Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
SparseConvNet
Commits
8210ee32
Commit
8210ee32
authored
Oct 27, 2017
by
Benjamin Thomas Graham
Browse files
confusion matrix
parent
254109fd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
6 deletions
+22
-6
PyTorch/sparseconvnet/classificationTrainValidate.py
PyTorch/sparseconvnet/classificationTrainValidate.py
+22
-6
No files found.
PyTorch/sparseconvnet/classificationTrainValidate.py
View file @
8210ee32
...
...
@@ -13,20 +13,29 @@ import sparseconvnet as s
import
time
import
os
import
math
import
numpy
as
np
import
PIL
def
updateStats
(
stats
,
output
,
target
,
loss
):
batchSize
=
output
.
size
(
0
)
nClasses
=
output
.
size
(
1
)
if
not
stats
:
stats
[
'top1'
]
=
0
stats
[
'top5'
]
=
0
stats
[
'n'
]
=
0
stats
[
'nll'
]
=
0
stats
[
'confusion matrix'
]
=
output
.
new
().
resize_
(
nClasses
,
nClasses
).
zero_
()
stats
[
'n'
]
=
stats
[
'n'
]
+
batchSize
stats
[
'nll'
]
=
stats
[
'nll'
]
+
loss
*
batchSize
_
,
predictions
=
output
.
float
().
sort
(
1
,
True
)
correct
=
predictions
.
eq
(
target
.
long
()
[:,
None
].
expand_as
(
output
))
target
[:,
None
].
expand_as
(
output
))
# Top-1 score
stats
[
'top1'
]
+=
correct
.
narrow
(
1
,
0
,
1
).
sum
()
# Top-5 score
l
=
min
(
5
,
correct
.
size
(
1
))
stats
[
'top5'
]
+=
correct
.
narrow
(
1
,
0
,
l
).
sum
()
stats
[
'confusion matrix'
].
index_add_
(
0
,
target
,
F
.
softmax
(
output
).
data
)
def
ClassificationTrainValidate
(
model
,
dataset
,
p
):
...
...
@@ -66,7 +75,7 @@ def ClassificationTrainValidate(model, dataset, p):
print
(
'#parameters'
,
sum
([
x
.
nelement
()
for
x
in
model
.
parameters
()]))
for
epoch
in
range
(
p
[
'epoch'
],
p
[
'n_epochs'
]
+
1
):
model
.
train
()
stats
=
{
'top1'
:
0
,
'top5'
:
0
,
'n'
:
0
,
'nll'
:
0
}
stats
=
{}
for
param_group
in
optimizer
.
param_groups
:
param_group
[
'lr'
]
=
p
[
'initial_lr'
]
*
\
math
.
exp
((
1
-
epoch
)
*
p
[
'lr_decay'
])
...
...
@@ -93,7 +102,10 @@ def ClassificationTrainValidate(model, dataset, p):
stats
[
'n'
]),
stats
[
'nll'
]
/
stats
[
'n'
],
time
.
time
()
-
start
))
cm
=
stats
[
'confusion matrix'
].
cpu
().
numpy
()
np
.
savetxt
(
'train confusion matrix.csv'
,
cm
,
delimiter
=
','
)
cm
*=
255
/
(
cm
.
sum
(
1
,
keepdims
=
True
)
+
1e-9
)
PIL
.
Image
.
fromarray
(
cm
.
astype
(
'uint8'
),
mode
=
'L'
).
save
(
'train confusion matrix.png'
)
if
p
[
'check_point'
]:
torch
.
save
(
epoch
,
'epoch.pth'
)
torch
.
save
(
model
.
state_dict
(),
'model.pth'
)
...
...
@@ -103,7 +115,7 @@ def ClassificationTrainValidate(model, dataset, p):
s
.
forward_pass_hidden_states
=
0
start
=
time
.
time
()
if
p
[
'test_reps'
]
==
1
:
stats
=
{
'top1'
:
0
,
'top5'
:
0
,
'n'
:
0
,
'nll'
:
0
}
stats
=
{}
for
batch
in
dataset
[
'val'
]():
if
p
[
'use_gpu'
]:
batch
[
'input'
]
=
batch
[
'input'
].
cuda
()
...
...
@@ -145,7 +157,7 @@ def ClassificationTrainValidate(model, dataset, p):
else
:
predictions
.
index_add_
(
0
,
idxs
,
pr
)
loss
=
criterion
(
predictions
/
rep
,
targets
)
stats
=
{
'top1'
:
0
,
'top5'
:
0
,
'n'
:
0
,
'nll'
:
0
}
stats
=
{}
updateStats
(
stats
,
predictions
,
targets
,
loss
.
data
[
0
])
print
(
epoch
,
'test rep '
,
rep
,
': top1=%.2f%% top5=%.2f%% nll:%.2f time:%.1fs'
%
(
...
...
@@ -156,3 +168,7 @@ def ClassificationTrainValidate(model, dataset, p):
'%.3e MultiplyAdds/sample %.3e HiddenStates/sample'
%
(
s
.
forward_pass_multiplyAdd_count
/
stats
[
'n'
],
s
.
forward_pass_hidden_states
/
stats
[
'n'
]))
cm
=
stats
[
'confusion matrix'
].
cpu
().
numpy
()
np
.
savetxt
(
'test confusion matrix.csv'
,
cm
,
delimiter
=
','
)
cm
*=
255
/
(
cm
.
sum
(
1
,
keepdims
=
True
)
+
1e-9
)
PIL
.
Image
.
fromarray
(
cm
.
astype
(
'uint8'
),
mode
=
'L'
).
save
(
'test confusion matrix.png'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment