Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
SparseConvNet
Commits
6722cac3
Commit
6722cac3
authored
Oct 26, 2017
by
Benjamin Graham
Browse files
Enable multi testing in ClassificationTrainValidate
parent
d8b64558
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
56 additions
and
28 deletions
+56
-28
PyTorch/sparseconvnet/classificationTrainValidate.py
PyTorch/sparseconvnet/classificationTrainValidate.py
+56
-28
No files found.
PyTorch/sparseconvnet/classificationTrainValidate.py
View file @
6722cac3
...
...
@@ -30,7 +30,7 @@ def updateStats(stats, output, target, loss):
def
ClassificationTrainValidate
(
model
,
dataset
,
p
):
criterion
=
nn
.
C
ross
E
ntropy
Loss
()
criterion
=
F
.
c
ross
_e
ntropy
if
'n_epochs'
not
in
p
:
p
[
'n_epochs'
]
=
100
if
'initial_lr'
not
in
p
:
...
...
@@ -47,7 +47,8 @@ def ClassificationTrainValidate(model, dataset, p):
p
[
'use_gpu'
]
=
torch
.
cuda
.
is_available
()
if
p
[
'use_gpu'
]:
model
.
cuda
()
criterion
.
cuda
()
if
'test_reps'
not
in
p
:
p
[
'test_reps'
]
=
1
optimizer
=
optim
.
SGD
(
model
.
parameters
(),
lr
=
p
[
'initial_lr'
],
momentum
=
p
[
'momentum'
],
...
...
@@ -100,30 +101,57 @@ def ClassificationTrainValidate(model, dataset, p):
model
.
eval
()
s
.
forward_pass_multiplyAdd_count
=
0
s
.
forward_pass_hidden_states
=
0
stats
=
{
'top1'
:
0
,
'top5'
:
0
,
'n'
:
0
,
'nll'
:
0
}
start
=
time
.
time
()
for
batch
in
dataset
[
'val'
]():
if
p
[
'use_gpu'
]:
batch
[
'input'
]
=
batch
[
'input'
].
cuda
()
batch
[
'target'
]
=
batch
[
'target'
].
cuda
()
batch
[
'input'
].
to_variable
()
batch
[
'target'
]
=
Variable
(
batch
[
'target'
])
output
=
model
(
batch
[
'input'
])
loss
=
criterion
(
output
,
batch
[
'target'
])
updateStats
(
stats
,
output
.
data
,
batch
[
'target'
].
data
,
loss
.
data
[
0
])
print
(
epoch
,
'test: top1=%.2f%% top5=%.2f%% nll:%.2f time:%.1fs'
%
(
100
*
(
1
-
1.0
*
stats
[
'top1'
]
/
stats
[
'n'
]),
100
*
(
1
-
1.0
*
stats
[
'top5'
]
/
stats
[
'n'
]),
stats
[
'nll'
]
/
stats
[
'n'
],
time
.
time
()
-
start
))
print
(
'%.3e MultiplyAdds/sample %.3e HiddenStates/sample'
%
(
s
.
forward_pass_multiplyAdd_count
/
stats
[
'n'
],
s
.
forward_pass_hidden_states
/
stats
[
'n'
]))
if
p
[
'test_reps'
]
==
1
:
stats
=
{
'top1'
:
0
,
'top5'
:
0
,
'n'
:
0
,
'nll'
:
0
}
for
batch
in
dataset
[
'val'
]():
if
p
[
'use_gpu'
]:
batch
[
'input'
]
=
batch
[
'input'
].
cuda
()
batch
[
'target'
]
=
batch
[
'target'
].
cuda
()
batch
[
'input'
].
to_variable
()
batch
[
'target'
]
=
Variable
(
batch
[
'target'
])
output
=
model
(
batch
[
'input'
])
loss
=
criterion
(
output
,
batch
[
'target'
])
updateStats
(
stats
,
output
.
data
,
batch
[
'target'
].
data
,
loss
.
data
[
0
])
print
(
epoch
,
'test: top1=%.2f%% top5=%.2f%% nll:%.2f time:%.1fs'
%
(
100
*
(
1
-
1.0
*
stats
[
'top1'
]
/
stats
[
'n'
]),
100
*
(
1
-
1.0
*
stats
[
'top5'
]
/
stats
[
'n'
]),
stats
[
'nll'
]
/
stats
[
'n'
],
time
.
time
()
-
start
),
'%.3e MultiplyAdds/sample %.3e HiddenStates/sample'
%
(
s
.
forward_pass_multiplyAdd_count
/
stats
[
'n'
],
s
.
forward_pass_hidden_states
/
stats
[
'n'
]))
else
:
for
rep
in
range
(
1
,
p
[
'test_reps'
]
+
1
):
pr
=
[]
ta
=
[]
idxs
=
[]
for
batch
in
dataset
[
'val'
]():
if
p
[
'use_gpu'
]:
batch
[
'input'
]
=
batch
[
'input'
].
cuda
()
batch
[
'target'
]
=
batch
[
'target'
].
cuda
()
batch
[
'idx'
]
=
batch
[
'idx'
].
cuda
()
batch
[
'input'
].
to_variable
()
pr
.
append
(
model
(
batch
[
'input'
]).
data
)
ta
.
append
(
batch
[
'target'
]
)
idxs
.
append
(
batch
[
'idx'
]
)
pr
=
torch
.
cat
(
pr
,
0
)
ta
=
torch
.
cat
(
ta
,
0
)
idxs
=
torch
.
cat
(
idxs
,
0
)
if
rep
==
1
:
target
=
pr
.
index_select
(
0
,
idxs
)
ta
=
ta
.
index_select
(
0
,
idxs
)
else
:
target
.
index_add_
(
0
,
idxs
,
pr
)
loss
=
criterion
(
pr
,
ta
)
stats
=
{
'top1'
:
0
,
'top5'
:
0
,
'n'
:
0
,
'nll'
:
0
}
updateStats
(
stats
,
pr
,
ta
,
loss
.
data
[
0
])
print
(
epoch
,
'test rep '
,
rep
,
': top1=%.2f%% top5=%.2f%% nll:%.2f time:%.1fs'
%
(
100
*
(
1
-
1.0
*
stats
[
'top1'
]
/
stats
[
'n'
]),
100
*
(
1
-
1.0
*
stats
[
'top5'
]
/
stats
[
'n'
]),
stats
[
'nll'
]
/
stats
[
'n'
],
time
.
time
()
-
start
),
'%.3e MultiplyAdds/sample %.3e HiddenStates/sample'
%
(
s
.
forward_pass_multiplyAdd_count
/
stats
[
'n'
],
s
.
forward_pass_hidden_states
/
stats
[
'n'
]))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment