Commit a0d33d69 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

confusion matrix

parent 8210ee32
...@@ -14,7 +14,7 @@ import time ...@@ -14,7 +14,7 @@ import time
import os import os
import math import math
import numpy as np import numpy as np
import PIL from PIL import Image
def updateStats(stats, output, target, loss): def updateStats(stats, output, target, loss):
batchSize = output.size(0) batchSize = output.size(0)
...@@ -105,7 +105,7 @@ def ClassificationTrainValidate(model, dataset, p): ...@@ -105,7 +105,7 @@ def ClassificationTrainValidate(model, dataset, p):
cm=stats['confusion matrix'].cpu().numpy() cm=stats['confusion matrix'].cpu().numpy()
np.savetxt('train confusion matrix.csv',cm,delimiter=',') np.savetxt('train confusion matrix.csv',cm,delimiter=',')
cm*=255/(cm.sum(1,keepdims=True)+1e-9) cm*=255/(cm.sum(1,keepdims=True)+1e-9)
PIL.Image.fromarray(cm.astype('uint8'),mode='L').save('train confusion matrix.png') Image.fromarray(cm.astype('uint8'),mode='L').save('train confusion matrix.png')
if p['check_point']: if p['check_point']:
torch.save(epoch, 'epoch.pth') torch.save(epoch, 'epoch.pth')
torch.save(model.state_dict(),'model.pth') torch.save(model.state_dict(),'model.pth')
...@@ -171,4 +171,4 @@ def ClassificationTrainValidate(model, dataset, p): ...@@ -171,4 +171,4 @@ def ClassificationTrainValidate(model, dataset, p):
cm=stats['confusion matrix'].cpu().numpy() cm=stats['confusion matrix'].cpu().numpy()
np.savetxt('test confusion matrix.csv',cm,delimiter=',') np.savetxt('test confusion matrix.csv',cm,delimiter=',')
cm*=255/(cm.sum(1,keepdims=True)+1e-9) cm*=255/(cm.sum(1,keepdims=True)+1e-9)
PIL.Image.fromarray(cm.astype('uint8'),mode='L').save('test confusion matrix.png') Image.fromarray(cm.astype('uint8'),mode='L').save('test confusion matrix.png')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment