"examples/vscode:/vscode.git/clone" did not exist on "dcd1428dcc769d175734260764eeadf160593335"
Commit 6722cac3 authored by Benjamin Graham's avatar Benjamin Graham
Browse files

Enable multi testing in ClassificationTrainValidate

parent d8b64558
...@@ -30,7 +30,7 @@ def updateStats(stats, output, target, loss): ...@@ -30,7 +30,7 @@ def updateStats(stats, output, target, loss):
def ClassificationTrainValidate(model, dataset, p): def ClassificationTrainValidate(model, dataset, p):
criterion = nn.CrossEntropyLoss() criterion = F.cross_entropy
if 'n_epochs' not in p: if 'n_epochs' not in p:
p['n_epochs'] = 100 p['n_epochs'] = 100
if 'initial_lr' not in p: if 'initial_lr' not in p:
...@@ -47,7 +47,8 @@ def ClassificationTrainValidate(model, dataset, p): ...@@ -47,7 +47,8 @@ def ClassificationTrainValidate(model, dataset, p):
p['use_gpu'] = torch.cuda.is_available() p['use_gpu'] = torch.cuda.is_available()
if p['use_gpu']: if p['use_gpu']:
model.cuda() model.cuda()
criterion.cuda() if 'test_reps' not in p:
p['test_reps'] = 1
optimizer = optim.SGD(model.parameters(), optimizer = optim.SGD(model.parameters(),
lr=p['initial_lr'], lr=p['initial_lr'],
momentum = p['momentum'], momentum = p['momentum'],
...@@ -100,8 +101,9 @@ def ClassificationTrainValidate(model, dataset, p): ...@@ -100,8 +101,9 @@ def ClassificationTrainValidate(model, dataset, p):
model.eval() model.eval()
s.forward_pass_multiplyAdd_count = 0 s.forward_pass_multiplyAdd_count = 0
s.forward_pass_hidden_states = 0 s.forward_pass_hidden_states = 0
stats = {'top1': 0, 'top5': 0, 'n': 0, 'nll': 0}
start = time.time() start = time.time()
if p['test_reps'] ==1:
stats = {'top1': 0, 'top5': 0, 'n': 0, 'nll': 0}
for batch in dataset['val'](): for batch in dataset['val']():
if p['use_gpu']: if p['use_gpu']:
batch['input']=batch['input'].cuda() batch['input']=batch['input'].cuda()
...@@ -111,19 +113,45 @@ def ClassificationTrainValidate(model, dataset, p): ...@@ -111,19 +113,45 @@ def ClassificationTrainValidate(model, dataset, p):
output = model(batch['input']) output = model(batch['input'])
loss = criterion(output, batch['target']) loss = criterion(output, batch['target'])
updateStats(stats, output.data, batch['target'].data, loss.data[0]) updateStats(stats, output.data, batch['target'].data, loss.data[0])
print(epoch, 'test: top1=%.2f%% top5=%.2f%% nll:%.2f time:%.1fs' % print(epoch, 'test: top1=%.2f%% top5=%.2f%% nll:%.2f time:%.1fs' %(
(100 * 100 * (1 - 1.0 * stats['top1'] / stats['n']),
(1 - 100 * (1 - 1.0 * stats['top5'] / stats['n']),
1.0 * stats['top1'] / stats['nll'] / stats['n'],
stats['n']), 100 * time.time() - start),
(1 - '%.3e MultiplyAdds/sample %.3e HiddenStates/sample' % (
1.0 * stats['top5'] / s.forward_pass_multiplyAdd_count / stats['n'],
stats['n']), stats['nll'] / s.forward_pass_hidden_states / stats['n']))
stats['n'], time.time() - else:
start)) for rep in range(1,p['test_reps']+1):
print( pr=[]
'%.3e MultiplyAdds/sample %.3e HiddenStates/sample' % ta=[]
(s.forward_pass_multiplyAdd_count / idxs=[]
stats['n'], for batch in dataset['val']():
s.forward_pass_hidden_states / if p['use_gpu']:
stats['n'])) batch['input']=batch['input'].cuda()
batch['target'] = batch['target'].cuda()
batch['idx'] = batch['idx'].cuda()
batch['input'].to_variable()
pr.append( model(batch['input']).data )
ta.append( batch['target'] )
idxs.append( batch['idx'] )
pr=torch.cat(pr,0)
ta=torch.cat(ta,0)
idxs=torch.cat(idxs,0)
if rep==1:
target=pr.index_select(0,idxs)
ta=ta.index_select(0,idxs)
else:
target.index_add_(0,idxs,pr)
loss = criterion(pr, ta)
stats = {'top1': 0, 'top5': 0, 'n': 0, 'nll': 0}
updateStats(stats, pr, ta, loss.data[0])
print(epoch, 'test rep ', rep,
': top1=%.2f%% top5=%.2f%% nll:%.2f time:%.1fs' %(
100 * (1 - 1.0 * stats['top1'] / stats['n']),
100 * (1 - 1.0 * stats['top5'] / stats['n']),
stats['nll'] / stats['n'],
time.time() - start),
'%.3e MultiplyAdds/sample %.3e HiddenStates/sample' % (
s.forward_pass_multiplyAdd_count / stats['n'],
s.forward_pass_hidden_states / stats['n']))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment