Commit 254109fd authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

Enable multi testing in ClassificationTrainValidate

parent 6722cac3
...@@ -21,7 +21,7 @@ def updateStats(stats, output, target, loss): ...@@ -21,7 +21,7 @@ def updateStats(stats, output, target, loss):
stats['nll'] = stats['nll'] + loss * batchSize stats['nll'] = stats['nll'] + loss * batchSize
_, predictions = output.float().sort(1, True) _, predictions = output.float().sort(1, True)
correct = predictions.eq( correct = predictions.eq(
target.long().view(batchSize, 1).expand_as(output)) target.long()[:,None].expand_as(output))
# Top-1 score # Top-1 score
stats['top1'] += correct.narrow(1, 0, 1).sum() stats['top1'] += correct.narrow(1, 0, 1).sum()
# Top-5 score # Top-5 score
...@@ -132,20 +132,21 @@ def ClassificationTrainValidate(model, dataset, p): ...@@ -132,20 +132,21 @@ def ClassificationTrainValidate(model, dataset, p):
batch['target'] = batch['target'].cuda() batch['target'] = batch['target'].cuda()
batch['idx'] = batch['idx'].cuda() batch['idx'] = batch['idx'].cuda()
batch['input'].to_variable() batch['input'].to_variable()
pr.append( model(batch['input']).data ) output = model(batch['input'])
pr.append( output.data )
ta.append( batch['target'] ) ta.append( batch['target'] )
idxs.append( batch['idx'] ) idxs.append( batch['idx'] )
pr=torch.cat(pr,0) pr=torch.cat(pr,0)
ta=torch.cat(ta,0) ta=torch.cat(ta,0)
idxs=torch.cat(idxs,0) idxs=torch.cat(idxs,0)
if rep==1: if rep==1:
target=pr.index_select(0,idxs) predictions=pr.new().resize_as_(pr).zero_().index_add_(0,idxs,pr)
ta=ta.index_select(0,idxs) targets=ta.new().resize_as_(ta).zero_().index_add_(0,idxs,ta)
else: else:
target.index_add_(0,idxs,pr) predictions.index_add_(0,idxs,pr)
loss = criterion(pr, ta) loss = criterion(predictions/rep, targets)
stats = {'top1': 0, 'top5': 0, 'n': 0, 'nll': 0} stats = {'top1': 0, 'top5': 0, 'n': 0, 'nll': 0}
updateStats(stats, pr, ta, loss.data[0]) updateStats(stats, predictions, targets, loss.data[0])
print(epoch, 'test rep ', rep, print(epoch, 'test rep ', rep,
': top1=%.2f%% top5=%.2f%% nll:%.2f time:%.1fs' %( ': top1=%.2f%% top5=%.2f%% nll:%.2f time:%.1fs' %(
100 * (1 - 1.0 * stats['top1'] / stats['n']), 100 * (1 - 1.0 * stats['top1'] / stats['n']),
......
...@@ -253,7 +253,7 @@ pip install git+https://github.com/pytorch/tnt.git@master ...@@ -253,7 +253,7 @@ pip install git+https://github.com/pytorch/tnt.git@master
2. [Spatially-sparse convolutional neural networks, 2014](http://arxiv.org/abs/1409.6070) SparseConvNets for Chinese handwriting recognition 2. [Spatially-sparse convolutional neural networks, 2014](http://arxiv.org/abs/1409.6070) SparseConvNets for Chinese handwriting recognition
3. [Fractional max-pooling, 2014](http://arxiv.org/abs/1412.6071) A SparseConvNet with fractional max-pooling achieves an error rate of 3.47% for CIFAR-10. 3. [Fractional max-pooling, 2014](http://arxiv.org/abs/1412.6071) A SparseConvNet with fractional max-pooling achieves an error rate of 3.47% for CIFAR-10.
4. [Sparse 3D convolutional neural networks, BMVC 2015](http://arxiv.org/abs/1505.02890) SparseConvNets for 3D object recognition and (2+1)D video action recognition. 4. [Sparse 3D convolutional neural networks, BMVC 2015](http://arxiv.org/abs/1505.02890) SparseConvNets for 3D object recognition and (2+1)D video action recognition.
5. [Kaggle plankton recognition competition, 2015](https://www.kaggle.com/c/datasciencebowl) Third place. The competition solution is being adapted for research purposes. 5. [Kaggle plankton recognition competition, 2015](https://www.kaggle.com/c/datasciencebowl) Third place. The competition solution is being adapted for research purposes in [EcoTaxa](http://ecotaxa.obs-vlfr.fr/).
6. [Kaggle Diabetic Retinopathy Detection, 2015](https://www.kaggle.com/c/diabetic-retinopathy-detection/) First place in the Kaggle Diabetic Retinopathy Detection competition. 6. [Kaggle Diabetic Retinopathy Detection, 2015](https://www.kaggle.com/c/diabetic-retinopathy-detection/) First place in the Kaggle Diabetic Retinopathy Detection competition.
7. [Submanifold Sparse Convolutional Networks, 2017](https://arxiv.org/abs/1706.01307) Introduces deep 'submanifold' SparseConvNets. 7. [Submanifold Sparse Convolutional Networks, 2017](https://arxiv.org/abs/1706.01307) Introduces deep 'submanifold' SparseConvNets.
8. [Workshop on Learning to See from 3D Data, 2017](https://shapenet.cs.stanford.edu/iccv17workshop/) First place in the [semantic segmentation](https://shapenet.cs.stanford.edu/iccv17/) competition. [Report](https://arxiv.org/pdf/1710.06104) 8. [Workshop on Learning to See from 3D Data, 2017](https://shapenet.cs.stanford.edu/iccv17workshop/) First place in the [semantic segmentation](https://shapenet.cs.stanford.edu/iccv17/) competition. [Report](https://arxiv.org/pdf/1710.06104)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment