Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
50d9dc5f
Commit
50d9dc5f
authored
Nov 12, 2018
by
Amir Arsalan Soltani
Committed by
Francisco Massa
Nov 12, 2018
Browse files
Update densenet.py (#658)
* Update densenet.py * Update densenet.py
parent
e3759081
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
101 additions
and
101 deletions
+101
-101
torchvision/models/densenet.py
torchvision/models/densenet.py
+101
-101
No files found.
torchvision/models/densenet.py
View file @
50d9dc5f
...
...
@@ -16,6 +16,107 @@ model_urls = {
}
class
_DenseLayer
(
nn
.
Sequential
):
def
__init__
(
self
,
num_input_features
,
growth_rate
,
bn_size
,
drop_rate
):
super
(
_DenseLayer
,
self
).
__init__
()
self
.
add_module
(
'norm1'
,
nn
.
BatchNorm2d
(
num_input_features
)),
self
.
add_module
(
'relu1'
,
nn
.
ReLU
(
inplace
=
True
)),
self
.
add_module
(
'conv1'
,
nn
.
Conv2d
(
num_input_features
,
bn_size
*
growth_rate
,
kernel_size
=
1
,
stride
=
1
,
bias
=
False
)),
self
.
add_module
(
'norm2'
,
nn
.
BatchNorm2d
(
bn_size
*
growth_rate
)),
self
.
add_module
(
'relu2'
,
nn
.
ReLU
(
inplace
=
True
)),
self
.
add_module
(
'conv2'
,
nn
.
Conv2d
(
bn_size
*
growth_rate
,
growth_rate
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
)),
self
.
drop_rate
=
drop_rate
def
forward
(
self
,
x
):
new_features
=
super
(
_DenseLayer
,
self
).
forward
(
x
)
if
self
.
drop_rate
>
0
:
new_features
=
F
.
dropout
(
new_features
,
p
=
self
.
drop_rate
,
training
=
self
.
training
)
return
torch
.
cat
([
x
,
new_features
],
1
)
class
_DenseBlock
(
nn
.
Sequential
):
def
__init__
(
self
,
num_layers
,
num_input_features
,
bn_size
,
growth_rate
,
drop_rate
):
super
(
_DenseBlock
,
self
).
__init__
()
for
i
in
range
(
num_layers
):
layer
=
_DenseLayer
(
num_input_features
+
i
*
growth_rate
,
growth_rate
,
bn_size
,
drop_rate
)
self
.
add_module
(
'denselayer%d'
%
(
i
+
1
),
layer
)
class
_Transition
(
nn
.
Sequential
):
def
__init__
(
self
,
num_input_features
,
num_output_features
):
super
(
_Transition
,
self
).
__init__
()
self
.
add_module
(
'norm'
,
nn
.
BatchNorm2d
(
num_input_features
))
self
.
add_module
(
'relu'
,
nn
.
ReLU
(
inplace
=
True
))
self
.
add_module
(
'conv'
,
nn
.
Conv2d
(
num_input_features
,
num_output_features
,
kernel_size
=
1
,
stride
=
1
,
bias
=
False
))
self
.
add_module
(
'pool'
,
nn
.
AvgPool2d
(
kernel_size
=
2
,
stride
=
2
))
class
DenseNet
(
nn
.
Module
):
r
"""Densenet-BC model class, based on
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
growth_rate (int) - how many filters to add each layer (`k` in paper)
block_config (list of 4 ints) - how many layers in each pooling block
num_init_features (int) - the number of filters to learn in the first convolution layer
bn_size (int) - multiplicative factor for number of bottle neck layers
(i.e. bn_size * k features in the bottleneck layer)
drop_rate (float) - dropout rate after each dense layer
num_classes (int) - number of classification classes
"""
def
__init__
(
self
,
growth_rate
=
32
,
block_config
=
(
6
,
12
,
24
,
16
),
num_init_features
=
64
,
bn_size
=
4
,
drop_rate
=
0
,
num_classes
=
1000
):
super
(
DenseNet
,
self
).
__init__
()
# First convolution
self
.
features
=
nn
.
Sequential
(
OrderedDict
([
(
'conv0'
,
nn
.
Conv2d
(
3
,
num_init_features
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)),
(
'norm0'
,
nn
.
BatchNorm2d
(
num_init_features
)),
(
'relu0'
,
nn
.
ReLU
(
inplace
=
True
)),
(
'pool0'
,
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)),
]))
# Each denseblock
num_features
=
num_init_features
for
i
,
num_layers
in
enumerate
(
block_config
):
block
=
_DenseBlock
(
num_layers
=
num_layers
,
num_input_features
=
num_features
,
bn_size
=
bn_size
,
growth_rate
=
growth_rate
,
drop_rate
=
drop_rate
)
self
.
features
.
add_module
(
'denseblock%d'
%
(
i
+
1
),
block
)
num_features
=
num_features
+
num_layers
*
growth_rate
if
i
!=
len
(
block_config
)
-
1
:
trans
=
_Transition
(
num_input_features
=
num_features
,
num_output_features
=
num_features
//
2
)
self
.
features
.
add_module
(
'transition%d'
%
(
i
+
1
),
trans
)
num_features
=
num_features
//
2
# Final batch norm
self
.
features
.
add_module
(
'norm5'
,
nn
.
BatchNorm2d
(
num_features
))
# Linear layer
self
.
classifier
=
nn
.
Linear
(
num_features
,
num_classes
)
# Official init from torch repo.
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
nn
.
init
.
kaiming_normal_
(
m
.
weight
)
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
nn
.
init
.
constant_
(
m
.
weight
,
1
)
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
Linear
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
forward
(
self
,
x
):
features
=
self
.
features
(
x
)
out
=
F
.
relu
(
features
,
inplace
=
True
)
out
=
F
.
avg_pool2d
(
out
,
kernel_size
=
7
,
stride
=
1
).
view
(
features
.
size
(
0
),
-
1
)
out
=
self
.
classifier
(
out
)
return
out
def
densenet121
(
pretrained
=
False
,
**
kwargs
):
r
"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
...
...
@@ -122,104 +223,3 @@ def densenet161(pretrained=False, **kwargs):
del
state_dict
[
key
]
model
.
load_state_dict
(
state_dict
)
return
model
class
_DenseLayer
(
nn
.
Sequential
):
def
__init__
(
self
,
num_input_features
,
growth_rate
,
bn_size
,
drop_rate
):
super
(
_DenseLayer
,
self
).
__init__
()
self
.
add_module
(
'norm1'
,
nn
.
BatchNorm2d
(
num_input_features
)),
self
.
add_module
(
'relu1'
,
nn
.
ReLU
(
inplace
=
True
)),
self
.
add_module
(
'conv1'
,
nn
.
Conv2d
(
num_input_features
,
bn_size
*
growth_rate
,
kernel_size
=
1
,
stride
=
1
,
bias
=
False
)),
self
.
add_module
(
'norm2'
,
nn
.
BatchNorm2d
(
bn_size
*
growth_rate
)),
self
.
add_module
(
'relu2'
,
nn
.
ReLU
(
inplace
=
True
)),
self
.
add_module
(
'conv2'
,
nn
.
Conv2d
(
bn_size
*
growth_rate
,
growth_rate
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
)),
self
.
drop_rate
=
drop_rate
def
forward
(
self
,
x
):
new_features
=
super
(
_DenseLayer
,
self
).
forward
(
x
)
if
self
.
drop_rate
>
0
:
new_features
=
F
.
dropout
(
new_features
,
p
=
self
.
drop_rate
,
training
=
self
.
training
)
return
torch
.
cat
([
x
,
new_features
],
1
)
class
_DenseBlock
(
nn
.
Sequential
):
def
__init__
(
self
,
num_layers
,
num_input_features
,
bn_size
,
growth_rate
,
drop_rate
):
super
(
_DenseBlock
,
self
).
__init__
()
for
i
in
range
(
num_layers
):
layer
=
_DenseLayer
(
num_input_features
+
i
*
growth_rate
,
growth_rate
,
bn_size
,
drop_rate
)
self
.
add_module
(
'denselayer%d'
%
(
i
+
1
),
layer
)
class
_Transition
(
nn
.
Sequential
):
def
__init__
(
self
,
num_input_features
,
num_output_features
):
super
(
_Transition
,
self
).
__init__
()
self
.
add_module
(
'norm'
,
nn
.
BatchNorm2d
(
num_input_features
))
self
.
add_module
(
'relu'
,
nn
.
ReLU
(
inplace
=
True
))
self
.
add_module
(
'conv'
,
nn
.
Conv2d
(
num_input_features
,
num_output_features
,
kernel_size
=
1
,
stride
=
1
,
bias
=
False
))
self
.
add_module
(
'pool'
,
nn
.
AvgPool2d
(
kernel_size
=
2
,
stride
=
2
))
class
DenseNet
(
nn
.
Module
):
r
"""Densenet-BC model class, based on
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
growth_rate (int) - how many filters to add each layer (`k` in paper)
block_config (list of 4 ints) - how many layers in each pooling block
num_init_features (int) - the number of filters to learn in the first convolution layer
bn_size (int) - multiplicative factor for number of bottle neck layers
(i.e. bn_size * k features in the bottleneck layer)
drop_rate (float) - dropout rate after each dense layer
num_classes (int) - number of classification classes
"""
def
__init__
(
self
,
growth_rate
=
32
,
block_config
=
(
6
,
12
,
24
,
16
),
num_init_features
=
64
,
bn_size
=
4
,
drop_rate
=
0
,
num_classes
=
1000
):
super
(
DenseNet
,
self
).
__init__
()
# First convolution
self
.
features
=
nn
.
Sequential
(
OrderedDict
([
(
'conv0'
,
nn
.
Conv2d
(
3
,
num_init_features
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)),
(
'norm0'
,
nn
.
BatchNorm2d
(
num_init_features
)),
(
'relu0'
,
nn
.
ReLU
(
inplace
=
True
)),
(
'pool0'
,
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)),
]))
# Each denseblock
num_features
=
num_init_features
for
i
,
num_layers
in
enumerate
(
block_config
):
block
=
_DenseBlock
(
num_layers
=
num_layers
,
num_input_features
=
num_features
,
bn_size
=
bn_size
,
growth_rate
=
growth_rate
,
drop_rate
=
drop_rate
)
self
.
features
.
add_module
(
'denseblock%d'
%
(
i
+
1
),
block
)
num_features
=
num_features
+
num_layers
*
growth_rate
if
i
!=
len
(
block_config
)
-
1
:
trans
=
_Transition
(
num_input_features
=
num_features
,
num_output_features
=
num_features
//
2
)
self
.
features
.
add_module
(
'transition%d'
%
(
i
+
1
),
trans
)
num_features
=
num_features
//
2
# Final batch norm
self
.
features
.
add_module
(
'norm5'
,
nn
.
BatchNorm2d
(
num_features
))
# Linear layer
self
.
classifier
=
nn
.
Linear
(
num_features
,
num_classes
)
# Official init from torch repo.
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
nn
.
init
.
kaiming_normal_
(
m
.
weight
)
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
nn
.
init
.
constant_
(
m
.
weight
,
1
)
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
Linear
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
forward
(
self
,
x
):
features
=
self
.
features
(
x
)
out
=
F
.
relu
(
features
,
inplace
=
True
)
out
=
F
.
avg_pool2d
(
out
,
kernel_size
=
7
,
stride
=
1
).
view
(
features
.
size
(
0
),
-
1
)
out
=
self
.
classifier
(
out
)
return
out
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment