Commit 8210ee32 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

confusion matrix

parent 254109fd
......@@ -13,20 +13,29 @@ import sparseconvnet as s
import time
import os
import math
import numpy as np
import PIL
def updateStats(stats, output, target, loss):
batchSize = output.size(0)
nClasses= output.size(1)
if not stats:
stats['top1'] = 0
stats['top5'] = 0
stats['n'] = 0
stats['nll'] = 0
stats['confusion matrix'] = output.new().resize_(nClasses,nClasses).zero_()
stats['n'] = stats['n'] + batchSize
stats['nll'] = stats['nll'] + loss * batchSize
_, predictions = output.float().sort(1, True)
correct = predictions.eq(
target.long()[:,None].expand_as(output))
target[:,None].expand_as(output))
# Top-1 score
stats['top1'] += correct.narrow(1, 0, 1).sum()
# Top-5 score
l = min(5, correct.size(1))
stats['top5'] += correct.narrow(1, 0, l).sum()
stats['confusion matrix'].index_add_(0,target,F.softmax(output).data)
def ClassificationTrainValidate(model, dataset, p):
......@@ -66,7 +75,7 @@ def ClassificationTrainValidate(model, dataset, p):
print('#parameters', sum([x.nelement() for x in model.parameters()]))
for epoch in range(p['epoch'], p['n_epochs'] + 1):
model.train()
stats = {'top1': 0, 'top5': 0, 'n': 0, 'nll': 0}
stats = {}
for param_group in optimizer.param_groups:
param_group['lr'] = p['initial_lr'] * \
math.exp((1 - epoch) * p['lr_decay'])
......@@ -93,7 +102,10 @@ def ClassificationTrainValidate(model, dataset, p):
stats['n']), stats['nll'] /
stats['n'], time.time() -
start))
cm=stats['confusion matrix'].cpu().numpy()
np.savetxt('train confusion matrix.csv',cm,delimiter=',')
cm*=255/(cm.sum(1,keepdims=True)+1e-9)
PIL.Image.fromarray(cm.astype('uint8'),mode='L').save('train confusion matrix.png')
if p['check_point']:
torch.save(epoch, 'epoch.pth')
torch.save(model.state_dict(),'model.pth')
......@@ -103,7 +115,7 @@ def ClassificationTrainValidate(model, dataset, p):
s.forward_pass_hidden_states = 0
start = time.time()
if p['test_reps'] ==1:
stats = {'top1': 0, 'top5': 0, 'n': 0, 'nll': 0}
stats = {}
for batch in dataset['val']():
if p['use_gpu']:
batch['input']=batch['input'].cuda()
......@@ -145,7 +157,7 @@ def ClassificationTrainValidate(model, dataset, p):
else:
predictions.index_add_(0,idxs,pr)
loss = criterion(predictions/rep, targets)
stats = {'top1': 0, 'top5': 0, 'n': 0, 'nll': 0}
stats = {}
updateStats(stats, predictions, targets, loss.data[0])
print(epoch, 'test rep ', rep,
': top1=%.2f%% top5=%.2f%% nll:%.2f time:%.1fs' %(
......@@ -156,3 +168,7 @@ def ClassificationTrainValidate(model, dataset, p):
'%.3e MultiplyAdds/sample %.3e HiddenStates/sample' % (
s.forward_pass_multiplyAdd_count / stats['n'],
s.forward_pass_hidden_states / stats['n']))
cm=stats['confusion matrix'].cpu().numpy()
np.savetxt('test confusion matrix.csv',cm,delimiter=',')
cm*=255/(cm.sum(1,keepdims=True)+1e-9)
PIL.Image.fromarray(cm.astype('uint8'),mode='L').save('test confusion matrix.png')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment