Commit b54de837 authored by VictorSanh's avatar VictorSanh
Browse files

Quick fix on eval accuracy

parent 1d53f9cb
...@@ -548,6 +548,7 @@ def main(): ...@@ -548,6 +548,7 @@ def main():
model.eval() model.eval()
eval_loss = 0 eval_loss = 0
eval_accuracy = 0 eval_accuracy = 0
nb_eval_examples = 0
for input_ids, input_mask, segment_ids, label_ids in eval_dataloader: for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device) input_mask = input_mask.float().to(device)
...@@ -562,9 +563,11 @@ def main(): ...@@ -562,9 +563,11 @@ def main():
eval_loss += tmp_eval_loss.item() eval_loss += tmp_eval_loss.item()
eval_accuracy += tmp_eval_accuracy eval_accuracy += tmp_eval_accuracy
nb_eval_examples += input_ids.size(0)
eval_loss = eval_loss / len(eval_dataloader) eval_loss = eval_loss / nb_eval_examples #len(eval_dataloader)
eval_accuracy = eval_accuracy / len(eval_dataloader) eval_accuracy = eval_accuracy / nb_eval_examples #len(eval_dataloader)
result = {'eval_loss': eval_loss, result = {'eval_loss': eval_loss,
'eval_accuracy': eval_accuracy, 'eval_accuracy': eval_accuracy,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment