Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6db23a1f
Commit
6db23a1f
authored
Jun 24, 2021
by
Atze00
Browse files
added tests and fixed some wrong behaviour
parent
bb3cc770
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
80 additions
and
8 deletions
+80
-8
official/vision/beta/modeling/layers/nn_layers.py
official/vision/beta/modeling/layers/nn_layers.py
+8
-8
official/vision/beta/modeling/layers/nn_layers_test.py
official/vision/beta/modeling/layers/nn_layers_test.py
+72
-0
No files found.
official/vision/beta/modeling/layers/nn_layers.py
View file @
6db23a1f
...
@@ -706,7 +706,7 @@ class CausalConvMixin:
...
@@ -706,7 +706,7 @@ class CausalConvMixin:
self
.
_use_buffered_input
=
variable
self
.
_use_buffered_input
=
variable
def
_compute_buffered_causal_padding
(
self
,
def
_compute_buffered_causal_padding
(
self
,
inputs
:
Optional
[
tf
.
Tensor
]
=
None
,
inputs
:
tf
.
Tensor
,
use_buffered_input
:
bool
=
False
,
use_buffered_input
:
bool
=
False
,
time_axis
:
int
=
1
)
->
List
[
List
[
int
]]:
time_axis
:
int
=
1
)
->
List
[
List
[
int
]]:
"""Calculates padding for 'causal' option for conv layers.
"""Calculates padding for 'causal' option for conv layers.
...
@@ -720,7 +720,7 @@ class CausalConvMixin:
...
@@ -720,7 +720,7 @@ class CausalConvMixin:
Returns:
Returns:
A list of paddings for `tf.pad`.
A list of paddings for `tf.pad`.
"""
"""
shape_in
=
inputs
.
shape
[
1
:
-
1
]
input_shape
=
tf
.
shape
(
inputs
)
[
1
:
-
1
]
del
inputs
del
inputs
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_first'
:
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_first'
:
...
@@ -731,11 +731,10 @@ class CausalConvMixin:
...
@@ -731,11 +731,10 @@ class CausalConvMixin:
(
self
.
kernel_size
[
i
]
-
1
)
*
(
self
.
dilation_rate
[
i
]
-
1
))
(
self
.
kernel_size
[
i
]
-
1
)
*
(
self
.
dilation_rate
[
i
]
-
1
))
for
i
in
range
(
self
.
rank
)
for
i
in
range
(
self
.
rank
)
]
]
pad_total
=
[
max
(
kernel_size_effective
[
i
]
-
(
self
.
strides
[
i
]),
0
)
pad_total
=
[
kernel_size_effective
[
0
]
-
1
]
if
(
shape_in
[
i
]
%
self
.
strides
[
i
])
==
0
else
for
i
in
range
(
1
,
self
.
rank
):
max
(
kernel_size_effective
[
i
]
-
overlap
=
(
input_shape
[
i
]
-
1
)
%
self
.
strides
[
i
]
+
1
(
shape_in
[
i
]
%
self
.
strides
[
i
]),
0
)
pad_total
.
append
(
tf
.
maximum
(
kernel_size_effective
[
i
]
-
overlap
,
0
))
for
i
in
range
(
self
.
rank
)]
pad_beg
=
[
pad_total
[
i
]
//
2
for
i
in
range
(
self
.
rank
)]
pad_beg
=
[
pad_total
[
i
]
//
2
for
i
in
range
(
self
.
rank
)]
pad_end
=
[
pad_total
[
i
]
-
pad_beg
[
i
]
for
i
in
range
(
self
.
rank
)]
pad_end
=
[
pad_total
[
i
]
-
pad_beg
[
i
]
for
i
in
range
(
self
.
rank
)]
padding
=
[[
pad_beg
[
i
],
pad_end
[
i
]]
for
i
in
range
(
self
.
rank
)]
padding
=
[[
pad_beg
[
i
],
pad_end
[
i
]]
for
i
in
range
(
self
.
rank
)]
...
@@ -768,7 +767,8 @@ class CausalConvMixin:
...
@@ -768,7 +767,8 @@ class CausalConvMixin:
# across time should be the input shape minus any padding, assuming
# across time should be the input shape minus any padding, assuming
# the stride across time is 1.
# the stride across time is 1.
if
self
.
_use_buffered_input
and
spatial_output_shape
[
0
]
is
not
None
:
if
self
.
_use_buffered_input
and
spatial_output_shape
[
0
]
is
not
None
:
padding
=
self
.
_compute_buffered_causal_padding
(
use_buffered_input
=
False
)
padding
=
self
.
_compute_buffered_causal_padding
(
tf
.
zeros
([
1
]
+
spatial_output_shape
+
[
1
]),
use_buffered_input
=
False
)
spatial_output_shape
[
0
]
-=
sum
(
padding
[
1
])
spatial_output_shape
[
0
]
-=
sum
(
padding
[
1
])
return
spatial_output_shape
return
spatial_output_shape
...
...
official/vision/beta/modeling/layers/nn_layers_test.py
View file @
6db23a1f
...
@@ -320,6 +320,9 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -320,6 +320,9 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
[[
12.
,
12.
,
12.
],
[[
12.
,
12.
,
12.
],
[
8.
,
8.
,
8.
]]]]])
[
8.
,
8.
,
8.
]]]]])
output_shape
=
conv3d
.
_spatial_output_shape
([
4
,
4
,
4
])
self
.
assertAllClose
(
output_shape
,
[
2
,
2
,
2
])
self
.
assertEqual
(
predicted
.
shape
,
expected
.
shape
)
self
.
assertEqual
(
predicted
.
shape
,
expected
.
shape
)
self
.
assertAllClose
(
predicted
,
expected
)
self
.
assertAllClose
(
predicted
,
expected
)
...
@@ -329,5 +332,74 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -329,5 +332,74 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
self
.
assertEqual
(
predicted
.
shape
,
expected
.
shape
)
self
.
assertEqual
(
predicted
.
shape
,
expected
.
shape
)
self
.
assertAllClose
(
predicted
,
expected
)
self
.
assertAllClose
(
predicted
,
expected
)
def
test_conv3d_causal_padding_2d
(
self
):
"""Test to ensure causal padding works like standard padding."""
conv3d
=
nn_layers
.
Conv3D
(
filters
=
1
,
kernel_size
=
(
1
,
3
,
3
),
strides
=
(
1
,
2
,
2
),
padding
=
'causal'
,
use_buffered_input
=
False
,
kernel_initializer
=
'ones'
,
use_bias
=
False
,
)
keras_conv3d
=
tf
.
keras
.
layers
.
Conv3D
(
filters
=
1
,
kernel_size
=
(
1
,
3
,
3
),
strides
=
(
1
,
2
,
2
),
padding
=
'same'
,
kernel_initializer
=
'ones'
,
use_bias
=
False
,
)
inputs
=
tf
.
ones
([
1
,
1
,
4
,
4
,
1
])
predicted
=
conv3d
(
inputs
)
expected
=
keras_conv3d
(
inputs
)
self
.
assertEqual
(
predicted
.
shape
,
expected
.
shape
)
self
.
assertAllClose
(
predicted
,
expected
)
self
.
assertAllClose
(
predicted
,
[[[[[
9.
],
[
6.
]],
[[
6.
],
[
4.
]]]]])
def
test_conv3d_causal_padding_1d
(
self
):
"""Test to ensure causal padding works like standard padding."""
conv3d
=
nn_layers
.
Conv3D
(
filters
=
1
,
kernel_size
=
(
3
,
1
,
1
),
strides
=
(
2
,
1
,
1
),
padding
=
'causal'
,
use_buffered_input
=
False
,
kernel_initializer
=
'ones'
,
use_bias
=
False
,
)
keras_conv1d
=
tf
.
keras
.
layers
.
Conv1D
(
filters
=
1
,
kernel_size
=
3
,
strides
=
2
,
padding
=
'causal'
,
kernel_initializer
=
'ones'
,
use_bias
=
False
,
)
inputs
=
tf
.
ones
([
1
,
4
,
1
,
1
,
1
])
predicted
=
conv3d
(
inputs
)
expected
=
keras_conv1d
(
tf
.
squeeze
(
inputs
,
axis
=
[
2
,
3
]))
expected
=
tf
.
reshape
(
expected
,
[
1
,
2
,
1
,
1
,
1
])
self
.
assertEqual
(
predicted
.
shape
,
expected
.
shape
)
self
.
assertAllClose
(
predicted
,
expected
)
self
.
assertAllClose
(
predicted
,
[[[[[
1.
]]],
[[[
3.
]]]]])
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment