Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
9329257b
Unverified
Commit
9329257b
authored
Apr 29, 2020
by
D. Khuê Lê-Huu
Committed by
GitHub
Apr 29, 2020
Browse files
Fix training resuming in references/segmentation (#2142)
parent
1affa2e8
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
5 deletions
+11
-5
references/segmentation/train.py
references/segmentation/train.py
+11
-5
No files found.
references/segmentation/train.py
View file @
9329257b
...
@@ -128,10 +128,6 @@ def main(args):
...
@@ -128,10 +128,6 @@ def main(args):
if
args
.
distributed
:
if
args
.
distributed
:
model
=
torch
.
nn
.
SyncBatchNorm
.
convert_sync_batchnorm
(
model
)
model
=
torch
.
nn
.
SyncBatchNorm
.
convert_sync_batchnorm
(
model
)
if
args
.
resume
:
checkpoint
=
torch
.
load
(
args
.
resume
,
map_location
=
'cpu'
)
model
.
load_state_dict
(
checkpoint
[
'model'
])
model_without_ddp
=
model
model_without_ddp
=
model
if
args
.
distributed
:
if
args
.
distributed
:
model
=
torch
.
nn
.
parallel
.
DistributedDataParallel
(
model
,
device_ids
=
[
args
.
gpu
])
model
=
torch
.
nn
.
parallel
.
DistributedDataParallel
(
model
,
device_ids
=
[
args
.
gpu
])
...
@@ -157,8 +153,15 @@ def main(args):
...
@@ -157,8 +153,15 @@ def main(args):
optimizer
,
optimizer
,
lambda
x
:
(
1
-
x
/
(
len
(
data_loader
)
*
args
.
epochs
))
**
0.9
)
lambda
x
:
(
1
-
x
/
(
len
(
data_loader
)
*
args
.
epochs
))
**
0.9
)
if
args
.
resume
:
checkpoint
=
torch
.
load
(
args
.
resume
,
map_location
=
'cpu'
)
model_without_ddp
.
load_state_dict
(
checkpoint
[
'model'
])
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer'
])
lr_scheduler
.
load_state_dict
(
checkpoint
[
'lr_scheduler'
])
args
.
start_epoch
=
checkpoint
[
'epoch'
]
+
1
start_time
=
time
.
time
()
start_time
=
time
.
time
()
for
epoch
in
range
(
args
.
epochs
):
for
epoch
in
range
(
args
.
start_epoch
,
args
.
epochs
):
if
args
.
distributed
:
if
args
.
distributed
:
train_sampler
.
set_epoch
(
epoch
)
train_sampler
.
set_epoch
(
epoch
)
train_one_epoch
(
model
,
criterion
,
optimizer
,
data_loader
,
lr_scheduler
,
device
,
epoch
,
args
.
print_freq
)
train_one_epoch
(
model
,
criterion
,
optimizer
,
data_loader
,
lr_scheduler
,
device
,
epoch
,
args
.
print_freq
)
...
@@ -168,6 +171,7 @@ def main(args):
...
@@ -168,6 +171,7 @@ def main(args):
{
{
'model'
:
model_without_ddp
.
state_dict
(),
'model'
:
model_without_ddp
.
state_dict
(),
'optimizer'
:
optimizer
.
state_dict
(),
'optimizer'
:
optimizer
.
state_dict
(),
'lr_scheduler'
:
lr_scheduler
.
state_dict
(),
'epoch'
:
epoch
,
'epoch'
:
epoch
,
'args'
:
args
'args'
:
args
},
},
...
@@ -201,6 +205,8 @@ def parse_args():
...
@@ -201,6 +205,8 @@ def parse_args():
parser
.
add_argument
(
'--print-freq'
,
default
=
10
,
type
=
int
,
help
=
'print frequency'
)
parser
.
add_argument
(
'--print-freq'
,
default
=
10
,
type
=
int
,
help
=
'print frequency'
)
parser
.
add_argument
(
'--output-dir'
,
default
=
'.'
,
help
=
'path where to save'
)
parser
.
add_argument
(
'--output-dir'
,
default
=
'.'
,
help
=
'path where to save'
)
parser
.
add_argument
(
'--resume'
,
default
=
''
,
help
=
'resume from checkpoint'
)
parser
.
add_argument
(
'--resume'
,
default
=
''
,
help
=
'resume from checkpoint'
)
parser
.
add_argument
(
'--start-epoch'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'start epoch'
)
parser
.
add_argument
(
parser
.
add_argument
(
"--test-only"
,
"--test-only"
,
dest
=
"test_only"
,
dest
=
"test_only"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment