Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
SparseConvNet
Commits
a0d33d69
Commit
a0d33d69
authored
Oct 27, 2017
by
Benjamin Thomas Graham
Browse files
confusion matrix
parent
8210ee32
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
PyTorch/sparseconvnet/classificationTrainValidate.py
PyTorch/sparseconvnet/classificationTrainValidate.py
+3
-3
No files found.
PyTorch/sparseconvnet/classificationTrainValidate.py
View file @
a0d33d69
...
@@ -14,7 +14,7 @@ import time
...
@@ -14,7 +14,7 @@ import time
import
os
import
os
import
math
import
math
import
numpy
as
np
import
numpy
as
np
import
PIL
from
PIL
import
Image
def
updateStats
(
stats
,
output
,
target
,
loss
):
def
updateStats
(
stats
,
output
,
target
,
loss
):
batchSize
=
output
.
size
(
0
)
batchSize
=
output
.
size
(
0
)
...
@@ -105,7 +105,7 @@ def ClassificationTrainValidate(model, dataset, p):
...
@@ -105,7 +105,7 @@ def ClassificationTrainValidate(model, dataset, p):
cm
=
stats
[
'confusion matrix'
].
cpu
().
numpy
()
cm
=
stats
[
'confusion matrix'
].
cpu
().
numpy
()
np
.
savetxt
(
'train confusion matrix.csv'
,
cm
,
delimiter
=
','
)
np
.
savetxt
(
'train confusion matrix.csv'
,
cm
,
delimiter
=
','
)
cm
*=
255
/
(
cm
.
sum
(
1
,
keepdims
=
True
)
+
1e-9
)
cm
*=
255
/
(
cm
.
sum
(
1
,
keepdims
=
True
)
+
1e-9
)
PIL
.
Image
.
fromarray
(
cm
.
astype
(
'uint8'
),
mode
=
'L'
).
save
(
'train confusion matrix.png'
)
Image
.
fromarray
(
cm
.
astype
(
'uint8'
),
mode
=
'L'
).
save
(
'train confusion matrix.png'
)
if
p
[
'check_point'
]:
if
p
[
'check_point'
]:
torch
.
save
(
epoch
,
'epoch.pth'
)
torch
.
save
(
epoch
,
'epoch.pth'
)
torch
.
save
(
model
.
state_dict
(),
'model.pth'
)
torch
.
save
(
model
.
state_dict
(),
'model.pth'
)
...
@@ -171,4 +171,4 @@ def ClassificationTrainValidate(model, dataset, p):
...
@@ -171,4 +171,4 @@ def ClassificationTrainValidate(model, dataset, p):
cm
=
stats
[
'confusion matrix'
].
cpu
().
numpy
()
cm
=
stats
[
'confusion matrix'
].
cpu
().
numpy
()
np
.
savetxt
(
'test confusion matrix.csv'
,
cm
,
delimiter
=
','
)
np
.
savetxt
(
'test confusion matrix.csv'
,
cm
,
delimiter
=
','
)
cm
*=
255
/
(
cm
.
sum
(
1
,
keepdims
=
True
)
+
1e-9
)
cm
*=
255
/
(
cm
.
sum
(
1
,
keepdims
=
True
)
+
1e-9
)
PIL
.
Image
.
fromarray
(
cm
.
astype
(
'uint8'
),
mode
=
'L'
).
save
(
'test confusion matrix.png'
)
Image
.
fromarray
(
cm
.
astype
(
'uint8'
),
mode
=
'L'
).
save
(
'test confusion matrix.png'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment