Commit de6640be authored by Neel Kant's avatar Neel Kant
Browse files

Made topk accuracy reporting optional

parent 8d7f508a
...@@ -414,5 +414,9 @@ def _add_realm_args(parser): ...@@ -414,5 +414,9 @@ def _add_realm_args(parser):
group.add_argument('--ict-one-sent', action='store_true', group.add_argument('--ict-one-sent', action='store_true',
help='Whether to use one sentence documents in ICT') help='Whether to use one sentence documents in ICT')
# training
group.add_argument('--report-topk-accuracies', nargs='+', default=[],
help="Which top-k accuracies to report (e.g. '1 5 20')")
return parser return parser
...@@ -116,19 +116,16 @@ def forward_step(data_iterator, model): ...@@ -116,19 +116,16 @@ def forward_step(data_iterator, model):
softmaxed = F.softmax(retrieval_scores, dim=1) softmaxed = F.softmax(retrieval_scores, dim=1)
sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True) sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True)
def topk_acc(k): def topk_accuracy(k):
return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) for i in range(global_batch_size)]) / global_batch_size]) return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) for i in range(global_batch_size)]) / global_batch_size])
top_accs = [topk_acc(k) for k in [1, 8, 20, 100]]
topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies]
retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda()) retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
reduced_losses = reduce_losses([retrieval_loss, *top_accs]) reduced_losses = reduce_losses([retrieval_loss, *topk_accs])
stats_dict = {
'retrieval loss': reduced_losses[0], # create stats_dict with retrieval loss and all specified top-k accuracies
'top1_acc': reduced_losses[1], topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, reduced_losses[1:])}
'top8_acc': reduced_losses[2], stats_dict = dict(retrieval_loss=reduced_losses[0], **topk_acc_dict)
'top20_acc': reduced_losses[3],
'top100_acc': reduced_losses[4],
}
return retrieval_loss, stats_dict return retrieval_loss, stats_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment