Commit d609ba24 authored by thomwolf's avatar thomwolf
Browse files

resolving merge conflicts

parent 64ce9009
...@@ -46,7 +46,10 @@ logger = logging.getLogger(__name__) ...@@ -46,7 +46,10 @@ logger = logging.getLogger(__name__)
class SquadExample(object): class SquadExample(object):
"""A single training/test example for the Squad dataset.""" """
A single training/test example for the Squad dataset.
For examples without an answer, the start and end position are -1.
"""
def __init__(self, def __init__(self,
qas_id, qas_id,
...@@ -54,13 +57,15 @@ class SquadExample(object): ...@@ -54,13 +57,15 @@ class SquadExample(object):
doc_tokens, doc_tokens,
orig_answer_text=None, orig_answer_text=None,
start_position=None, start_position=None,
end_position=None): end_position=None,
is_impossible=None):
self.qas_id = qas_id self.qas_id = qas_id
self.question_text = question_text self.question_text = question_text
self.doc_tokens = doc_tokens self.doc_tokens = doc_tokens
self.orig_answer_text = orig_answer_text self.orig_answer_text = orig_answer_text
self.start_position = start_position self.start_position = start_position
self.end_position = end_position self.end_position = end_position
self.is_impossible = is_impossible
def __str__(self): def __str__(self):
return self.__repr__() return self.__repr__()
...@@ -75,6 +80,8 @@ class SquadExample(object): ...@@ -75,6 +80,8 @@ class SquadExample(object):
s += ", start_position: %d" % (self.start_position) s += ", start_position: %d" % (self.start_position)
if self.start_position: if self.start_position:
s += ", end_position: %d" % (self.end_position) s += ", end_position: %d" % (self.end_position)
if self.start_position:
s += ", is_impossible: %r" % (self.is_impossible)
return s return s
...@@ -92,7 +99,8 @@ class InputFeatures(object): ...@@ -92,7 +99,8 @@ class InputFeatures(object):
input_mask, input_mask,
segment_ids, segment_ids,
start_position=None, start_position=None,
end_position=None): end_position=None,
is_impossible=None):
self.unique_id = unique_id self.unique_id = unique_id
self.example_index = example_index self.example_index = example_index
self.doc_span_index = doc_span_index self.doc_span_index = doc_span_index
...@@ -104,9 +112,10 @@ class InputFeatures(object): ...@@ -104,9 +112,10 @@ class InputFeatures(object):
self.segment_ids = segment_ids self.segment_ids = segment_ids
self.start_position = start_position self.start_position = start_position
self.end_position = end_position self.end_position = end_position
self.is_impossible = is_impossible
def read_squad_examples(input_file, is_training): def read_squad_examples(input_file, is_training, version_2_with_negative):
"""Read a SQuAD json file into a list of SquadExample.""" """Read a SQuAD json file into a list of SquadExample."""
with open(input_file, "r", encoding='utf-8') as reader: with open(input_file, "r", encoding='utf-8') as reader:
input_data = json.load(reader)["data"] input_data = json.load(reader)["data"]
...@@ -140,10 +149,14 @@ def read_squad_examples(input_file, is_training): ...@@ -140,10 +149,14 @@ def read_squad_examples(input_file, is_training):
start_position = None start_position = None
end_position = None end_position = None
orig_answer_text = None orig_answer_text = None
is_impossible = False
if is_training: if is_training:
if len(qa["answers"]) != 1: if version_2_with_negative:
is_impossible = qa["is_impossible"]
if (len(qa["answers"]) != 1) and (not is_impossible):
raise ValueError( raise ValueError(
"For training, each question should have exactly 1 answer.") "For training, each question should have exactly 1 answer.")
if not is_impossible:
answer = qa["answers"][0] answer = qa["answers"][0]
orig_answer_text = answer["text"] orig_answer_text = answer["text"]
answer_offset = answer["answer_start"] answer_offset = answer["answer_start"]
...@@ -163,6 +176,10 @@ def read_squad_examples(input_file, is_training): ...@@ -163,6 +176,10 @@ def read_squad_examples(input_file, is_training):
logger.warning("Could not find answer: '%s' vs. '%s'", logger.warning("Could not find answer: '%s' vs. '%s'",
actual_text, cleaned_answer_text) actual_text, cleaned_answer_text)
continue continue
else:
start_position = -1
end_position = -1
orig_answer_text = ""
example = SquadExample( example = SquadExample(
qas_id=qas_id, qas_id=qas_id,
...@@ -170,7 +187,8 @@ def read_squad_examples(input_file, is_training): ...@@ -170,7 +187,8 @@ def read_squad_examples(input_file, is_training):
doc_tokens=doc_tokens, doc_tokens=doc_tokens,
orig_answer_text=orig_answer_text, orig_answer_text=orig_answer_text,
start_position=start_position, start_position=start_position,
end_position=end_position) end_position=end_position,
is_impossible=is_impossible)
examples.append(example) examples.append(example)
return examples return examples
...@@ -200,7 +218,10 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -200,7 +218,10 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
tok_start_position = None tok_start_position = None
tok_end_position = None tok_end_position = None
if is_training: if is_training and example.is_impossible:
tok_start_position = -1
tok_end_position = -1
if is_training and not example.is_impossible:
tok_start_position = orig_to_tok_index[example.start_position] tok_start_position = orig_to_tok_index[example.start_position]
if example.end_position < len(example.doc_tokens) - 1: if example.end_position < len(example.doc_tokens) - 1:
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1 tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
...@@ -272,20 +293,25 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -272,20 +293,25 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
start_position = None start_position = None
end_position = None end_position = None
if is_training: if is_training and not example.is_impossible:
# For training, if our document chunk does not contain an annotation # For training, if our document chunk does not contain an annotation
# we throw it out, since there is nothing to predict. # we throw it out, since there is nothing to predict.
doc_start = doc_span.start doc_start = doc_span.start
doc_end = doc_span.start + doc_span.length - 1 doc_end = doc_span.start + doc_span.length - 1
if (example.start_position < doc_start or out_of_span = False
example.end_position < doc_start or if not (tok_start_position >= doc_start and
example.start_position > doc_end or example.end_position > doc_end): tok_end_position <= doc_end):
continue out_of_span = True
if out_of_span:
start_position = 0
end_position = 0
else:
doc_offset = len(query_tokens) + 2 doc_offset = len(query_tokens) + 2
start_position = tok_start_position - doc_start + doc_offset start_position = tok_start_position - doc_start + doc_offset
end_position = tok_end_position - doc_start + doc_offset end_position = tok_end_position - doc_start + doc_offset
if is_training and example.is_impossible:
start_position = 0
end_position = 0
if example_index < 20: if example_index < 20:
logger.info("*** Example ***") logger.info("*** Example ***")
logger.info("unique_id: %s" % (unique_id)) logger.info("unique_id: %s" % (unique_id))
...@@ -302,7 +328,9 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -302,7 +328,9 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
"input_mask: %s" % " ".join([str(x) for x in input_mask])) "input_mask: %s" % " ".join([str(x) for x in input_mask]))
logger.info( logger.info(
"segment_ids: %s" % " ".join([str(x) for x in segment_ids])) "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
if is_training: if is_training and example.is_impossible:
logger.info("impossible example")
if is_training and not example.is_impossible:
answer_text = " ".join(tokens[start_position:(end_position + 1)]) answer_text = " ".join(tokens[start_position:(end_position + 1)])
logger.info("start_position: %d" % (start_position)) logger.info("start_position: %d" % (start_position))
logger.info("end_position: %d" % (end_position)) logger.info("end_position: %d" % (end_position))
...@@ -321,7 +349,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -321,7 +349,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
input_mask=input_mask, input_mask=input_mask,
segment_ids=segment_ids, segment_ids=segment_ids,
start_position=start_position, start_position=start_position,
end_position=end_position)) end_position=end_position,
is_impossible=example.is_impossible))
unique_id += 1 unique_id += 1
return features return features
...@@ -401,15 +430,15 @@ def _check_is_max_context(doc_spans, cur_span_index, position): ...@@ -401,15 +430,15 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
return cur_span_index == best_span_index return cur_span_index == best_span_index
RawResult = collections.namedtuple("RawResult", RawResult = collections.namedtuple("RawResult",
["unique_id", "start_logits", "end_logits"]) ["unique_id", "start_logits", "end_logits"])
def write_predictions(all_examples, all_features, all_results, n_best_size, def write_predictions(all_examples, all_features, all_results, n_best_size,
max_answer_length, do_lower_case, output_prediction_file, max_answer_length, do_lower_case, output_prediction_file,
output_nbest_file, verbose_logging): output_nbest_file, output_null_log_odds_file, verbose_logging,
"""Write final predictions to the json file.""" version_2_with_negative, null_score_diff_threshold):
"""Write final predictions to the json file and log-odds of null if needed."""
logger.info("Writing predictions to: %s" % (output_prediction_file)) logger.info("Writing predictions to: %s" % (output_prediction_file))
logger.info("Writing nbest to: %s" % (output_nbest_file)) logger.info("Writing nbest to: %s" % (output_nbest_file))
...@@ -427,15 +456,29 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -427,15 +456,29 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
all_predictions = collections.OrderedDict() all_predictions = collections.OrderedDict()
all_nbest_json = collections.OrderedDict() all_nbest_json = collections.OrderedDict()
scores_diff_json = collections.OrderedDict()
for (example_index, example) in enumerate(all_examples): for (example_index, example) in enumerate(all_examples):
features = example_index_to_features[example_index] features = example_index_to_features[example_index]
prelim_predictions = [] prelim_predictions = []
# keep track of the minimum score of null start+end of position 0
score_null = 1000000 # large and positive
min_null_feature_index = 0 # the paragraph slice with min mull score
null_start_logit = 0 # the start logit at the slice with min null score
null_end_logit = 0 # the end logit at the slice with min null score
for (feature_index, feature) in enumerate(features): for (feature_index, feature) in enumerate(features):
result = unique_id_to_result[feature.unique_id] result = unique_id_to_result[feature.unique_id]
start_indexes = _get_best_indexes(result.start_logits, n_best_size) start_indexes = _get_best_indexes(result.start_logits, n_best_size)
end_indexes = _get_best_indexes(result.end_logits, n_best_size) end_indexes = _get_best_indexes(result.end_logits, n_best_size)
# if we could have irrelevant answers, get the min score of irrelevant
if version_2_with_negative:
feature_null_score = result.start_logits[0] + result.end_logits[0]
if feature_null_score < score_null:
score_null = feature_null_score
min_null_feature_index = feature_index
null_start_logit = result.start_logits[0]
null_end_logit = result.end_logits[0]
for start_index in start_indexes: for start_index in start_indexes:
for end_index in end_indexes: for end_index in end_indexes:
# We could hypothetically create invalid predictions, e.g., predict # We could hypothetically create invalid predictions, e.g., predict
...@@ -463,7 +506,14 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -463,7 +506,14 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
end_index=end_index, end_index=end_index,
start_logit=result.start_logits[start_index], start_logit=result.start_logits[start_index],
end_logit=result.end_logits[end_index])) end_logit=result.end_logits[end_index]))
if version_2_with_negative:
prelim_predictions.append(
_PrelimPrediction(
feature_index=min_null_feature_index,
start_index=0,
end_index=0,
start_logit=null_start_logit,
end_logit=null_end_logit))
prelim_predictions = sorted( prelim_predictions = sorted(
prelim_predictions, prelim_predictions,
key=lambda x: (x.start_logit + x.end_logit), key=lambda x: (x.start_logit + x.end_logit),
...@@ -478,7 +528,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -478,7 +528,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
if len(nbest) >= n_best_size: if len(nbest) >= n_best_size:
break break
feature = features[pred.feature_index] feature = features[pred.feature_index]
if pred.start_index > 0: # this is a non-null prediction
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index] orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index] orig_doc_end = feature.token_to_orig_map[pred.end_index]
...@@ -499,12 +549,23 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -499,12 +549,23 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
continue continue
seen_predictions[final_text] = True seen_predictions[final_text] = True
else:
final_text = ""
seen_predictions[final_text] = True
nbest.append( nbest.append(
_NbestPrediction( _NbestPrediction(
text=final_text, text=final_text,
start_logit=pred.start_logit, start_logit=pred.start_logit,
end_logit=pred.end_logit)) end_logit=pred.end_logit))
# if we didn't include the empty option in the n-best, include it
if version_2_with_negative:
if "" not in seen_predictions:
nbest.append(
_NbestPrediction(
text="",
start_logit=null_start_logit,
end_logit=null_end_logit))
# In very rare edge cases we could have no valid predictions. So we # In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure. # just create a nonce prediction in this case to avoid failure.
if not nbest: if not nbest:
...@@ -514,8 +575,12 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -514,8 +575,12 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
assert len(nbest) >= 1 assert len(nbest) >= 1
total_scores = [] total_scores = []
best_non_null_entry = None
for entry in nbest: for entry in nbest:
total_scores.append(entry.start_logit + entry.end_logit) total_scores.append(entry.start_logit + entry.end_logit)
if not best_non_null_entry:
if entry.text:
best_non_null_entry = entry
probs = _compute_softmax(total_scores) probs = _compute_softmax(total_scores)
...@@ -530,7 +595,17 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -530,7 +595,17 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
assert len(nbest_json) >= 1 assert len(nbest_json) >= 1
if not version_2_with_negative:
all_predictions[example.qas_id] = nbest_json[0]["text"] all_predictions[example.qas_id] = nbest_json[0]["text"]
else:
# predict "" iff the null score - the score of best non-null > threshold
score_diff = score_null - best_non_null_entry.start_logit - (
best_non_null_entry.end_logit)
scores_diff_json[example.qas_id] = score_diff
if score_diff > null_score_diff_threshold:
all_predictions[example.qas_id] = ""
else:
all_predictions[example.qas_id] = best_non_null_entry.text
all_nbest_json[example.qas_id] = nbest_json all_nbest_json[example.qas_id] = nbest_json
with open(output_prediction_file, "w") as writer: with open(output_prediction_file, "w") as writer:
...@@ -539,6 +614,10 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -539,6 +614,10 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
with open(output_nbest_file, "w") as writer: with open(output_nbest_file, "w") as writer:
writer.write(json.dumps(all_nbest_json, indent=4) + "\n") writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
if version_2_with_negative:
with open(output_null_log_odds_file, "w") as writer:
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
"""Project the tokenized prediction back to the original text.""" """Project the tokenized prediction back to the original text."""
...@@ -701,7 +780,7 @@ def main(): ...@@ -701,7 +780,7 @@ def main():
parser.add_argument("--num_train_epochs", default=3.0, type=float, parser.add_argument("--num_train_epochs", default=3.0, type=float,
help="Total number of training epochs to perform.") help="Total number of training epochs to perform.")
parser.add_argument("--warmup_proportion", default=0.1, type=float, parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% " help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
"of training.") "of training.")
parser.add_argument("--n_best_size", default=20, type=int, parser.add_argument("--n_best_size", default=20, type=int,
help="The total number of n-best predictions to generate in the nbest_predictions.json " help="The total number of n-best predictions to generate in the nbest_predictions.json "
...@@ -738,7 +817,12 @@ def main(): ...@@ -738,7 +817,12 @@ def main():
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n" "0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n") "Positive power of 2: static loss scaling value.\n")
parser.add_argument('--version_2_with_negative',
action='store_true',
help='If true, the SQuAD examples contain some that do not have an answer.')
parser.add_argument('--null_score_diff_threshold',
type=float, default=0.0,
help="If null_score - best_non_null is greater than the threshold predict null.")
args = parser.parse_args() args = parser.parse_args()
if args.local_rank == -1 or args.no_cuda: if args.local_rank == -1 or args.no_cuda:
...@@ -787,9 +871,9 @@ def main(): ...@@ -787,9 +871,9 @@ def main():
num_train_optimization_steps = None num_train_optimization_steps = None
if args.do_train: if args.do_train:
train_examples = read_squad_examples( train_examples = read_squad_examples(
input_file=args.train_file, is_training=True) input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative)
num_train_optimization_steps = int( num_train_optimization_steps = int(
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
if args.local_rank != -1: if args.local_rank != -1:
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
...@@ -825,7 +909,7 @@ def main(): ...@@ -825,7 +909,7 @@ def main():
if args.fp16: if args.fp16:
try: try:
from apex.optimizers import FP16_Optimizer from apex.optimizer import FP16_Optimizer
from apex.optimizers import FusedAdam from apex.optimizers import FusedAdam
except ImportError: except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
...@@ -901,7 +985,7 @@ def main(): ...@@ -901,7 +985,7 @@ def main():
if (step + 1) % args.gradient_accumulation_steps == 0: if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16: if args.fp16:
# modify learning rate with special warm up BERT uses # modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used that handles this automatically # if args.fp16 is False, BertAdam is used and handles this automatically
lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion) lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step param_group['lr'] = lr_this_step
...@@ -914,7 +998,6 @@ def main(): ...@@ -914,7 +998,6 @@ def main():
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
if args.do_train: if args.do_train:
torch.save(model_to_save.state_dict(), output_model_file) torch.save(model_to_save.state_dict(), output_model_file)
# Load a trained model that you have fine-tuned # Load a trained model that you have fine-tuned
model_state_dict = torch.load(output_model_file) model_state_dict = torch.load(output_model_file)
model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict) model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict)
...@@ -925,7 +1008,7 @@ def main(): ...@@ -925,7 +1008,7 @@ def main():
if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = read_squad_examples( eval_examples = read_squad_examples(
input_file=args.predict_file, is_training=False) input_file=args.predict_file, is_training=False, version_2_with_negative=args.version_2_with_negative)
eval_features = convert_examples_to_features( eval_features = convert_examples_to_features(
examples=eval_examples, examples=eval_examples,
tokenizer=tokenizer, tokenizer=tokenizer,
...@@ -969,10 +1052,12 @@ def main(): ...@@ -969,10 +1052,12 @@ def main():
end_logits=end_logits)) end_logits=end_logits))
output_prediction_file = os.path.join(args.output_dir, "predictions.json") output_prediction_file = os.path.join(args.output_dir, "predictions.json")
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json") output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds.json")
write_predictions(eval_examples, eval_features, all_results, write_predictions(eval_examples, eval_features, all_results,
args.n_best_size, args.max_answer_length, args.n_best_size, args.max_answer_length,
args.do_lower_case, output_prediction_file, args.do_lower_case, output_prediction_file,
output_nbest_file, args.verbose_logging) output_nbest_file, output_null_log_odds_file, args.verbose_logging,
args.version_2_with_negative, args.null_score_diff_threshold)
if __name__ == "__main__": if __name__ == "__main__":
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment