Commit e0bf01d9 authored by Ananya Harsh Jha's avatar Ananya Harsh Jha
Browse files

added hack for mismatched MNLI

parent 4c721c6b
...@@ -679,7 +679,6 @@ def main(): ...@@ -679,7 +679,6 @@ def main():
output_modes = { output_modes = {
"cola": "classification", "cola": "classification",
"mnli": "classification", "mnli": "classification",
"mnli-mm": "classification",
"mrpc": "classification", "mrpc": "classification",
"sst-2": "classification", "sst-2": "classification",
"sts-b": "regression", "sts-b": "regression",
...@@ -930,6 +929,8 @@ def main(): ...@@ -930,6 +929,8 @@ def main():
preds = preds[0] preds = preds[0]
if output_mode == "classification": if output_mode == "classification":
preds = np.argmax(preds, axis=1) preds = np.argmax(preds, axis=1)
elif output_mode == "regression":
preds = np.squeeze(preds)
result = compute_metrics(task_name, preds, all_label_ids.numpy()) result = compute_metrics(task_name, preds, all_label_ids.numpy())
loss = tr_loss/nb_tr_steps if args.do_train else None loss = tr_loss/nb_tr_steps if args.do_train else None
...@@ -943,6 +944,69 @@ def main(): ...@@ -943,6 +944,69 @@ def main():
for key in sorted(result.keys()): for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key])) logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key]))) writer.write("%s = %s\n" % (key, str(result[key])))
# hack for MNLI-MM
if task_name == "mnli":
task_name = "mnli-mm"
processor = processors[task_name]()
eval_examples = processor.get_dev_examples(args.data_dir)
eval_features = convert_examples_to_features(
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
logger.info("***** Running evaluation *****")
logger.info(" Num examples = %d", len(eval_examples))
logger.info(" Batch size = %d", args.eval_batch_size)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
# Run prediction for full data
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
model.eval()
eval_loss = 0
nb_eval_steps = 0
preds = []
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device)
with torch.no_grad():
logits = model(input_ids, segment_ids, input_mask, labels=None)
loss_fct = CrossEntropyLoss()
tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
eval_loss += tmp_eval_loss.mean().item()
nb_eval_steps += 1
if len(preds) == 0:
preds.append(logits.detach().cpu().numpy())
else:
preds[0] = np.append(
preds[0], logits.detach().cpu().numpy(), axis=0)
eval_loss = eval_loss / nb_eval_steps
preds = preds[0]
preds = np.argmax(preds, axis=1)
result = compute_metrics(task_name, preds, all_label_ids.numpy())
loss = tr_loss/nb_tr_steps if args.do_train else None
result['eval_loss'] = eval_loss
result['global_step'] = global_step
result['loss'] = loss
output_eval_file = os.path.join(args.output_dir + '-MM', "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment