Commit 11274575 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

Variable(..) before softmax in classificationTrainValidate

parent 15fd91a0
......@@ -35,7 +35,7 @@ def updateStats(stats, output, target, loss):
# Top-5 score
l = min(5, correct.size(1))
stats['top5'] += correct.narrow(1, 0, l).sum()
stats['confusion matrix'].index_add_(0,target,F.softmax(output).data)
stats['confusion matrix'].index_add_(0,target,F.softmax(Variable(output),1).data)
def ClassificationTrainValidate(model, dataset, p):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment