Commit 07f46a64 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

iou

parent faa48887
...@@ -15,16 +15,16 @@ import torch, numpy as np, glob, math, torch.utils.data, scipy.ndimage, multipro ...@@ -15,16 +15,16 @@ import torch, numpy as np, glob, math, torch.utils.data, scipy.ndimage, multipro
dimension=3 dimension=3
full_scale=4096 #Input field size full_scale=4096 #Input field size
# VALID_CLAS_IDS have been mapped to the range {0,1,...,19} # Class IDs have been mapped to the range {0,1,...,19}
VALID_CLASS_IDS = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39]) # NYU_CLASS_IDS = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39])
train,val=[],[] train,val=[],[]
for x in torch.utils.data.DataLoader( for x in torch.utils.data.DataLoader(
glob.glob('train/*.pth')[::10], glob.glob('train/*.pth'),
collate_fn=lambda x: torch.load(x[0]), num_workers=mp.cpu_count()): collate_fn=lambda x: torch.load(x[0]), num_workers=mp.cpu_count()):
train.append(x) train.append(x)
for x in torch.utils.data.DataLoader( for x in torch.utils.data.DataLoader(
glob.glob('val/*.pth')[::10], glob.glob('val/*.pth'),
collate_fn=lambda x: torch.load(x[0]), num_workers=mp.cpu_count()): collate_fn=lambda x: torch.load(x[0]), num_workers=mp.cpu_count()):
val.append(x) val.append(x)
print('Training examples:', len(train)) print('Training examples:', len(train))
......
...@@ -6,15 +6,15 @@ ...@@ -6,15 +6,15 @@
import torch, numpy as np import torch, numpy as np
#VALID_CLASS_IDS = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39])
#Classes relabelled {-100,0,1,...,19}. #Classes relabelled {-100,0,1,...,19}.
#Predictions will all be in the set {0,1,...,19} #Predictions will all be in the set {0,1,...,19}
#VALID_CLASS_IDS = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39])
VALID_CLASS_IDS = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
CLASS_LABELS = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator', 'shower curtain', 'toilet', 'sink', 'bathtub', 'otherfurniture'] CLASS_LABELS = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator', 'shower curtain', 'toilet', 'sink', 'bathtub', 'otherfurniture']
UNKNOWN_ID = -100 UNKNOWN_ID = -100
N_CLASSES = len(CLASS_LABELS)
def confusion_matrix(pred_ids, gt_ids): def confusion_matrix(pred_ids, gt_ids):
assert pred_ids.shape == gt_ids.shape, (pred_ids.shape, gt_ids.shape) assert pred_ids.shape == gt_ids.shape, (pred_ids.shape, gt_ids.shape)
...@@ -24,11 +24,10 @@ def confusion_matrix(pred_ids, gt_ids): ...@@ -24,11 +24,10 @@ def confusion_matrix(pred_ids, gt_ids):
def get_iou(label_id, confusion): def get_iou(label_id, confusion):
# true positives # true positives
tp = np.longlong(confusion[label_id, label_id]) tp = np.longlong(confusion[label_id, label_id])
# false negatives
fn = np.longlong(confusion[label_id, :].sum()) - tp
# false positives # false positives
not_ignored = [l for l in VALID_CLASS_IDS if not l == label_id] fp = np.longlong(confusion[label_id, :].sum()) - tp
fp = np.longlong(confusion[not_ignored, label_id].sum()) # false negatives
fn = np.longlong(confusion[:, label_id].sum())
denom = (tp + fp + fn) denom = (tp + fp + fn)
if denom == 0: if denom == 0:
...@@ -40,15 +39,14 @@ def evaluate(pred_ids,gt_ids): ...@@ -40,15 +39,14 @@ def evaluate(pred_ids,gt_ids):
confusion=confusion_matrix(pred_ids,gt_ids) confusion=confusion_matrix(pred_ids,gt_ids)
class_ious = {} class_ious = {}
mean_iou = 0 mean_iou = 0
for i in range(len(VALID_CLASS_IDS)): for i in range(N_CLASSES):
label_name = CLASS_LABELS[i] label_name = CLASS_LABELS[i]
label_id = VALID_CLASS_IDS[i] class_ious[label_name] = get_iou(i, confusion)
class_ious[label_name] = get_iou(label_id, confusion)
mean_iou+=class_ious[label_name][0]/20 mean_iou+=class_ious[label_name][0]/20
print('classes IoU') print('classes IoU')
print('----------------------------') print('----------------------------')
for i in range(len(VALID_CLASS_IDS)): for i in range(N_CLASSES):
label_name = CLASS_LABELS[i] label_name = CLASS_LABELS[i]
print('{0:<14s}: {1:>5.3f} ({2:>6d}/{3:<6d})'.format(label_name, class_ious[label_name][0], class_ious[label_name][1], class_ious[label_name][2])) print('{0:<14s}: {1:>5.3f} ({2:>6d}/{3:<6d})'.format(label_name, class_ious[label_name][0], class_ious[label_name][1], class_ious[label_name][2]))
print('mean IOU', mean_iou) print('mean IOU', mean_iou)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment