Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
8500c14e
Commit
8500c14e
authored
Dec 18, 2018
by
ThangVu
Browse files
add group norm on backbone and unit tests
parent
033e537e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
595 additions
and
57 deletions
+595
-57
mmdet/models/backbones/resnet.py
mmdet/models/backbones/resnet.py
+184
-44
mmdet/models/bbox_heads/convfc_bbox_head.py
mmdet/models/bbox_heads/convfc_bbox_head.py
+4
-11
mmdet/models/utils/norm.py
mmdet/models/utils/norm.py
+4
-2
tools/train_imagenet/train_imagenet.py
tools/train_imagenet/train_imagenet.py
+403
-0
No files found.
mmdet/models/backbones/resnet.py
View file @
8500c14e
import
logging
import
pickle
import
torch
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
cp
from
mmcv.cnn
import
constant_init
,
kaiming_init
from
mmcv.runner
import
load_checkpoint
from
..utils
import
build_norm_layer
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
dilation
=
1
):
...
...
@@ -29,13 +32,21 @@ class BasicBlock(nn.Module):
dilation
=
1
,
downsample
=
None
,
style
=
'pytorch'
,
with_cp
=
False
):
with_cp
=
False
,
normalize
=
dict
(
type
=
'GN'
)):
super
(
BasicBlock
,
self
).
__init__
()
self
.
conv1
=
conv3x3
(
inplanes
,
planes
,
stride
,
dilation
)
self
.
bn1
=
nn
.
BatchNorm2d
(
planes
)
norm_layers
=
[]
norm_layers
.
append
(
build_norm_layer
(
normalize
,
planes
))
norm_layers
.
append
(
build_norm_layer
(
normalize
,
planes
))
self
.
norm_names
=
([
'gn1'
,
'gn2'
]
if
normalize
[
'type'
]
==
'GN'
else
[
'bn1'
,
'bn2'
])
for
name
,
layer
in
zip
(
self
.
norm_names
,
norm_layers
):
self
.
add_module
(
name
,
layer
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
conv2
=
conv3x3
(
planes
,
planes
)
self
.
bn2
=
nn
.
BatchNorm2d
(
planes
)
self
.
downsample
=
downsample
self
.
stride
=
stride
self
.
dilation
=
dilation
...
...
@@ -45,11 +56,11 @@ class BasicBlock(nn.Module):
residual
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
getattr
(
self
,
self
.
norm_names
[
0
])
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
out
=
getattr
(
self
,
self
.
norm_names
[
1
])
(
out
)
if
self
.
downsample
is
not
None
:
residual
=
self
.
downsample
(
x
)
...
...
@@ -70,7 +81,8 @@ class Bottleneck(nn.Module):
dilation
=
1
,
downsample
=
None
,
style
=
'pytorch'
,
with_cp
=
False
):
with_cp
=
False
,
normalize
=
dict
(
type
=
'BN'
)):
"""Bottleneck block.
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
...
...
@@ -94,16 +106,23 @@ class Bottleneck(nn.Module):
dilation
=
dilation
,
bias
=
False
)
self
.
bn1
=
nn
.
BatchNorm2d
(
planes
)
self
.
bn2
=
nn
.
BatchNorm2d
(
planes
)
norm_layers
=
[]
norm_layers
.
append
(
build_norm_layer
(
normalize
,
planes
))
norm_layers
.
append
(
build_norm_layer
(
normalize
,
planes
))
norm_layers
.
append
(
build_norm_layer
(
normalize
,
planes
*
self
.
expansion
))
self
.
norm_names
=
([
'gn1'
,
'gn2'
,
'gn3'
]
if
normalize
[
'type'
]
==
'GN'
else
[
'bn1'
,
'bn2'
,
'bn3'
])
for
name
,
layer
in
zip
(
self
.
norm_names
,
norm_layers
):
self
.
add_module
(
name
,
layer
)
self
.
conv3
=
nn
.
Conv2d
(
planes
,
planes
*
self
.
expansion
,
kernel_size
=
1
,
bias
=
False
)
self
.
bn3
=
nn
.
BatchNorm2d
(
planes
*
self
.
expansion
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
downsample
=
downsample
self
.
stride
=
stride
self
.
dilation
=
dilation
self
.
with_cp
=
with_cp
self
.
normalize
=
normalize
def
forward
(
self
,
x
):
...
...
@@ -111,15 +130,15 @@ class Bottleneck(nn.Module):
residual
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
getattr
(
self
,
self
.
norm_names
[
0
])
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
out
=
getattr
(
self
,
self
.
norm_names
[
1
])
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv3
(
out
)
out
=
self
.
bn3
(
out
)
out
=
getattr
(
self
,
self
.
norm_names
[
2
])
(
out
)
if
self
.
downsample
is
not
None
:
residual
=
self
.
downsample
(
x
)
...
...
@@ -145,7 +164,8 @@ def make_res_layer(block,
stride
=
1
,
dilation
=
1
,
style
=
'pytorch'
,
with_cp
=
False
):
with_cp
=
False
,
normalize
=
dict
(
type
=
'BN'
)):
downsample
=
None
if
stride
!=
1
or
inplanes
!=
planes
*
block
.
expansion
:
downsample
=
nn
.
Sequential
(
...
...
@@ -155,7 +175,7 @@ def make_res_layer(block,
kernel_size
=
1
,
stride
=
stride
,
bias
=
False
),
nn
.
BatchNorm2d
(
planes
*
block
.
expansion
),
build_norm_layer
(
normalize
,
planes
*
block
.
expansion
),
)
layers
=
[]
...
...
@@ -167,11 +187,13 @@ def make_res_layer(block,
dilation
,
downsample
,
style
=
style
,
with_cp
=
with_cp
))
with_cp
=
with_cp
,
normalize
=
normalize
))
inplanes
=
planes
*
block
.
expansion
for
i
in
range
(
1
,
blocks
):
layers
.
append
(
block
(
inplanes
,
planes
,
1
,
dilation
,
style
=
style
,
with_cp
=
with_cp
))
block
(
inplanes
,
planes
,
1
,
dilation
,
style
=
style
,
with_cp
=
with_cp
,
normalize
=
normalize
))
return
nn
.
Sequential
(
*
layers
)
...
...
@@ -212,9 +234,11 @@ class ResNet(nn.Module):
dilations
=
(
1
,
1
,
1
,
1
),
out_indices
=
(
0
,
1
,
2
,
3
),
style
=
'pytorch'
,
frozen_stages
=-
1
,
bn_eval
=
True
,
bn_frozen
=
False
,
normalize
=
dict
(
type
=
'BN'
,
frozen_stages
=-
1
,
bn_eval
=
True
,
bn_frozen
=
False
),
with_cp
=
False
):
super
(
ResNet
,
self
).
__init__
()
if
depth
not
in
self
.
arch_settings
:
...
...
@@ -225,17 +249,29 @@ class ResNet(nn.Module):
assert
len
(
strides
)
==
len
(
dilations
)
==
num_stages
assert
max
(
out_indices
)
<
num_stages
assert
isinstance
(
normalize
,
dict
)
and
'type'
in
normalize
assert
normalize
[
'type'
]
in
[
'BN'
,
'GN'
]
if
normalize
[
'type'
]
==
'GN'
:
assert
'num_groups'
in
normalize
else
:
assert
(
set
([
'type'
,
'frozen_stages'
,
'bn_eval'
,
'bn_frozen'
])
==
set
(
normalize
))
self
.
out_indices
=
out_indices
self
.
style
=
style
self
.
frozen_stages
=
frozen_stages
self
.
bn_eval
=
bn_eval
self
.
bn_frozen
=
bn_frozen
self
.
with_cp
=
with_cp
if
normalize
[
'type'
]
==
'BN'
:
self
.
frozen_stages
=
normalize
[
'frozen_stages'
]
self
.
bn_eval
=
normalize
[
'bn_eval'
]
self
.
bn_frozen
=
normalize
[
'bn_frozen'
]
self
.
normalize
=
normalize
self
.
inplanes
=
64
self
.
conv1
=
nn
.
Conv2d
(
3
,
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
self
.
bn1
=
nn
.
BatchNorm2d
(
64
)
stem_norm
=
build_norm_layer
(
normalize
,
64
)
self
.
stem_norm_name
=
'gn1'
if
normalize
[
'type'
]
==
'GN'
else
'bn1'
self
.
add_module
(
self
.
stem_norm_name
,
stem_norm
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
...
...
@@ -252,7 +288,8 @@ class ResNet(nn.Module):
stride
=
stride
,
dilation
=
dilation
,
style
=
self
.
style
,
with_cp
=
with_cp
)
with_cp
=
with_cp
,
normalize
=
normalize
)
self
.
inplanes
=
planes
*
block
.
expansion
layer_name
=
'layer{}'
.
format
(
i
+
1
)
self
.
add_module
(
layer_name
,
res_layer
)
...
...
@@ -270,12 +307,18 @@ class ResNet(nn.Module):
kaiming_init
(
m
)
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
constant_init
(
m
,
1
)
# zero init for last norm layer https://arxiv.org/abs/1706.02677
for
m
in
self
.
modules
():
if
isinstance
(
m
,
Bottleneck
)
or
isinstance
(
m
,
BasicBlock
):
last_norm
=
getattr
(
m
,
m
.
norm_names
[
-
1
])
constant_init
(
last_norm
,
0
)
else
:
raise
TypeError
(
'pretrained must be a str or None'
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
getattr
(
self
,
self
.
stem_norm_name
)
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
maxpool
(
x
)
outs
=
[]
...
...
@@ -291,23 +334,120 @@ class ResNet(nn.Module):
def
train
(
self
,
mode
=
True
):
super
(
ResNet
,
self
).
train
(
mode
)
if
self
.
bn_eval
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
BatchNorm2d
):
m
.
eval
()
if
self
.
bn_frozen
:
for
params
in
m
.
parameters
():
params
.
requires_grad
=
False
if
mode
and
self
.
frozen_stages
>=
0
:
for
param
in
self
.
conv1
.
parameters
():
param
.
requires_grad
=
False
for
param
in
self
.
bn1
.
parameters
():
param
.
requires_grad
=
False
self
.
bn1
.
eval
()
self
.
bn1
.
weight
.
requires_grad
=
False
self
.
bn1
.
bias
.
requires_grad
=
False
for
i
in
range
(
1
,
self
.
frozen_stages
+
1
):
mod
=
getattr
(
self
,
'layer{}'
.
format
(
i
))
mod
.
eval
()
for
param
in
mod
.
parameters
():
if
self
.
normalize
[
'type'
]
==
'BN'
:
if
self
.
bn_eval
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
BatchNorm2d
):
m
.
eval
()
if
self
.
bn_frozen
:
for
params
in
m
.
parameters
():
params
.
requires_grad
=
False
if
mode
and
self
.
frozen_stages
>=
0
:
for
param
in
self
.
conv1
.
parameters
():
param
.
requires_grad
=
False
for
param
in
self
.
bn1
.
parameters
():
param
.
requires_grad
=
False
self
.
bn1
.
eval
()
self
.
bn1
.
weight
.
requires_grad
=
False
self
.
bn1
.
bias
.
requires_grad
=
False
for
i
in
range
(
1
,
self
.
frozen_stages
+
1
):
mod
=
getattr
(
self
,
'layer{}'
.
format
(
i
))
mod
.
eval
()
for
param
in
mod
.
parameters
():
param
.
requires_grad
=
False
class
ResNetClassifier
(
ResNet
):
def
__init__
(
self
,
depth
,
num_stages
=
4
,
strides
=
(
1
,
2
,
2
,
2
),
dilations
=
(
1
,
1
,
1
,
1
),
out_indices
=
(
0
,
1
,
2
,
3
),
style
=
'pytorch'
,
normalize
=
dict
(
type
=
'BN'
,
frozen_stages
=-
1
,
bn_eval
=
True
,
bn_frozen
=
False
),
with_cp
=
False
,
num_classes
=
1000
):
super
(
ResNetClassifier
,
self
).
__init__
(
depth
,
num_stages
=
num_stages
,
strides
=
strides
,
dilations
=
dilations
,
out_indices
=
out_indices
,
style
=
style
,
normalize
=
normalize
,
with_cp
=
with_cp
)
_
,
self
.
stage_blocks
=
self
.
arch_settings
[
depth
]
self
.
avgpool
=
nn
.
AdaptiveAvgPool2d
((
1
,
1
))
expansion
=
1
if
depth
==
18
else
4
self
.
fc
=
nn
.
Linear
(
512
*
expansion
,
num_classes
)
self
.
init_weights
()
# TODO can be removed after tested
def
load_caffe2_weight
(
self
,
cf_path
):
norm
=
'gn'
if
self
.
normalize
[
'type'
]
==
'GN'
else
'bn'
mapping
=
{}
for
layer
,
blocks_in_layer
in
enumerate
(
self
.
stage_blocks
,
1
):
for
blk
in
range
(
blocks_in_layer
):
cf_prefix
=
'res%d_%d_'
%
(
layer
+
1
,
blk
)
py_prefix
=
'layer%d.%d.'
%
(
layer
,
blk
)
# conv branch
for
i
,
a
in
zip
([
1
,
2
,
3
],
[
'a'
,
'b'
,
'c'
]):
cf_full
=
cf_prefix
+
'branch2%s_'
%
a
mapping
[
py_prefix
+
'conv%d.weight'
%
i
]
=
cf_full
+
'w'
mapping
[
py_prefix
+
norm
+
'%d.weight'
%
i
]
\
=
cf_full
+
norm
+
'_s'
mapping
[
py_prefix
+
norm
+
'%d.bias'
%
i
]
\
=
cf_full
+
norm
+
'_b'
# downsample branch
cf_full
=
'res%d_0_branch1_'
%
(
layer
+
1
)
py_full
=
'layer%d.0.downsample.'
%
layer
mapping
[
py_full
+
'0.weight'
]
=
cf_full
+
'w'
mapping
[
py_full
+
'1.weight'
]
=
cf_full
+
norm
+
'_s'
mapping
[
py_full
+
'1.bias'
]
=
cf_full
+
norm
+
'_b'
# stem layers and last fc layer
if
self
.
normalize
[
'type'
]
==
'GN'
:
mapping
[
'conv1.weight'
]
=
'conv1_w'
mapping
[
'gn1.weight'
]
=
'conv1_gn_s'
mapping
[
'gn1.bias'
]
=
'conv1_gn_b'
mapping
[
'fc.weight'
]
=
'pred_w'
mapping
[
'fc.bias'
]
=
'pred_b'
else
:
mapping
[
'conv1.weight'
]
=
'conv1_w'
mapping
[
'bn1.weight'
]
=
'res_conv1_bn_s'
mapping
[
'bn1.bias'
]
=
'res_conv1_bn_b'
mapping
[
'fc.weight'
]
=
'fc1000_w'
mapping
[
'fc.bias'
]
=
'fc1000_b'
# load state dict
py_state
=
self
.
state_dict
()
with
open
(
cf_path
,
'rb'
)
as
f
:
cf_state
=
pickle
.
load
(
f
,
encoding
=
'latin1'
)
if
'blobs'
in
cf_state
:
cf_state
=
cf_state
[
'blobs'
]
for
py_k
,
cf_k
in
mapping
.
items
():
print
(
'Loading {} to {}'
.
format
(
cf_k
,
py_k
))
assert
py_k
in
py_state
and
cf_k
in
cf_state
py_state
[
py_k
]
=
torch
.
Tensor
(
cf_state
[
cf_k
])
self
.
load_state_dict
(
py_state
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
getattr
(
self
,
self
.
stem_norm_name
)(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
maxpool
(
x
)
for
i
,
layer_name
in
enumerate
(
self
.
res_layers
):
res_layer
=
getattr
(
self
,
layer_name
)
x
=
res_layer
(
x
)
x
=
self
.
avgpool
(
x
)
x
=
x
.
view
(
x
.
size
(
0
),
-
1
)
x
=
self
.
fc
(
x
)
return
x
mmdet/models/bbox_heads/convfc_bbox_head.py
View file @
8500c14e
import
torch.nn
as
nn
from
.bbox_head
import
BBoxHead
from
..utils
import
ConvModule
,
build_norm_layer
from
..utils
import
ConvModule
class
ConvFCBBoxHead
(
BBoxHead
):
...
...
@@ -113,14 +113,8 @@ class ConvFCBBoxHead(BBoxHead):
for
i
in
range
(
num_branch_fcs
):
fc_in_channels
=
(
last_layer_dim
if
i
==
0
else
self
.
fc_out_channels
)
if
self
.
normalize
is
not
None
:
branch_fcs
.
append
(
nn
.
Sequential
(
nn
.
Linear
(
fc_in_channels
,
self
.
fc_out_channels
,
False
),
build_norm_layer
(
self
.
normalize
,
self
.
fc_out_channels
))
)
else
:
branch_fcs
.
append
(
nn
.
Linear
(
fc_in_channels
,
self
.
fc_out_channels
))
branch_fcs
.
append
(
nn
.
Linear
(
fc_in_channels
,
self
.
fc_out_channels
))
last_layer_dim
=
self
.
fc_out_channels
return
branch_convs
,
branch_fcs
,
last_layer_dim
...
...
@@ -130,8 +124,7 @@ class ConvFCBBoxHead(BBoxHead):
for
m
in
module_list
.
modules
():
if
isinstance
(
m
,
nn
.
Linear
):
nn
.
init
.
xavier_uniform_
(
m
.
weight
)
if
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
forward
(
self
,
x
):
# shared part
...
...
mmdet/models/utils/norm.py
View file @
8500c14e
...
...
@@ -6,14 +6,16 @@ norm_cfg = {'BN': nn.BatchNorm2d, 'SyncBN': None, 'GN': nn.GroupNorm}
def
build_norm_layer
(
cfg
,
num_features
):
assert
isinstance
(
cfg
,
dict
)
and
'type'
in
cfg
cfg_
=
cfg
.
copy
()
cfg_
.
setdefault
(
'eps'
,
1e-5
)
layer_type
=
cfg_
.
pop
(
'type'
)
# args name matching
if
layer_type
==
'GN'
:
assert
'num_groups'
in
cfg
cfg_
.
setdefault
(
'num_channels'
,
num_features
)
else
:
elif
layer_type
==
'BN'
:
cfg_
=
dict
()
# rewrite neccessary info for BN from here
cfg_
.
setdefault
(
'num_features'
,
num_features
)
cfg_
.
setdefault
(
'eps'
,
1e-5
)
if
layer_type
not
in
norm_cfg
:
raise
KeyError
(
'Unrecognized norm type {}'
.
format
(
layer_type
))
...
...
tools/train_imagenet/train_imagenet.py
0 → 100644
View file @
8500c14e
import
argparse
import
os
import
random
import
shutil
import
time
import
warnings
import
sys
import
torch
import
torch.nn
as
nn
import
torch.nn.parallel
import
torch.backends.cudnn
as
cudnn
import
torch.distributed
as
dist
import
torch.optim
import
torch.multiprocessing
as
mp
import
torch.utils.data
import
torch.utils.data.distributed
import
torchvision.transforms
as
transforms
import
torchvision.datasets
as
datasets
import
torchvision.models
as
models
from
mmdet.models.backbones.resnet
import
*
model_names
=
sorted
(
name
for
name
in
models
.
__dict__
if
name
.
islower
()
and
not
name
.
startswith
(
"__"
)
and
callable
(
models
.
__dict__
[
name
]))
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch ImageNet Training'
)
parser
.
add_argument
(
'data'
,
metavar
=
'DIR'
,
help
=
'path to dataset'
)
parser
.
add_argument
(
'-a'
,
'--arch'
,
metavar
=
'ARCH'
,
default
=
'resnet18'
,
choices
=
model_names
,
help
=
'model architecture: '
+
' | '
.
join
(
model_names
)
+
' (default: resnet18)'
)
parser
.
add_argument
(
'-j'
,
'--workers'
,
default
=
4
,
type
=
int
,
metavar
=
'N'
,
help
=
'number of data loading workers (default: 4)'
)
parser
.
add_argument
(
'--epochs'
,
default
=
90
,
type
=
int
,
metavar
=
'N'
,
help
=
'number of total epochs to run'
)
parser
.
add_argument
(
'--start-epoch'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'manual epoch number (useful on restarts)'
)
parser
.
add_argument
(
'-b'
,
'--batch-size'
,
default
=
256
,
type
=
int
,
metavar
=
'N'
,
help
=
'mini-batch size (default: 256), this is the total '
'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel'
)
parser
.
add_argument
(
'--lr'
,
'--learning-rate'
,
default
=
0.1
,
type
=
float
,
metavar
=
'LR'
,
help
=
'initial learning rate'
,
dest
=
'lr'
)
parser
.
add_argument
(
'--momentum'
,
default
=
0.9
,
type
=
float
,
metavar
=
'M'
,
help
=
'momentum'
)
parser
.
add_argument
(
'--wd'
,
'--weight-decay'
,
default
=
1e-4
,
type
=
float
,
metavar
=
'W'
,
help
=
'weight decay (default: 1e-4)'
,
dest
=
'weight_decay'
)
parser
.
add_argument
(
'-p'
,
'--print-freq'
,
default
=
10
,
type
=
int
,
metavar
=
'N'
,
help
=
'print frequency (default: 10)'
)
parser
.
add_argument
(
'--resume'
,
default
=
''
,
type
=
str
,
metavar
=
'PATH'
,
help
=
'path to latest checkpoint (default: none)'
)
parser
.
add_argument
(
'-e'
,
'--evaluate'
,
dest
=
'evaluate'
,
action
=
'store_true'
,
help
=
'evaluate model on validation set'
)
parser
.
add_argument
(
'--pretrained'
,
dest
=
'pretrained'
,
action
=
'store_true'
,
help
=
'use pre-trained model'
)
parser
.
add_argument
(
'--world-size'
,
default
=-
1
,
type
=
int
,
help
=
'number of nodes for distributed training'
)
parser
.
add_argument
(
'--rank'
,
default
=-
1
,
type
=
int
,
help
=
'node rank for distributed training'
)
parser
.
add_argument
(
'--dist-url'
,
default
=
'tcp://224.66.41.62:23456'
,
type
=
str
,
help
=
'url used to set up distributed training'
)
parser
.
add_argument
(
'--dist-backend'
,
default
=
'nccl'
,
type
=
str
,
help
=
'distributed backend'
)
parser
.
add_argument
(
'--seed'
,
default
=
None
,
type
=
int
,
help
=
'seed for initializing training. '
)
parser
.
add_argument
(
'--gpu'
,
default
=
None
,
type
=
int
,
help
=
'GPU id to use.'
)
parser
.
add_argument
(
'--multiprocessing-distributed'
,
action
=
'store_true'
,
help
=
'Use multi-processing distributed training to launch '
'N processes per node, which has N GPUs. This is the '
'fastest way to use PyTorch for either single node or '
'multi node data parallel training'
)
parser
.
add_argument
(
'--cf_path'
,
type
=
str
,
default
=
'.'
)
best_acc1
=
0
def
main
():
args
=
parser
.
parse_args
()
if
args
.
seed
is
not
None
:
random
.
seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
)
cudnn
.
deterministic
=
True
warnings
.
warn
(
'You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.'
)
if
args
.
gpu
is
not
None
:
warnings
.
warn
(
'You have chosen a specific GPU. This will completely '
'disable data parallelism.'
)
if
args
.
dist_url
==
"env://"
and
args
.
world_size
==
-
1
:
args
.
world_size
=
int
(
os
.
environ
[
"WORLD_SIZE"
])
args
.
distributed
=
args
.
world_size
>
1
or
args
.
multiprocessing_distributed
ngpus_per_node
=
torch
.
cuda
.
device_count
()
if
args
.
multiprocessing_distributed
:
# Since we have ngpus_per_node processes per node, the total world_size
# needs to be adjusted accordingly
args
.
world_size
=
ngpus_per_node
*
args
.
world_size
# Use torch.multiprocessing.spawn to launch distributed processes: the
# main_worker process function
mp
.
spawn
(
main_worker
,
nprocs
=
ngpus_per_node
,
args
=
(
ngpus_per_node
,
args
))
else
:
# Simply call main_worker function
main_worker
(
args
.
gpu
,
ngpus_per_node
,
args
)
def
main_worker
(
gpu
,
ngpus_per_node
,
args
):
global
best_acc1
args
.
gpu
=
gpu
if
args
.
gpu
is
not
None
:
print
(
"Use GPU: {} for training"
.
format
(
args
.
gpu
))
if
args
.
distributed
:
if
args
.
dist_url
==
"env://"
and
args
.
rank
==
-
1
:
args
.
rank
=
int
(
os
.
environ
[
"RANK"
])
if
args
.
multiprocessing_distributed
:
# For multiprocessing distributed training, rank needs to be the
# global rank among all the processes
args
.
rank
=
args
.
rank
*
ngpus_per_node
+
gpu
dist
.
init_process_group
(
backend
=
args
.
dist_backend
,
init_method
=
args
.
dist_url
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
)
# create model
#if args.pretrained:
# print("=> using pre-trained model '{}'".format(args.arch))
# model = models.__dict__[args.arch](pretrained=True)
#else:
# print("=> creating model '{}'".format(args.arch))
# model = models.__dict__[args.arch]()
model
=
ResNetClassifier
(
50
,
normalize
=
dict
(
type
=
'GN'
,
num_groups
=
32
))
model
.
load_caffe2_weight
(
args
.
cf_path
)
if
args
.
distributed
:
# For multiprocessing distributed, DistributedDataParallel constructor
# should always set the single device scope, otherwise,
# DistributedDataParallel will use all available devices.
if
args
.
gpu
is
not
None
:
torch
.
cuda
.
set_device
(
args
.
gpu
)
model
.
cuda
(
args
.
gpu
)
# When using a single GPU per process and per
# DistributedDataParallel, we need to divide the batch size
# ourselves based on the total number of GPUs we have
args
.
batch_size
=
int
(
args
.
batch_size
/
ngpus_per_node
)
model
=
torch
.
nn
.
parallel
.
DistributedDataParallel
(
model
,
device_ids
=
[
args
.
gpu
])
else
:
model
.
cuda
()
# DistributedDataParallel will divide and allocate batch_size to all
# available GPUs if device_ids are not set
model
=
torch
.
nn
.
parallel
.
DistributedDataParallel
(
model
)
elif
args
.
gpu
is
not
None
:
torch
.
cuda
.
set_device
(
args
.
gpu
)
model
=
model
.
cuda
(
args
.
gpu
)
else
:
# DataParallel will divide and allocate batch_size to all available GPUs
if
args
.
arch
.
startswith
(
'alexnet'
)
or
args
.
arch
.
startswith
(
'vgg'
):
model
.
features
=
torch
.
nn
.
DataParallel
(
model
.
features
)
model
.
cuda
()
else
:
model
=
torch
.
nn
.
DataParallel
(
model
).
cuda
()
# define loss function (criterion) and optimizer
criterion
=
nn
.
CrossEntropyLoss
().
cuda
(
args
.
gpu
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
args
.
lr
,
momentum
=
args
.
momentum
,
weight_decay
=
args
.
weight_decay
)
# optionally resume from a checkpoint
if
args
.
resume
:
if
os
.
path
.
isfile
(
args
.
resume
):
print
(
"=> loading checkpoint '{}'"
.
format
(
args
.
resume
))
checkpoint
=
torch
.
load
(
args
.
resume
)
args
.
start_epoch
=
checkpoint
[
'epoch'
]
best_acc1
=
checkpoint
[
'best_acc1'
]
model
.
load_state_dict
(
checkpoint
[
'state_dict'
])
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer'
])
print
(
"=> loaded checkpoint '{}' (epoch {})"
.
format
(
args
.
resume
,
checkpoint
[
'epoch'
]))
else
:
print
(
"=> no checkpoint found at '{}'"
.
format
(
args
.
resume
))
cudnn
.
benchmark
=
True
# Data loading code
traindir
=
os
.
path
.
join
(
args
.
data
,
'train'
)
valdir
=
os
.
path
.
join
(
args
.
data
,
'val'
)
normalize
=
transforms
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
])
train_dataset
=
datasets
.
ImageFolder
(
traindir
,
transforms
.
Compose
([
transforms
.
RandomResizedCrop
(
224
),
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToTensor
(),
normalize
,
]))
if
args
.
distributed
:
train_sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
train_dataset
)
else
:
train_sampler
=
None
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
(
train_sampler
is
None
),
num_workers
=
args
.
workers
,
pin_memory
=
True
,
sampler
=
train_sampler
)
val_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
ImageFolder
(
valdir
,
transforms
.
Compose
([
transforms
.
Resize
(
256
),
transforms
.
CenterCrop
(
224
),
transforms
.
ToTensor
(),
normalize
,
])),
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
num_workers
=
args
.
workers
,
pin_memory
=
True
)
if
args
.
evaluate
:
validate
(
val_loader
,
model
,
criterion
,
args
)
return
for
epoch
in
range
(
args
.
start_epoch
,
args
.
epochs
):
if
args
.
distributed
:
train_sampler
.
set_epoch
(
epoch
)
adjust_learning_rate
(
optimizer
,
epoch
,
args
)
# train for one epoch
train
(
train_loader
,
model
,
criterion
,
optimizer
,
epoch
,
args
)
# evaluate on validation set
acc1
=
validate
(
val_loader
,
model
,
criterion
,
args
)
# remember best acc@1 and save checkpoint
is_best
=
acc1
>
best_acc1
best_acc1
=
max
(
acc1
,
best_acc1
)
if
not
args
.
multiprocessing_distributed
or
(
args
.
multiprocessing_distributed
and
args
.
rank
%
ngpus_per_node
==
0
):
save_checkpoint
({
'epoch'
:
epoch
+
1
,
'arch'
:
args
.
arch
,
'state_dict'
:
model
.
state_dict
(),
'best_acc1'
:
best_acc1
,
'optimizer'
:
optimizer
.
state_dict
(),
},
is_best
)
def
train
(
train_loader
,
model
,
criterion
,
optimizer
,
epoch
,
args
):
batch_time
=
AverageMeter
()
data_time
=
AverageMeter
()
losses
=
AverageMeter
()
top1
=
AverageMeter
()
top5
=
AverageMeter
()
# switch to train mode
model
.
train
()
end
=
time
.
time
()
for
i
,
(
input
,
target
)
in
enumerate
(
train_loader
):
# measure data loading time
data_time
.
update
(
time
.
time
()
-
end
)
if
args
.
gpu
is
not
None
:
input
=
input
.
cuda
(
args
.
gpu
,
non_blocking
=
True
)
target
=
target
.
cuda
(
args
.
gpu
,
non_blocking
=
True
)
# compute output
output
=
model
(
input
)
loss
=
criterion
(
output
,
target
)
# measure accuracy and record loss
acc1
,
acc5
=
accuracy
(
output
,
target
,
topk
=
(
1
,
5
))
losses
.
update
(
loss
.
item
(),
input
.
size
(
0
))
top1
.
update
(
acc1
[
0
],
input
.
size
(
0
))
top5
.
update
(
acc5
[
0
],
input
.
size
(
0
))
# compute gradient and do SGD step
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
# measure elapsed time
batch_time
.
update
(
time
.
time
()
-
end
)
end
=
time
.
time
()
if
i
%
args
.
print_freq
==
0
:
print
(
'Epoch: [{0}][{1}/{2}]
\t
'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})
\t
'
'Data {data_time.val:.3f} ({data_time.avg:.3f})
\t
'
'Loss {loss.val:.4f} ({loss.avg:.4f})
\t
'
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})
\t
'
'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'
.
format
(
epoch
,
i
,
len
(
train_loader
),
batch_time
=
batch_time
,
data_time
=
data_time
,
loss
=
losses
,
top1
=
top1
,
top5
=
top5
))
def
validate
(
val_loader
,
model
,
criterion
,
args
):
batch_time
=
AverageMeter
()
losses
=
AverageMeter
()
top1
=
AverageMeter
()
top5
=
AverageMeter
()
# switch to evaluate mode
model
.
eval
()
with
torch
.
no_grad
():
end
=
time
.
time
()
for
i
,
(
input
,
target
)
in
enumerate
(
val_loader
):
if
args
.
gpu
is
not
None
:
input
=
input
.
cuda
(
args
.
gpu
,
non_blocking
=
True
)
target
=
target
.
cuda
(
args
.
gpu
,
non_blocking
=
True
)
# compute output
output
=
model
(
input
)
loss
=
criterion
(
output
,
target
)
# measure accuracy and record loss
acc1
,
acc5
=
accuracy
(
output
,
target
,
topk
=
(
1
,
5
))
losses
.
update
(
loss
.
item
(),
input
.
size
(
0
))
top1
.
update
(
acc1
[
0
],
input
.
size
(
0
))
top5
.
update
(
acc5
[
0
],
input
.
size
(
0
))
# measure elapsed time
batch_time
.
update
(
time
.
time
()
-
end
)
end
=
time
.
time
()
if
i
%
args
.
print_freq
==
0
:
print
(
'Test: [{0}/{1}]
\t
'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})
\t
'
'Loss {loss.val:.4f} ({loss.avg:.4f})
\t
'
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})
\t
'
'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'
.
format
(
i
,
len
(
val_loader
),
batch_time
=
batch_time
,
loss
=
losses
,
top1
=
top1
,
top5
=
top5
))
print
(
' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
.
format
(
top1
=
top1
,
top5
=
top5
))
return
top1
.
avg
def
save_checkpoint
(
state
,
is_best
,
filename
=
'checkpoint.pth.tar'
):
torch
.
save
(
state
,
filename
)
if
is_best
:
shutil
.
copyfile
(
filename
,
'model_best.pth.tar'
)
class
AverageMeter
(
object
):
"""Computes and stores the average and current value"""
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
val
=
0
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
val
,
n
=
1
):
self
.
val
=
val
self
.
sum
+=
val
*
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
def
adjust_learning_rate
(
optimizer
,
epoch
,
args
):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr
=
args
.
lr
*
(
0.1
**
(
epoch
//
30
))
for
param_group
in
optimizer
.
param_groups
:
param_group
[
'lr'
]
=
lr
def
accuracy
(
output
,
target
,
topk
=
(
1
,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with
torch
.
no_grad
():
maxk
=
max
(
topk
)
batch_size
=
target
.
size
(
0
)
_
,
pred
=
output
.
topk
(
maxk
,
1
,
True
,
True
)
pred
=
pred
.
t
()
correct
=
pred
.
eq
(
target
.
view
(
1
,
-
1
).
expand_as
(
pred
))
res
=
[]
for
k
in
topk
:
correct_k
=
correct
[:
k
].
view
(
-
1
).
float
().
sum
(
0
,
keepdim
=
True
)
res
.
append
(
correct_k
.
mul_
(
100.0
/
batch_size
))
return
res
if
__name__
==
'__main__'
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment