Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
5785e2b0
Unverified
Commit
5785e2b0
authored
Dec 12, 2022
by
Vasilis Vryniotis
Committed by
GitHub
Dec 12, 2022
Browse files
Allow dropout overwrites on EfficientNet (#7031)
parent
5a75fa9f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
14 deletions
+36
-14
test/smoke_test.py
test/smoke_test.py
+5
-3
torchvision/models/efficientnet.py
torchvision/models/efficientnet.py
+31
-11
No files found.
test/smoke_test.py
View file @
5785e2b0
...
@@ -17,6 +17,7 @@ def smoke_test_torchvision() -> None:
...
@@ -17,6 +17,7 @@ def smoke_test_torchvision() -> None:
all
(
x
is
not
None
for
x
in
[
torch
.
ops
.
image
.
decode_png
,
torch
.
ops
.
torchvision
.
roi_align
]),
all
(
x
is
not
None
for
x
in
[
torch
.
ops
.
image
.
decode_png
,
torch
.
ops
.
torchvision
.
roi_align
]),
)
)
def
smoke_test_torchvision_read_decode
()
->
None
:
def
smoke_test_torchvision_read_decode
()
->
None
:
img_jpg
=
read_image
(
str
(
SCRIPT_DIR
/
"assets"
/
"encode_jpeg"
/
"grace_hopper_517x606.jpg"
))
img_jpg
=
read_image
(
str
(
SCRIPT_DIR
/
"assets"
/
"encode_jpeg"
/
"grace_hopper_517x606.jpg"
))
if
img_jpg
.
ndim
!=
3
or
img_jpg
.
numel
()
<
100
:
if
img_jpg
.
ndim
!=
3
or
img_jpg
.
numel
()
<
100
:
...
@@ -25,6 +26,7 @@ def smoke_test_torchvision_read_decode() -> None:
...
@@ -25,6 +26,7 @@ def smoke_test_torchvision_read_decode() -> None:
if
img_png
.
ndim
!=
3
or
img_png
.
numel
()
<
100
:
if
img_png
.
ndim
!=
3
or
img_png
.
numel
()
<
100
:
raise
RuntimeError
(
f
"Unexpected shape of img_png:
{
img_png
.
shape
}
"
)
raise
RuntimeError
(
f
"Unexpected shape of img_png:
{
img_png
.
shape
}
"
)
def
smoke_test_torchvision_resnet50_classify
(
device
:
str
=
"cpu"
)
->
None
:
def
smoke_test_torchvision_resnet50_classify
(
device
:
str
=
"cpu"
)
->
None
:
img
=
read_image
(
str
(
SCRIPT_DIR
/
".."
/
"gallery"
/
"assets"
/
"dog2.jpg"
)).
to
(
device
)
img
=
read_image
(
str
(
SCRIPT_DIR
/
".."
/
"gallery"
/
"assets"
/
"dog2.jpg"
)).
to
(
device
)
...
@@ -47,9 +49,8 @@ def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
...
@@ -47,9 +49,8 @@ def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
expected_category
=
"German shepherd"
expected_category
=
"German shepherd"
print
(
f
"
{
category_name
}
(
{
device
}
):
{
100
*
score
:.
1
f
}
%"
)
print
(
f
"
{
category_name
}
(
{
device
}
):
{
100
*
score
:.
1
f
}
%"
)
if
category_name
!=
expected_category
:
if
category_name
!=
expected_category
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Failed ResNet50 classify
{
category_name
}
Expected:
{
expected_category
}
"
)
f
"Failed ResNet50 classify
{
category_name
}
Expected:
{
expected_category
}
"
)
def
main
()
->
None
:
def
main
()
->
None
:
print
(
f
"torchvision:
{
torchvision
.
__version__
}
"
)
print
(
f
"torchvision:
{
torchvision
.
__version__
}
"
)
...
@@ -59,5 +60,6 @@ def main() -> None:
...
@@ -59,5 +60,6 @@ def main() -> None:
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
smoke_test_torchvision_resnet50_classify
(
"cuda"
)
smoke_test_torchvision_resnet50_classify
(
"cuda"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
torchvision/models/efficientnet.py
View file @
5785e2b0
...
@@ -779,7 +779,9 @@ def efficientnet_b0(
...
@@ -779,7 +779,9 @@ def efficientnet_b0(
weights
=
EfficientNet_B0_Weights
.
verify
(
weights
)
weights
=
EfficientNet_B0_Weights
.
verify
(
weights
)
inverted_residual_setting
,
last_channel
=
_efficientnet_conf
(
"efficientnet_b0"
,
width_mult
=
1.0
,
depth_mult
=
1.0
)
inverted_residual_setting
,
last_channel
=
_efficientnet_conf
(
"efficientnet_b0"
,
width_mult
=
1.0
,
depth_mult
=
1.0
)
return
_efficientnet
(
inverted_residual_setting
,
0.2
,
last_channel
,
weights
,
progress
,
**
kwargs
)
return
_efficientnet
(
inverted_residual_setting
,
kwargs
.
pop
(
"dropout"
,
0.2
),
last_channel
,
weights
,
progress
,
**
kwargs
)
@
register_model
()
@
register_model
()
...
@@ -808,7 +810,9 @@ def efficientnet_b1(
...
@@ -808,7 +810,9 @@ def efficientnet_b1(
weights
=
EfficientNet_B1_Weights
.
verify
(
weights
)
weights
=
EfficientNet_B1_Weights
.
verify
(
weights
)
inverted_residual_setting
,
last_channel
=
_efficientnet_conf
(
"efficientnet_b1"
,
width_mult
=
1.0
,
depth_mult
=
1.1
)
inverted_residual_setting
,
last_channel
=
_efficientnet_conf
(
"efficientnet_b1"
,
width_mult
=
1.0
,
depth_mult
=
1.1
)
return
_efficientnet
(
inverted_residual_setting
,
0.2
,
last_channel
,
weights
,
progress
,
**
kwargs
)
return
_efficientnet
(
inverted_residual_setting
,
kwargs
.
pop
(
"dropout"
,
0.2
),
last_channel
,
weights
,
progress
,
**
kwargs
)
@
register_model
()
@
register_model
()
...
@@ -837,7 +841,9 @@ def efficientnet_b2(
...
@@ -837,7 +841,9 @@ def efficientnet_b2(
weights
=
EfficientNet_B2_Weights
.
verify
(
weights
)
weights
=
EfficientNet_B2_Weights
.
verify
(
weights
)
inverted_residual_setting
,
last_channel
=
_efficientnet_conf
(
"efficientnet_b2"
,
width_mult
=
1.1
,
depth_mult
=
1.2
)
inverted_residual_setting
,
last_channel
=
_efficientnet_conf
(
"efficientnet_b2"
,
width_mult
=
1.1
,
depth_mult
=
1.2
)
return
_efficientnet
(
inverted_residual_setting
,
0.3
,
last_channel
,
weights
,
progress
,
**
kwargs
)
return
_efficientnet
(
inverted_residual_setting
,
kwargs
.
pop
(
"dropout"
,
0.3
),
last_channel
,
weights
,
progress
,
**
kwargs
)
@
register_model
()
@
register_model
()
...
@@ -866,7 +872,14 @@ def efficientnet_b3(
...
@@ -866,7 +872,14 @@ def efficientnet_b3(
weights
=
EfficientNet_B3_Weights
.
verify
(
weights
)
weights
=
EfficientNet_B3_Weights
.
verify
(
weights
)
inverted_residual_setting
,
last_channel
=
_efficientnet_conf
(
"efficientnet_b3"
,
width_mult
=
1.2
,
depth_mult
=
1.4
)
inverted_residual_setting
,
last_channel
=
_efficientnet_conf
(
"efficientnet_b3"
,
width_mult
=
1.2
,
depth_mult
=
1.4
)
return
_efficientnet
(
inverted_residual_setting
,
0.3
,
last_channel
,
weights
,
progress
,
**
kwargs
)
return
_efficientnet
(
inverted_residual_setting
,
kwargs
.
pop
(
"dropout"
,
0.3
),
last_channel
,
weights
,
progress
,
**
kwargs
,
)
@
register_model
()
@
register_model
()
...
@@ -895,7 +908,14 @@ def efficientnet_b4(
...
@@ -895,7 +908,14 @@ def efficientnet_b4(
weights
=
EfficientNet_B4_Weights
.
verify
(
weights
)
weights
=
EfficientNet_B4_Weights
.
verify
(
weights
)
inverted_residual_setting
,
last_channel
=
_efficientnet_conf
(
"efficientnet_b4"
,
width_mult
=
1.4
,
depth_mult
=
1.8
)
inverted_residual_setting
,
last_channel
=
_efficientnet_conf
(
"efficientnet_b4"
,
width_mult
=
1.4
,
depth_mult
=
1.8
)
return
_efficientnet
(
inverted_residual_setting
,
0.4
,
last_channel
,
weights
,
progress
,
**
kwargs
)
return
_efficientnet
(
inverted_residual_setting
,
kwargs
.
pop
(
"dropout"
,
0.4
),
last_channel
,
weights
,
progress
,
**
kwargs
,
)
@
register_model
()
@
register_model
()
...
@@ -926,7 +946,7 @@ def efficientnet_b5(
...
@@ -926,7 +946,7 @@ def efficientnet_b5(
inverted_residual_setting
,
last_channel
=
_efficientnet_conf
(
"efficientnet_b5"
,
width_mult
=
1.6
,
depth_mult
=
2.2
)
inverted_residual_setting
,
last_channel
=
_efficientnet_conf
(
"efficientnet_b5"
,
width_mult
=
1.6
,
depth_mult
=
2.2
)
return
_efficientnet
(
return
_efficientnet
(
inverted_residual_setting
,
inverted_residual_setting
,
0.4
,
kwargs
.
pop
(
"dropout"
,
0.4
)
,
last_channel
,
last_channel
,
weights
,
weights
,
progress
,
progress
,
...
@@ -963,7 +983,7 @@ def efficientnet_b6(
...
@@ -963,7 +983,7 @@ def efficientnet_b6(
inverted_residual_setting
,
last_channel
=
_efficientnet_conf
(
"efficientnet_b6"
,
width_mult
=
1.8
,
depth_mult
=
2.6
)
inverted_residual_setting
,
last_channel
=
_efficientnet_conf
(
"efficientnet_b6"
,
width_mult
=
1.8
,
depth_mult
=
2.6
)
return
_efficientnet
(
return
_efficientnet
(
inverted_residual_setting
,
inverted_residual_setting
,
0.5
,
kwargs
.
pop
(
"dropout"
,
0.5
)
,
last_channel
,
last_channel
,
weights
,
weights
,
progress
,
progress
,
...
@@ -1000,7 +1020,7 @@ def efficientnet_b7(
...
@@ -1000,7 +1020,7 @@ def efficientnet_b7(
inverted_residual_setting
,
last_channel
=
_efficientnet_conf
(
"efficientnet_b7"
,
width_mult
=
2.0
,
depth_mult
=
3.1
)
inverted_residual_setting
,
last_channel
=
_efficientnet_conf
(
"efficientnet_b7"
,
width_mult
=
2.0
,
depth_mult
=
3.1
)
return
_efficientnet
(
return
_efficientnet
(
inverted_residual_setting
,
inverted_residual_setting
,
0.5
,
kwargs
.
pop
(
"dropout"
,
0.5
)
,
last_channel
,
last_channel
,
weights
,
weights
,
progress
,
progress
,
...
@@ -1038,7 +1058,7 @@ def efficientnet_v2_s(
...
@@ -1038,7 +1058,7 @@ def efficientnet_v2_s(
inverted_residual_setting
,
last_channel
=
_efficientnet_conf
(
"efficientnet_v2_s"
)
inverted_residual_setting
,
last_channel
=
_efficientnet_conf
(
"efficientnet_v2_s"
)
return
_efficientnet
(
return
_efficientnet
(
inverted_residual_setting
,
inverted_residual_setting
,
0.2
,
kwargs
.
pop
(
"dropout"
,
0.2
)
,
last_channel
,
last_channel
,
weights
,
weights
,
progress
,
progress
,
...
@@ -1076,7 +1096,7 @@ def efficientnet_v2_m(
...
@@ -1076,7 +1096,7 @@ def efficientnet_v2_m(
inverted_residual_setting
,
last_channel
=
_efficientnet_conf
(
"efficientnet_v2_m"
)
inverted_residual_setting
,
last_channel
=
_efficientnet_conf
(
"efficientnet_v2_m"
)
return
_efficientnet
(
return
_efficientnet
(
inverted_residual_setting
,
inverted_residual_setting
,
0.3
,
kwargs
.
pop
(
"dropout"
,
0.3
)
,
last_channel
,
last_channel
,
weights
,
weights
,
progress
,
progress
,
...
@@ -1114,7 +1134,7 @@ def efficientnet_v2_l(
...
@@ -1114,7 +1134,7 @@ def efficientnet_v2_l(
inverted_residual_setting
,
last_channel
=
_efficientnet_conf
(
"efficientnet_v2_l"
)
inverted_residual_setting
,
last_channel
=
_efficientnet_conf
(
"efficientnet_v2_l"
)
return
_efficientnet
(
return
_efficientnet
(
inverted_residual_setting
,
inverted_residual_setting
,
0.4
,
kwargs
.
pop
(
"dropout"
,
0.4
)
,
last_channel
,
last_channel
,
weights
,
weights
,
progress
,
progress
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment