Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
RODNet
Commits
cc8336e3
"vscode:/vscode.git/clone" did not exist on "c6e6c8ee7e86b774436a18e72bb3b5f1c495ae43"
Commit
cc8336e3
authored
Feb 05, 2022
by
Yizhou Wang
Browse files
add average loss to log and tensorboard
parent
1d635cff
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
4 deletions
+11
-4
tools/train.py
tools/train.py
+11
-4
No files found.
tools/train.py
View file @
cc8336e3
...
...
@@ -2,6 +2,7 @@ import os
import
time
import
json
import
argparse
import
numpy
as
np
import
torch
import
torch.nn
as
nn
...
...
@@ -173,6 +174,8 @@ if __name__ == "__main__":
scheduler
=
StepLR
(
optimizer
,
step_size
=
config_dict
[
'train_cfg'
][
'lr_step'
],
gamma
=
0.1
)
iter_count
=
0
loss_ave
=
0
if
cp_path
is
not
None
:
checkpoint
=
torch
.
load
(
cp_path
)
if
'optimizer_state_dict'
in
checkpoint
:
...
...
@@ -229,13 +232,15 @@ if __name__ == "__main__":
loss_confmap
.
backward
()
optimizer
.
step
()
loss_ave
=
np
.
average
([
loss_ave
,
loss_confmap
.
item
()],
weights
=
[
iter_count
,
1
])
if
iter
%
config_dict
[
'train_cfg'
][
'log_step'
]
==
0
:
# print statistics
print
(
'epoch %2d, iter %4d: loss: %.
8f
| load time: %.
4
f | backward time: %.
4
f'
%
(
epoch
+
1
,
iter
+
1
,
loss_confmap
.
item
(),
tic
-
tic_load
,
time
.
time
()
-
tic
))
print
(
'epoch %2d, iter %4d: loss: %.
6f (%.4f)
| load time: %.
2
f | backward time: %.
2
f'
%
(
epoch
+
1
,
iter
+
1
,
loss_confmap
.
item
(),
loss_ave
,
tic
-
tic_load
,
time
.
time
()
-
tic
))
with
open
(
train_log_name
,
'a+'
)
as
f_log
:
f_log
.
write
(
'epoch %2d, iter %4d: loss: %.
8f
| load time: %.
4
f | backward time: %.
4
f
\n
'
%
(
epoch
+
1
,
iter
+
1
,
loss_confmap
.
item
(),
tic
-
tic_load
,
time
.
time
()
-
tic
))
f_log
.
write
(
'epoch %2d, iter %4d: loss: %.
6f (%.4f)
| load time: %.
2
f | backward time: %.
2
f
\n
'
%
(
epoch
+
1
,
iter
+
1
,
loss_confmap
.
item
(),
loss_ave
,
tic
-
tic_load
,
time
.
time
()
-
tic
))
if
stacked_num
is
not
None
:
writer
.
add_scalar
(
'loss/loss_all'
,
loss_confmap
.
item
(),
iter_count
)
...
...
@@ -243,6 +248,8 @@ if __name__ == "__main__":
else
:
writer
.
add_scalar
(
'loss/loss_all'
,
loss_confmap
.
item
(),
iter_count
)
confmap_pred
=
confmap_preds
.
cpu
().
detach
().
numpy
()
writer
.
add_scalar
(
'loss/loss_ave'
,
loss_ave
,
iter_count
)
if
'mnet_cfg'
in
model_cfg
:
chirp_amp_curr
=
chirp_amp
(
data
.
numpy
()[
0
,
:,
0
,
0
,
:,
:],
radar_configs
[
'data_type'
])
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment