Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Pytorch-Encoding
Commits
3e384b4a
Unverified
Commit
3e384b4a
authored
Apr 25, 2020
by
Hang Zhang
Committed by
GitHub
Apr 25, 2020
Browse files
DeepLabV3_ResNeSt200_ADE 48.36 mIoU (#264)
parent
69ba6789
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
298 additions
and
20 deletions
+298
-20
docs/source/model_zoo/segmentation.rst
docs/source/model_zoo/segmentation.rst
+24
-19
encoding/models/model_store.py
encoding/models/model_store.py
+1
-0
encoding/nn/dropblock.py
encoding/nn/dropblock.py
+0
-1
experiments/segmentation/train.py
experiments/segmentation/train.py
+273
-0
No files found.
docs/source/model_zoo/segmentation.rst
View file @
3e384b4a
...
@@ -36,31 +36,36 @@ ResNeSt Backbone Models
...
@@ -36,31 +36,36 @@ ResNeSt Backbone Models
Model
pixAcc
mIoU
Command
Model
pixAcc
mIoU
Command
==============================================================================
==============
==============
=========================================================================================================
==============================================================================
==============
==============
=========================================================================================================
FCN_ResNeSt50_ADE
80.18
%
42.94
%
:
raw
-
html
:`<
a
href
=
"javascript:toggleblock('cmd_fcn_nest50_ade')"
class
=
"toggleblock"
>
cmd
</
a
>`
FCN_ResNeSt50_ADE
80.18
%
42.94
%
:
raw
-
html
:`<
a
href
=
"javascript:toggleblock('cmd_fcn_nest50_ade')"
class
=
"toggleblock"
>
cmd
</
a
>`
DeepLabV3_ResNeSt50_ADE
81.17
%
45.12
%
:
raw
-
html
:`<
a
href
=
"javascript:toggleblock('cmd_deeplab_resnest50_ade')"
class
=
"toggleblock"
>
cmd
</
a
>`
DeepLab_ResNeSt50_ADE
81.17
%
45.12
%
:
raw
-
html
:`<
a
href
=
"javascript:toggleblock('cmd_deeplab_resnest50_ade')"
class
=
"toggleblock"
>
cmd
</
a
>`
DeepLabV3_ResNeSt101_ADE
82.07
%
46.91
%
:
raw
-
html
:`<
a
href
=
"javascript:toggleblock('cmd_deeplab_resnest101_ade')"
class
=
"toggleblock"
>
cmd
</
a
>`
DeepLab_ResNeSt101_ADE
82.07
%
46.91
%
:
raw
-
html
:`<
a
href
=
"javascript:toggleblock('cmd_deeplab_resnest101_ade')"
class
=
"toggleblock"
>
cmd
</
a
>`
DeepLabV3_ResNeSt269_ADE
82.62
%
47.60
%
:
raw
-
html
:`<
a
href
=
"javascript:toggleblock('cmd_deeplab_resnest269_ade')"
class
=
"toggleblock"
>
cmd
</
a
>`
DeepLab_ResNeSt200_ADE
82.45
%
48.36
%
:
raw
-
html
:`<
a
href
=
"javascript:toggleblock('cmd_deeplab_resnest200_ade')"
class
=
"toggleblock"
>
cmd
</
a
>`
DeepLab_ResNeSt269_ADE
82.62
%
47.60
%
:
raw
-
html
:`<
a
href
=
"javascript:toggleblock('cmd_deeplab_resnest269_ade')"
class
=
"toggleblock"
>
cmd
</
a
>`
==============================================================================
==============
==============
=========================================================================================================
==============================================================================
==============
==============
=========================================================================================================
..
raw
::
html
..
raw
::
html
<
code
xml
:
space
=
"preserve"
id
=
"cmd_fcn_nest50_ade"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
<
code
xml
:
space
=
"preserve"
id
=
"cmd_fcn_nest50_ade"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
python
train
_dist
.
py
--
dataset
ADE20K
--
model
fcn
--
aux
--
backbone
resnest50
python
train
.
py
--
dataset
ADE20K
--
model
fcn
--
aux
--
backbone
resnest50
</
code
>
</
code
>
<
code
xml
:
space
=
"preserve"
id
=
"cmd_enc_nest50_ade"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
<
code
xml
:
space
=
"preserve"
id
=
"cmd_enc_nest50_ade"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
python
train
_dist
.
py
--
dataset
ADE20K
--
model
EncNet
--
aux
--
se
-
loss
--
backbone
resnest50
python
train
.
py
--
dataset
ADE20K
--
model
EncNet
--
aux
--
se
-
loss
--
backbone
resnest50
</
code
>
</
code
>
<
code
xml
:
space
=
"preserve"
id
=
"cmd_deeplab_resnest50_ade"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
<
code
xml
:
space
=
"preserve"
id
=
"cmd_deeplab_resnest50_ade"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
python
train
_dist
.
py
--
dataset
ADE20K
--
model
deeplab
--
aux
--
backbone
resnest50
python
train
.
py
--
dataset
ADE20K
--
model
deeplab
--
aux
--
backbone
resnest50
</
code
>
</
code
>
<
code
xml
:
space
=
"preserve"
id
=
"cmd_deeplab_resnest101_ade"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
<
code
xml
:
space
=
"preserve"
id
=
"cmd_deeplab_resnest101_ade"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
python
train_dist
.
py
--
dataset
ADE20K
--
model
deeplab
--
aux
--
backbone
resnest101
python
train
.
py
--
dataset
ADE20K
--
model
deeplab
--
aux
--
backbone
resnest101
</
code
>
<
code
xml
:
space
=
"preserve"
id
=
"cmd_deeplab_resnest200_ade"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
python
train
.
py
--
dataset
ADE20K
--
model
deeplab
--
aux
--
backbone
resnest200
--
epochs
180
</
code
>
</
code
>
<
code
xml
:
space
=
"preserve"
id
=
"cmd_deeplab_resnest269_ade"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
<
code
xml
:
space
=
"preserve"
id
=
"cmd_deeplab_resnest269_ade"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
python
train
_dist
.
py
--
dataset
ADE20K
--
model
deeplab
--
aux
--
backbone
resnest269
python
train
.
py
--
dataset
ADE20K
--
model
deeplab
--
aux
--
backbone
resnest269
</
code
>
</
code
>
...
@@ -82,19 +87,19 @@ EncNet_ResNet101s_ADE
...
@@ -82,19 +87,19 @@ EncNet_ResNet101s_ADE
..
raw
::
html
..
raw
::
html
<
code
xml
:
space
=
"preserve"
id
=
"cmd_fcn50_ade"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
<
code
xml
:
space
=
"preserve"
id
=
"cmd_fcn50_ade"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
CUDA_VISIBLE_DEVICES
=
0
,
1
,
2
,
3
python
train
_dist
.
py
--
dataset
ADE20K
--
model
FCN
CUDA_VISIBLE_DEVICES
=
0
,
1
,
2
,
3
python
train
.
py
--
dataset
ADE20K
--
model
FCN
</
code
>
</
code
>
<
code
xml
:
space
=
"preserve"
id
=
"cmd_psp50_ade"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
<
code
xml
:
space
=
"preserve"
id
=
"cmd_psp50_ade"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
CUDA_VISIBLE_DEVICES
=
0
,
1
,
2
,
3
python
train
_dist
.
py
--
dataset
ADE20K
--
model
PSP
--
aux
CUDA_VISIBLE_DEVICES
=
0
,
1
,
2
,
3
python
train
.
py
--
dataset
ADE20K
--
model
PSP
--
aux
</
code
>
</
code
>
<
code
xml
:
space
=
"preserve"
id
=
"cmd_enc50_ade"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
<
code
xml
:
space
=
"preserve"
id
=
"cmd_enc50_ade"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
CUDA_VISIBLE_DEVICES
=
0
,
1
,
2
,
3
python
train
_dist
.
py
--
dataset
ADE20K
--
model
EncNet
--
aux
--
se
-
loss
CUDA_VISIBLE_DEVICES
=
0
,
1
,
2
,
3
python
train
.
py
--
dataset
ADE20K
--
model
EncNet
--
aux
--
se
-
loss
</
code
>
</
code
>
<
code
xml
:
space
=
"preserve"
id
=
"cmd_enc101_ade"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
<
code
xml
:
space
=
"preserve"
id
=
"cmd_enc101_ade"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
CUDA_VISIBLE_DEVICES
=
0
,
1
,
2
,
3
python
train
_dist
.
py
--
dataset
ADE20K
--
model
EncNet
--
aux
--
se
-
loss
--
backbone
resnet101
CUDA_VISIBLE_DEVICES
=
0
,
1
,
2
,
3
python
train
.
py
--
dataset
ADE20K
--
model
EncNet
--
aux
--
se
-
loss
--
backbone
resnet101
</
code
>
</
code
>
Pascal
Context
Dataset
Pascal
Context
Dataset
...
@@ -110,15 +115,15 @@ EncNet_ResNet101s_PContext
...
@@ -110,15 +115,15 @@ EncNet_ResNet101s_PContext
..
raw
::
html
..
raw
::
html
<
code
xml
:
space
=
"preserve"
id
=
"cmd_fcn50_pcont"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
<
code
xml
:
space
=
"preserve"
id
=
"cmd_fcn50_pcont"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
CUDA_VISIBLE_DEVICES
=
0
,
1
,
2
,
3
python
train
_dist
.
py
--
dataset
PContext
--
model
FCN
CUDA_VISIBLE_DEVICES
=
0
,
1
,
2
,
3
python
train
.
py
--
dataset
PContext
--
model
FCN
</
code
>
</
code
>
<
code
xml
:
space
=
"preserve"
id
=
"cmd_enc50_pcont"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
<
code
xml
:
space
=
"preserve"
id
=
"cmd_enc50_pcont"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
CUDA_VISIBLE_DEVICES
=
0
,
1
,
2
,
3
python
train
_dist
.
py
--
dataset
PContext
--
model
EncNet
--
aux
--
se
-
loss
CUDA_VISIBLE_DEVICES
=
0
,
1
,
2
,
3
python
train
.
py
--
dataset
PContext
--
model
EncNet
--
aux
--
se
-
loss
</
code
>
</
code
>
<
code
xml
:
space
=
"preserve"
id
=
"cmd_enc101_pcont"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
<
code
xml
:
space
=
"preserve"
id
=
"cmd_enc101_pcont"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
CUDA_VISIBLE_DEVICES
=
0
,
1
,
2
,
3
python
train
_dist
.
py
--
dataset
PContext
--
model
EncNet
--
aux
--
se
-
loss
--
backbone
resnet101
CUDA_VISIBLE_DEVICES
=
0
,
1
,
2
,
3
python
train
.
py
--
dataset
PContext
--
model
EncNet
--
aux
--
se
-
loss
--
backbone
resnet101
</
code
>
</
code
>
...
@@ -136,9 +141,9 @@ EncNet_ResNet101s_VOC
...
@@ -136,9 +141,9 @@ EncNet_ResNet101s_VOC
<
code
xml
:
space
=
"preserve"
id
=
"cmd_enc101_voc"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
<
code
xml
:
space
=
"preserve"
id
=
"cmd_enc101_voc"
style
=
"display: none; text-align: left; white-space: pre-wrap"
>
#
First
finetuning
COCO
dataset
pretrained
model
on
augmented
set
#
First
finetuning
COCO
dataset
pretrained
model
on
augmented
set
#
You
can
also
train
from
scratch
on
COCO
by
yourself
#
You
can
also
train
from
scratch
on
COCO
by
yourself
CUDA_VISIBLE_DEVICES
=
0
,
1
,
2
,
3
python
train
_dist
.
py
--
dataset
Pascal_aug
--
model
-
zoo
EncNet_Resnet101_COCO
--
aux
--
se
-
loss
--
lr
0.001
--
syncbn
--
ngpus
4
--
checkname
res101
--
ft
CUDA_VISIBLE_DEVICES
=
0
,
1
,
2
,
3
python
train
.
py
--
dataset
Pascal_aug
--
model
-
zoo
EncNet_Resnet101_COCO
--
aux
--
se
-
loss
--
lr
0.001
--
syncbn
--
ngpus
4
--
checkname
res101
--
ft
#
Finetuning
on
original
set
#
Finetuning
on
original
set
CUDA_VISIBLE_DEVICES
=
0
,
1
,
2
,
3
python
train
_dist
.
py
--
dataset
Pascal_voc
--
model
encnet
--
aux
--
se
-
loss
--
backbone
resnet101
--
lr
0.0001
--
syncbn
--
ngpus
4
--
checkname
res101
--
resume
runs
/
Pascal_aug
/
encnet
/
res101
/
checkpoint
.
params
--
ft
CUDA_VISIBLE_DEVICES
=
0
,
1
,
2
,
3
python
train
.
py
--
dataset
Pascal_voc
--
model
encnet
--
aux
--
se
-
loss
--
backbone
resnet101
--
lr
0.0001
--
syncbn
--
ngpus
4
--
checkname
res101
--
resume
runs
/
Pascal_aug
/
encnet
/
res101
/
checkpoint
.
params
--
ft
</
code
>
</
code
>
...
@@ -165,9 +170,9 @@ Train Your Own Model
...
@@ -165,9 +170,9 @@ Train Your Own Model
-
The
training
script
is
in
the
``
experiments
/
segmentation
/``
folder
,
example
training
command
::
-
The
training
script
is
in
the
``
experiments
/
segmentation
/``
folder
,
example
training
command
::
python
train
_dist
.
py
--
dataset
ade20k
--
model
encnet
--
aux
--
se
-
loss
python
train
.
py
--
dataset
ade20k
--
model
encnet
--
aux
--
se
-
loss
-
Detail
training
options
,
please
run
``
python
train
_dist
.
py
-
h
``.
Commands
for
reproducing
pre
-
trained
models
can
be
found
in
the
table
.
-
Detail
training
options
,
please
run
``
python
train
.
py
-
h
``.
Commands
for
reproducing
pre
-
trained
models
can
be
found
in
the
table
.
..
hint
::
..
hint
::
The
validation
metrics
during
the
training
only
using
center
-
crop
is
just
for
monitoring
the
The
validation
metrics
during
the
training
only
using
center
-
crop
is
just
for
monitoring
the
...
...
encoding/models/model_store.py
View file @
3e384b4a
...
@@ -34,6 +34,7 @@ _model_sha1 = {name: checksum for checksum, name in [
...
@@ -34,6 +34,7 @@ _model_sha1 = {name: checksum for checksum, name in [
(
'4aba491aaf8e4866a9c9981b210e3e3266ac1f2a'
,
'fcn_resnest50_ade'
),
(
'4aba491aaf8e4866a9c9981b210e3e3266ac1f2a'
,
'fcn_resnest50_ade'
),
(
'2225f09d0f40b9a168d9091652194bc35ec2a5a9'
,
'deeplab_resnest50_ade'
),
(
'2225f09d0f40b9a168d9091652194bc35ec2a5a9'
,
'deeplab_resnest50_ade'
),
(
'06ca799c8cc148fe0fafb5b6d052052935aa3cc8'
,
'deeplab_resnest101_ade'
),
(
'06ca799c8cc148fe0fafb5b6d052052935aa3cc8'
,
'deeplab_resnest101_ade'
),
(
'7b9e7d3e6f0e2c763c7d77cad14d306c0a31fe05'
,
'deeplab_resnest200_ade'
),
(
'0074dd10a6e6696f6f521653fb98224e75955496'
,
'deeplab_resnest269_ade'
),
(
'0074dd10a6e6696f6f521653fb98224e75955496'
,
'deeplab_resnest269_ade'
),
]}
]}
...
...
encoding/nn/dropblock.py
View file @
3e384b4a
...
@@ -123,5 +123,4 @@ def reset_dropblock(start_step, nr_steps, start_value, stop_value, m):
...
@@ -123,5 +123,4 @@ def reset_dropblock(start_step, nr_steps, start_value, stop_value, m):
net.apply(apply_drop_prob)
net.apply(apply_drop_prob)
"""
"""
if
isinstance
(
m
,
DropBlock2D
):
if
isinstance
(
m
,
DropBlock2D
):
print
(
'reseting dropblock'
)
m
.
reset_steps
(
start_step
,
nr_steps
,
start_value
,
stop_value
)
m
.
reset_steps
(
start_step
,
nr_steps
,
start_value
,
stop_value
)
experiments/segmentation/train.py
0 → 100644
View file @
3e384b4a
###########################################################################
# Created by: Hang Zhang
# Email: zhang.hang@rutgers.edu
# Copyright (c) 2017
###########################################################################
import
os
import
copy
import
argparse
import
numpy
as
np
from
tqdm
import
tqdm
import
torch
from
torch.utils
import
data
import
torchvision.transforms
as
transform
from
torch.nn.parallel.scatter_gather
import
gather
import
encoding.utils
as
utils
from
encoding.nn
import
SegmentationLosses
,
SyncBatchNorm
from
encoding.parallel
import
DataParallelModel
,
DataParallelCriterion
from
encoding.datasets
import
get_dataset
from
encoding.models
import
get_segmentation_model
class
Options
():
def
__init__
(
self
):
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch
\
Segmentation'
)
# model and dataset
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'encnet'
,
help
=
'model name (default: encnet)'
)
parser
.
add_argument
(
'--backbone'
,
type
=
str
,
default
=
'resnet50'
,
help
=
'backbone name (default: resnet50)'
)
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
default
=
'ade20k'
,
help
=
'dataset name (default: pascal12)'
)
parser
.
add_argument
(
'--workers'
,
type
=
int
,
default
=
16
,
metavar
=
'N'
,
help
=
'dataloader threads'
)
parser
.
add_argument
(
'--base-size'
,
type
=
int
,
default
=
520
,
help
=
'base image size'
)
parser
.
add_argument
(
'--crop-size'
,
type
=
int
,
default
=
480
,
help
=
'crop image size'
)
parser
.
add_argument
(
'--train-split'
,
type
=
str
,
default
=
'train'
,
help
=
'dataset train split (default: train)'
)
# training hyper params
parser
.
add_argument
(
'--aux'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Auxilary Loss'
)
parser
.
add_argument
(
'--aux-weight'
,
type
=
float
,
default
=
0.2
,
help
=
'Auxilary loss weight (default: 0.2)'
)
parser
.
add_argument
(
'--se-loss'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Semantic Encoding Loss SE-loss'
)
parser
.
add_argument
(
'--se-weight'
,
type
=
float
,
default
=
0.2
,
help
=
'SE-loss weight (default: 0.2)'
)
parser
.
add_argument
(
'--epochs'
,
type
=
int
,
default
=
None
,
metavar
=
'N'
,
help
=
'number of epochs to train (default: auto)'
)
parser
.
add_argument
(
'--start_epoch'
,
type
=
int
,
default
=
0
,
metavar
=
'N'
,
help
=
'start epochs (default:0)'
)
parser
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
16
,
metavar
=
'N'
,
help
=
'input batch size for
\
training (default: auto)'
)
parser
.
add_argument
(
'--test-batch-size'
,
type
=
int
,
default
=
16
,
metavar
=
'N'
,
help
=
'input batch size for
\
testing (default: same as batch size)'
)
# optimizer params
parser
.
add_argument
(
'--lr'
,
type
=
float
,
default
=
None
,
metavar
=
'LR'
,
help
=
'learning rate (default: auto)'
)
parser
.
add_argument
(
'--lr-scheduler'
,
type
=
str
,
default
=
'poly'
,
help
=
'learning rate scheduler (default: poly)'
)
parser
.
add_argument
(
'--momentum'
,
type
=
float
,
default
=
0.9
,
metavar
=
'M'
,
help
=
'momentum (default: 0.9)'
)
parser
.
add_argument
(
'--weight-decay'
,
type
=
float
,
default
=
1e-4
,
metavar
=
'M'
,
help
=
'w-decay (default: 1e-4)'
)
# cuda, seed and logging
parser
.
add_argument
(
'--no-cuda'
,
action
=
'store_true'
,
default
=
False
,
help
=
'disables CUDA training'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
1
,
metavar
=
'S'
,
help
=
'random seed (default: 1)'
)
# checking point
parser
.
add_argument
(
'--resume'
,
type
=
str
,
default
=
None
,
help
=
'put the path to resuming file if needed'
)
parser
.
add_argument
(
'--checkname'
,
type
=
str
,
default
=
'default'
,
help
=
'set the checkpoint name'
)
parser
.
add_argument
(
'--model-zoo'
,
type
=
str
,
default
=
None
,
help
=
'evaluating on model zoo model'
)
# finetuning pre-trained models
parser
.
add_argument
(
'--ft'
,
action
=
'store_true'
,
default
=
False
,
help
=
'finetuning on a different dataset'
)
# evaluation option
parser
.
add_argument
(
'--eval'
,
action
=
'store_true'
,
default
=
False
,
help
=
'evaluating mIoU'
)
parser
.
add_argument
(
'--test-val'
,
action
=
'store_true'
,
default
=
False
,
help
=
'generate masks on val set'
)
parser
.
add_argument
(
'--no-val'
,
action
=
'store_true'
,
default
=
False
,
help
=
'skip validation during training'
)
# test option
parser
.
add_argument
(
'--test-folder'
,
type
=
str
,
default
=
None
,
help
=
'path to test image folder'
)
# the parser
self
.
parser
=
parser
def
parse
(
self
):
args
=
self
.
parser
.
parse_args
()
args
.
cuda
=
not
args
.
no_cuda
and
torch
.
cuda
.
is_available
()
# default settings for epochs, batch_size and lr
if
args
.
epochs
is
None
:
epoches
=
{
'coco'
:
30
,
'pascal_aug'
:
80
,
'pascal_voc'
:
50
,
'pcontext'
:
80
,
'ade20k'
:
180
,
'citys'
:
240
,
}
args
.
epochs
=
epoches
[
args
.
dataset
.
lower
()]
if
args
.
lr
is
None
:
lrs
=
{
'coco'
:
0.004
,
'pascal_aug'
:
0.001
,
'pascal_voc'
:
0.0001
,
'pcontext'
:
0.001
,
'ade20k'
:
0.004
,
'citys'
:
0.004
,
}
args
.
lr
=
lrs
[
args
.
dataset
.
lower
()]
/
16
*
args
.
batch_size
print
(
args
)
return
args
class
Trainer
():
def
__init__
(
self
,
args
):
self
.
args
=
args
# data transforms
input_transform
=
transform
.
Compose
([
transform
.
ToTensor
(),
transform
.
Normalize
([.
485
,
.
456
,
.
406
],
[.
229
,
.
224
,
.
225
])])
# dataset
data_kwargs
=
{
'transform'
:
input_transform
,
'base_size'
:
args
.
base_size
,
'crop_size'
:
args
.
crop_size
}
trainset
=
get_dataset
(
args
.
dataset
,
split
=
args
.
train_split
,
mode
=
'train'
,
**
data_kwargs
)
testset
=
get_dataset
(
args
.
dataset
,
split
=
'val'
,
mode
=
'val'
,
**
data_kwargs
)
# dataloader
kwargs
=
{
'num_workers'
:
args
.
workers
,
'pin_memory'
:
True
}
\
if
args
.
cuda
else
{}
self
.
trainloader
=
data
.
DataLoader
(
trainset
,
batch_size
=
args
.
batch_size
,
drop_last
=
True
,
shuffle
=
True
,
**
kwargs
)
self
.
valloader
=
data
.
DataLoader
(
testset
,
batch_size
=
args
.
batch_size
,
drop_last
=
False
,
shuffle
=
False
,
**
kwargs
)
self
.
nclass
=
trainset
.
num_class
# model
model
=
get_segmentation_model
(
args
.
model
,
dataset
=
args
.
dataset
,
backbone
=
args
.
backbone
,
aux
=
args
.
aux
,
se_loss
=
args
.
se_loss
,
norm_layer
=
SyncBatchNorm
,
base_size
=
args
.
base_size
,
crop_size
=
args
.
crop_size
)
print
(
model
)
# optimizer using different LR
params_list
=
[{
'params'
:
model
.
pretrained
.
parameters
(),
'lr'
:
args
.
lr
},]
if
hasattr
(
model
,
'head'
):
params_list
.
append
({
'params'
:
model
.
head
.
parameters
(),
'lr'
:
args
.
lr
*
10
})
if
hasattr
(
model
,
'auxlayer'
):
params_list
.
append
({
'params'
:
model
.
auxlayer
.
parameters
(),
'lr'
:
args
.
lr
*
10
})
optimizer
=
torch
.
optim
.
SGD
(
params_list
,
lr
=
args
.
lr
,
momentum
=
args
.
momentum
,
weight_decay
=
args
.
weight_decay
)
# criterions
self
.
criterion
=
SegmentationLosses
(
se_loss
=
args
.
se_loss
,
aux
=
args
.
aux
,
nclass
=
self
.
nclass
,
se_weight
=
args
.
se_weight
,
aux_weight
=
args
.
aux_weight
)
self
.
model
,
self
.
optimizer
=
model
,
optimizer
# using cuda
if
args
.
cuda
:
self
.
model
=
DataParallelModel
(
self
.
model
).
cuda
()
self
.
criterion
=
DataParallelCriterion
(
self
.
criterion
).
cuda
()
# resuming checkpoint
if
args
.
resume
is
not
None
:
if
not
os
.
path
.
isfile
(
args
.
resume
):
raise
RuntimeError
(
"=> no checkpoint found at '{}'"
.
format
(
args
.
resume
))
checkpoint
=
torch
.
load
(
args
.
resume
)
args
.
start_epoch
=
checkpoint
[
'epoch'
]
if
args
.
cuda
:
self
.
model
.
module
.
load_state_dict
(
checkpoint
[
'state_dict'
])
else
:
self
.
model
.
load_state_dict
(
checkpoint
[
'state_dict'
])
if
not
args
.
ft
:
self
.
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer'
])
self
.
best_pred
=
checkpoint
[
'best_pred'
]
print
(
"=> loaded checkpoint '{}' (epoch {})"
.
format
(
args
.
resume
,
checkpoint
[
'epoch'
]))
# clear start epoch if fine-tuning
if
args
.
ft
:
args
.
start_epoch
=
0
# lr scheduler
self
.
scheduler
=
utils
.
LR_Scheduler_Head
(
args
.
lr_scheduler
,
args
.
lr
,
args
.
epochs
,
len
(
self
.
trainloader
))
self
.
best_pred
=
0.0
def
training
(
self
,
epoch
):
train_loss
=
0.0
self
.
model
.
train
()
tbar
=
tqdm
(
self
.
trainloader
)
for
i
,
(
image
,
target
)
in
enumerate
(
tbar
):
self
.
scheduler
(
self
.
optimizer
,
i
,
epoch
,
self
.
best_pred
)
self
.
optimizer
.
zero_grad
()
outputs
=
self
.
model
(
image
)
loss
=
self
.
criterion
(
outputs
,
target
)
loss
.
backward
()
self
.
optimizer
.
step
()
train_loss
+=
loss
.
item
()
tbar
.
set_description
(
'Train loss: %.3f'
%
(
train_loss
/
(
i
+
1
)))
if
self
.
args
.
no_val
:
# save checkpoint every epoch
is_best
=
False
utils
.
save_checkpoint
({
'epoch'
:
epoch
+
1
,
'state_dict'
:
self
.
model
.
module
.
state_dict
(),
'optimizer'
:
self
.
optimizer
.
state_dict
(),
'best_pred'
:
self
.
best_pred
,
},
self
.
args
,
is_best
)
def
validation
(
self
,
epoch
):
# Fast test during the training
def
eval_batch
(
model
,
image
,
target
):
outputs
=
model
(
image
)
outputs
=
gather
(
outputs
,
0
,
dim
=
0
)
pred
=
outputs
[
0
]
target
=
target
.
cuda
()
correct
,
labeled
=
utils
.
batch_pix_accuracy
(
pred
.
data
,
target
)
inter
,
union
=
utils
.
batch_intersection_union
(
pred
.
data
,
target
,
self
.
nclass
)
return
correct
,
labeled
,
inter
,
union
is_best
=
False
self
.
model
.
eval
()
total_inter
,
total_union
,
total_correct
,
total_label
=
0
,
0
,
0
,
0
tbar
=
tqdm
(
self
.
valloader
,
desc
=
'
\r
'
)
for
i
,
(
image
,
target
)
in
enumerate
(
tbar
):
with
torch
.
no_grad
():
correct
,
labeled
,
inter
,
union
=
eval_batch
(
self
.
model
,
image
,
target
)
total_correct
+=
correct
total_label
+=
labeled
total_inter
+=
inter
total_union
+=
union
pixAcc
=
1.0
*
total_correct
/
(
np
.
spacing
(
1
)
+
total_label
)
IoU
=
1.0
*
total_inter
/
(
np
.
spacing
(
1
)
+
total_union
)
mIoU
=
IoU
.
mean
()
tbar
.
set_description
(
'pixAcc: %.3f, mIoU: %.3f'
%
(
pixAcc
,
mIoU
))
new_pred
=
(
pixAcc
+
mIoU
)
/
2
if
new_pred
>
self
.
best_pred
:
is_best
=
True
self
.
best_pred
=
new_pred
utils
.
save_checkpoint
({
'epoch'
:
epoch
+
1
,
'state_dict'
:
self
.
model
.
module
.
state_dict
(),
'optimizer'
:
self
.
optimizer
.
state_dict
(),
'best_pred'
:
self
.
best_pred
,
},
self
.
args
,
is_best
)
if
__name__
==
"__main__"
:
args
=
Options
().
parse
()
torch
.
manual_seed
(
args
.
seed
)
trainer
=
Trainer
(
args
)
print
(
'Starting Epoch:'
,
trainer
.
args
.
start_epoch
)
print
(
'Total Epoches:'
,
trainer
.
args
.
epochs
)
if
args
.
eval
:
trainer
.
validation
(
trainer
.
args
.
start_epoch
)
else
:
for
epoch
in
range
(
trainer
.
args
.
start_epoch
,
trainer
.
args
.
epochs
):
trainer
.
training
(
epoch
)
if
not
trainer
.
args
.
no_val
:
trainer
.
validation
(
epoch
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment