Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ddc9bce0
Commit
ddc9bce0
authored
Jul 19, 2022
by
A. Unique TensorFlower
Browse files
Merge pull request #10626 from ryan0507:master
PiperOrigin-RevId: 461948867
parents
4323d37c
e3fc61e7
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
142 additions
and
83 deletions
+142
-83
official/projects/yt8m/configs/yt8m_test.py
official/projects/yt8m/configs/yt8m_test.py
+40
-0
official/projects/yt8m/dataloaders/utils.py
official/projects/yt8m/dataloaders/utils.py
+23
-75
official/projects/yt8m/dataloaders/yt8m_input.py
official/projects/yt8m/dataloaders/yt8m_input.py
+1
-1
official/projects/yt8m/dataloaders/yt8m_input_test.py
official/projects/yt8m/dataloaders/yt8m_input_test.py
+3
-2
official/projects/yt8m/eval_utils/average_precision_calculator.py
.../projects/yt8m/eval_utils/average_precision_calculator.py
+1
-2
official/projects/yt8m/eval_utils/eval_util_test.py
official/projects/yt8m/eval_utils/eval_util_test.py
+70
-0
official/projects/yt8m/modeling/nn_layers.py
official/projects/yt8m/modeling/nn_layers.py
+0
-0
official/projects/yt8m/modeling/yt8m_model.py
official/projects/yt8m/modeling/yt8m_model.py
+3
-2
official/projects/yt8m/train_test.py
official/projects/yt8m/train_test.py
+1
-1
No files found.
official/projects/yt8m/configs/yt8m_test.py
0 → 100644
View file @
ddc9bce0
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.projects.yt8m.configs
import
yt8m
# pylint: disable=unused-import
from
official.projects.yt8m.configs.yt8m
import
yt8m
as
exp_cfg
class
YT8MTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
(
(
'yt8m_experiment'
,),)
def
test_yt8m_configs
(
self
,
config_name
):
config
=
exp_factory
.
get_exp_config
(
config_name
)
self
.
assertIsInstance
(
config
,
cfg
.
ExperimentConfig
)
self
.
assertIsInstance
(
config
.
task
,
cfg
.
TaskConfig
)
self
.
assertIsInstance
(
config
.
task
.
model
,
hyperparams
.
Config
)
self
.
assertIsInstance
(
config
.
task
.
train_data
,
cfg
.
DataConfig
)
config
.
task
.
train_data
.
is_training
=
None
with
self
.
assertRaises
(
KeyError
):
config
.
validate
()
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/yt8m/dataloaders/utils.py
View file @
ddc9bce0
...
@@ -20,7 +20,7 @@ import tensorflow as tf
...
@@ -20,7 +20,7 @@ import tensorflow as tf
from
official.vision.dataloaders
import
tfexample_utils
from
official.vision.dataloaders
import
tfexample_utils
def
D
equantize
(
feat_vector
,
max_quantized_value
=
2
,
min_quantized_value
=-
2
):
def
d
equantize
(
feat_vector
,
max_quantized_value
=
2
,
min_quantized_value
=-
2
):
"""Dequantize the feature from the byte format to the float format.
"""Dequantize the feature from the byte format to the float format.
Args:
Args:
...
@@ -38,7 +38,7 @@ def Dequantize(feat_vector, max_quantized_value=2, min_quantized_value=-2):
...
@@ -38,7 +38,7 @@ def Dequantize(feat_vector, max_quantized_value=2, min_quantized_value=-2):
return
feat_vector
*
scalar
+
bias
return
feat_vector
*
scalar
+
bias
def
M
ake
S
ummary
(
name
,
value
):
def
m
ake
_s
ummary
(
name
,
value
):
"""Creates a tf.Summary proto with the given name and value."""
"""Creates a tf.Summary proto with the given name and value."""
summary
=
tf
.
Summary
()
summary
=
tf
.
Summary
()
val
=
summary
.
value
.
add
()
val
=
summary
.
value
.
add
()
...
@@ -47,10 +47,10 @@ def MakeSummary(name, value):
...
@@ -47,10 +47,10 @@ def MakeSummary(name, value):
return
summary
return
summary
def
A
dd
G
lobal
S
tep
S
ummary
(
summary_writer
,
def
a
dd
_g
lobal
_s
tep
_s
ummary
(
summary_writer
,
global_step_val
,
global_step_val
,
global_step_info_dict
,
global_step_info_dict
,
summary_scope
=
"Eval"
):
summary_scope
=
"Eval"
):
"""Add the global_step summary to the Tensorboard.
"""Add the global_step summary to the Tensorboard.
Args:
Args:
...
@@ -69,19 +69,19 @@ def AddGlobalStepSummary(summary_writer,
...
@@ -69,19 +69,19 @@ def AddGlobalStepSummary(summary_writer,
examples_per_second
=
global_step_info_dict
.
get
(
"examples_per_second"
,
-
1
)
examples_per_second
=
global_step_info_dict
.
get
(
"examples_per_second"
,
-
1
)
summary_writer
.
add_summary
(
summary_writer
.
add_summary
(
M
ake
S
ummary
(
"GlobalStep/"
+
summary_scope
+
"_Hit@1"
,
this_hit_at_one
),
m
ake
_s
ummary
(
"GlobalStep/"
+
summary_scope
+
"_Hit@1"
,
this_hit_at_one
),
global_step_val
)
global_step_val
)
summary_writer
.
add_summary
(
summary_writer
.
add_summary
(
M
ake
S
ummary
(
"GlobalStep/"
+
summary_scope
+
"_Perr"
,
this_perr
),
m
ake
_s
ummary
(
"GlobalStep/"
+
summary_scope
+
"_Perr"
,
this_perr
),
global_step_val
)
global_step_val
)
summary_writer
.
add_summary
(
summary_writer
.
add_summary
(
M
ake
S
ummary
(
"GlobalStep/"
+
summary_scope
+
"_Loss"
,
this_loss
),
m
ake
_s
ummary
(
"GlobalStep/"
+
summary_scope
+
"_Loss"
,
this_loss
),
global_step_val
)
global_step_val
)
if
examples_per_second
!=
-
1
:
if
examples_per_second
!=
-
1
:
summary_writer
.
add_summary
(
summary_writer
.
add_summary
(
M
ake
S
ummary
(
"GlobalStep/"
+
summary_scope
+
"_Example_Second"
,
m
ake
_s
ummary
(
"GlobalStep/"
+
summary_scope
+
"_Example_Second"
,
examples_per_second
),
global_step_val
)
examples_per_second
),
global_step_val
)
summary_writer
.
flush
()
summary_writer
.
flush
()
info
=
(
info
=
(
...
@@ -92,10 +92,10 @@ def AddGlobalStepSummary(summary_writer,
...
@@ -92,10 +92,10 @@ def AddGlobalStepSummary(summary_writer,
return
info
return
info
def
A
dd
E
poch
S
ummary
(
summary_writer
,
def
a
dd
_e
poch
_s
ummary
(
summary_writer
,
global_step_val
,
global_step_val
,
epoch_info_dict
,
epoch_info_dict
,
summary_scope
=
"Eval"
):
summary_scope
=
"Eval"
):
"""Add the epoch summary to the Tensorboard.
"""Add the epoch summary to the Tensorboard.
Args:
Args:
...
@@ -117,18 +117,18 @@ def AddEpochSummary(summary_writer,
...
@@ -117,18 +117,18 @@ def AddEpochSummary(summary_writer,
mean_ap
=
np
.
mean
(
aps
)
mean_ap
=
np
.
mean
(
aps
)
summary_writer
.
add_summary
(
summary_writer
.
add_summary
(
M
ake
S
ummary
(
"Epoch/"
+
summary_scope
+
"_Avg_Hit@1"
,
avg_hit_at_one
),
m
ake
_s
ummary
(
"Epoch/"
+
summary_scope
+
"_Avg_Hit@1"
,
avg_hit_at_one
),
global_step_val
)
global_step_val
)
summary_writer
.
add_summary
(
summary_writer
.
add_summary
(
M
ake
S
ummary
(
"Epoch/"
+
summary_scope
+
"_Avg_Perr"
,
avg_perr
),
m
ake
_s
ummary
(
"Epoch/"
+
summary_scope
+
"_Avg_Perr"
,
avg_perr
),
global_step_val
)
global_step_val
)
summary_writer
.
add_summary
(
summary_writer
.
add_summary
(
M
ake
S
ummary
(
"Epoch/"
+
summary_scope
+
"_Avg_Loss"
,
avg_loss
),
m
ake
_s
ummary
(
"Epoch/"
+
summary_scope
+
"_Avg_Loss"
,
avg_loss
),
global_step_val
)
global_step_val
)
summary_writer
.
add_summary
(
summary_writer
.
add_summary
(
M
ake
S
ummary
(
"Epoch/"
+
summary_scope
+
"_MAP"
,
mean_ap
),
global_step_val
)
m
ake
_s
ummary
(
"Epoch/"
+
summary_scope
+
"_MAP"
,
mean_ap
),
global_step_val
)
summary_writer
.
add_summary
(
summary_writer
.
add_summary
(
M
ake
S
ummary
(
"Epoch/"
+
summary_scope
+
"_GAP"
,
gap
),
global_step_val
)
m
ake
_s
ummary
(
"Epoch/"
+
summary_scope
+
"_GAP"
,
gap
),
global_step_val
)
summary_writer
.
flush
()
summary_writer
.
flush
()
info
=
(
"epoch/eval number {0} | Avg_Hit@1: {1:.3f} | Avg_PERR: {2:.3f} "
info
=
(
"epoch/eval number {0} | Avg_Hit@1: {1:.3f} | Avg_PERR: {2:.3f} "
...
@@ -138,7 +138,7 @@ def AddEpochSummary(summary_writer,
...
@@ -138,7 +138,7 @@ def AddEpochSummary(summary_writer,
return
info
return
info
def
G
et
L
ist
OfF
eature
N
ames
AndS
izes
(
feature_names
,
feature_sizes
):
def
g
et
_l
ist
_of_f
eature
_n
ames
_and_s
izes
(
feature_names
,
feature_sizes
):
"""Extract the list of feature names and the dimensionality.
"""Extract the list of feature names and the dimensionality.
Args:
Args:
...
@@ -164,59 +164,7 @@ def GetListOfFeatureNamesAndSizes(feature_names, feature_sizes):
...
@@ -164,59 +164,7 @@ def GetListOfFeatureNamesAndSizes(feature_names, feature_sizes):
return
list_of_feature_names
,
list_of_feature_sizes
return
list_of_feature_names
,
list_of_feature_sizes
def
ClipGradientNorms
(
gradients_to_variables
,
max_norm
):
def
make_yt8m_example
(
num_segment
:
int
=
5
)
->
tf
.
train
.
SequenceExample
:
"""Clips the gradients by the given value.
Args:
gradients_to_variables: A list of gradient to variable pairs (tuples).
max_norm: the maximum norm value.
Returns:
A list of clipped gradient to variable pairs.
"""
clipped_grads_and_vars
=
[]
for
grad
,
var
in
gradients_to_variables
:
if
grad
is
not
None
:
if
isinstance
(
grad
,
tf
.
IndexedSlices
):
tmp
=
tf
.
clip_by_norm
(
grad
.
values
,
max_norm
)
grad
=
tf
.
IndexedSlices
(
tmp
,
grad
.
indices
,
grad
.
dense_shape
)
else
:
grad
=
tf
.
clip_by_norm
(
grad
,
max_norm
)
clipped_grads_and_vars
.
append
((
grad
,
var
))
return
clipped_grads_and_vars
def
CombineGradients
(
tower_grads
):
"""Calculate the combined gradient for each shared variable across all towers.
Note that this function provides a synchronization point across all towers.
Args:
tower_grads: List of lists of (gradient, variable) tuples. The outer list is
over individual gradients. The inner list is over the gradient calculation
for each tower.
Returns:
List of pairs of (gradient, variable) where the gradient has been summed
across all towers.
"""
filtered_grads
=
[
[
x
for
x
in
grad_list
if
x
[
0
]
is
not
None
]
for
grad_list
in
tower_grads
]
final_grads
=
[]
for
i
in
range
(
len
(
filtered_grads
[
0
])):
grads
=
[
filtered_grads
[
t
][
i
]
for
t
in
range
(
len
(
filtered_grads
))]
grad
=
tf
.
stack
([
x
[
0
]
for
x
in
grads
],
0
)
grad
=
tf
.
reduce_sum
(
grad
,
0
)
final_grads
.
append
((
grad
,
filtered_grads
[
0
][
i
][
1
],
))
return
final_grads
def
MakeYt8mExample
(
num_segment
:
int
=
5
)
->
tf
.
train
.
SequenceExample
:
"""Generate fake data for unit tests."""
"""Generate fake data for unit tests."""
rgb
=
np
.
random
.
randint
(
low
=
256
,
size
=
1024
,
dtype
=
np
.
uint8
)
rgb
=
np
.
random
.
randint
(
low
=
256
,
size
=
1024
,
dtype
=
np
.
uint8
)
audio
=
np
.
random
.
randint
(
low
=
256
,
size
=
128
,
dtype
=
np
.
uint8
)
audio
=
np
.
random
.
randint
(
low
=
256
,
size
=
128
,
dtype
=
np
.
uint8
)
...
@@ -240,7 +188,7 @@ def MakeYt8mExample(num_segment: int = 5) -> tf.train.SequenceExample:
...
@@ -240,7 +188,7 @@ def MakeYt8mExample(num_segment: int = 5) -> tf.train.SequenceExample:
# TODO(yeqing): Move the test related functions to test_utils.
# TODO(yeqing): Move the test related functions to test_utils.
def
M
ake
E
xample
W
ith
F
loat
F
eatures
(
def
m
ake
_e
xample
_w
ith
_f
loat
_f
eatures
(
num_segment
:
int
=
5
)
->
tf
.
train
.
SequenceExample
:
num_segment
:
int
=
5
)
->
tf
.
train
.
SequenceExample
:
"""Generate fake data for unit tests."""
"""Generate fake data for unit tests."""
rgb
=
np
.
random
.
rand
(
1
,
2048
).
astype
(
np
.
float32
)
rgb
=
np
.
random
.
rand
(
1
,
2048
).
astype
(
np
.
float32
)
...
...
official/projects/yt8m/dataloaders/yt8m_input.py
View file @
ddc9bce0
...
@@ -175,7 +175,7 @@ def _get_video_matrix(features, feature_size, dtype, max_frames,
...
@@ -175,7 +175,7 @@ def _get_video_matrix(features, feature_size, dtype, max_frames,
num_frames
=
tf
.
math
.
minimum
(
tf
.
shape
(
decoded_features
)[
0
],
max_frames
)
num_frames
=
tf
.
math
.
minimum
(
tf
.
shape
(
decoded_features
)[
0
],
max_frames
)
if
dtype
.
is_integer
:
if
dtype
.
is_integer
:
feature_matrix
=
utils
.
D
equantize
(
decoded_features
,
max_quantized_value
,
feature_matrix
=
utils
.
d
equantize
(
decoded_features
,
max_quantized_value
,
min_quantized_value
)
min_quantized_value
)
else
:
else
:
feature_matrix
=
decoded_features
feature_matrix
=
decoded_features
...
...
official/projects/yt8m/dataloaders/yt8m_input_test.py
View file @
ddc9bce0
...
@@ -37,7 +37,7 @@ class Yt8mInputTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -37,7 +37,7 @@ class Yt8mInputTest(parameterized.TestCase, tf.test.TestCase):
tf
.
io
.
gfile
.
makedirs
(
data_dir
)
tf
.
io
.
gfile
.
makedirs
(
data_dir
)
self
.
data_path
=
os
.
path
.
join
(
data_dir
,
'data.tfrecord'
)
self
.
data_path
=
os
.
path
.
join
(
data_dir
,
'data.tfrecord'
)
self
.
num_segment
=
6
self
.
num_segment
=
6
examples
=
[
utils
.
M
ake
Y
t8m
E
xample
(
self
.
num_segment
)
for
_
in
range
(
8
)]
examples
=
[
utils
.
m
ake
_y
t8m
_e
xample
(
self
.
num_segment
)
for
_
in
range
(
8
)]
tfexample_utils
.
dump_to_tfrecord
(
self
.
data_path
,
tf_examples
=
examples
)
tfexample_utils
.
dump_to_tfrecord
(
self
.
data_path
,
tf_examples
=
examples
)
def
create_input_reader
(
self
,
params
):
def
create_input_reader
(
self
,
params
):
...
@@ -130,7 +130,8 @@ class Yt8mInputTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -130,7 +130,8 @@ class Yt8mInputTest(parameterized.TestCase, tf.test.TestCase):
tf
.
io
.
gfile
.
makedirs
(
data_dir
)
tf
.
io
.
gfile
.
makedirs
(
data_dir
)
data_path
=
os
.
path
.
join
(
data_dir
,
'data2.tfrecord'
)
data_path
=
os
.
path
.
join
(
data_dir
,
'data2.tfrecord'
)
examples
=
[
examples
=
[
utils
.
MakeExampleWithFloatFeatures
(
self
.
num_segment
)
for
_
in
range
(
8
)
utils
.
make_example_with_float_features
(
self
.
num_segment
)
for
_
in
range
(
8
)
]
]
tfexample_utils
.
dump_to_tfrecord
(
data_path
,
tf_examples
=
examples
)
tfexample_utils
.
dump_to_tfrecord
(
data_path
,
tf_examples
=
examples
)
...
...
official/projects/yt8m/eval_utils/average_precision_calculator.py
View file @
ddc9bce0
...
@@ -268,6 +268,5 @@ class AveragePrecisionCalculator(object):
...
@@ -268,6 +268,5 @@ class AveragePrecisionCalculator(object):
The normalized prediction.
The normalized prediction.
"""
"""
denominator
=
numpy
.
max
(
predictions
)
-
numpy
.
min
(
predictions
)
denominator
=
numpy
.
max
(
predictions
)
-
numpy
.
min
(
predictions
)
ret
=
(
predictions
-
numpy
.
min
(
predictions
))
/
numpy
.
max
(
ret
=
(
predictions
-
numpy
.
min
(
predictions
))
/
max
(
denominator
,
epsilon
)
denominator
,
epsilon
)
return
ret
return
ret
official/projects/yt8m/eval_utils/eval_util_test.py
0 → 100644
View file @
ddc9bce0
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
absl
import
logging
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.projects.yt8m.eval_utils.average_precision_calculator
import
AveragePrecisionCalculator
class
YT8MAveragePrecisionCalculatorTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
self
.
prediction
=
np
.
array
([
[
0.98
,
0.88
,
0.77
,
0.65
,
0.64
,
0.59
,
0.45
,
0.43
,
0.20
,
0.05
],
[
0.878
,
0.832
,
0.759
,
0.621
,
0.458
,
0.285
,
0.134
],
[
0.98
],
[
0.56
],
])
self
.
raw_prediction
=
np
.
random
.
rand
(
5
,
10
)
+
np
.
random
.
randint
(
low
=
0
,
high
=
10
,
size
=
(
5
,
10
))
self
.
ground_truth
=
np
.
array
([[
1
,
1
,
0
,
0
,
0
,
1
,
1
,
0
,
0
,
1
],
[
1
,
0
,
1
,
0
,
0
,
1
,
0
],
[
1
],
[
0
]])
self
.
expected_ap
=
np
.
array
([
0.714
,
0.722
,
1.000
,
0.000
,
])
def
test_ap_calculator_ap
(
self
):
# Compare Expected Average Precision with function expected
for
i
,
_
in
enumerate
(
self
.
ground_truth
):
calculator
=
AveragePrecisionCalculator
()
ap
=
calculator
.
ap
(
self
.
prediction
[
i
],
self
.
ground_truth
[
i
])
logging
.
info
(
'DEBUG %dth AP: %r'
,
i
+
1
,
ap
)
def
test_ap_calculator_zero_one_normalize
(
self
):
for
i
,
_
in
enumerate
(
self
.
raw_prediction
):
calculator
=
AveragePrecisionCalculator
()
logging
.
error
(
'%r'
,
self
.
raw_prediction
[
i
])
normalized_score
=
calculator
.
_zero_one_normalize
(
self
.
raw_prediction
[
i
])
self
.
assertAllInRange
(
normalized_score
,
lower_bound
=
0.0
,
upper_bound
=
1.0
)
@
parameterized
.
parameters
((
None
,),
(
3
,),
(
5
,),
(
10
,),
(
20
,))
def
test_ap_calculator_ap_at_n
(
self
,
n
):
for
i
,
_
in
enumerate
(
self
.
ground_truth
):
calculator
=
AveragePrecisionCalculator
(
n
)
ap
=
calculator
.
ap_at_n
(
self
.
prediction
[
i
],
self
.
ground_truth
[
i
],
n
)
logging
.
info
(
'DEBUG %dth AP: %r'
,
i
+
1
,
ap
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/yt8m/modeling/
yt8m_agg_model
s.py
→
official/projects/yt8m/modeling/
nn_layer
s.py
View file @
ddc9bce0
File moved
official/projects/yt8m/modeling/yt8m_model.py
View file @
ddc9bce0
...
@@ -16,9 +16,10 @@
...
@@ -16,9 +16,10 @@
from
typing
import
Optional
from
typing
import
Optional
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.projects.yt8m.configs
import
yt8m
as
yt8m_cfg
from
official.projects.yt8m.configs
import
yt8m
as
yt8m_cfg
from
official.projects.yt8m.modeling
import
yt8m_agg_model
s
from
official.projects.yt8m.modeling
import
nn_layer
s
from
official.projects.yt8m.modeling
import
yt8m_model_utils
as
utils
from
official.projects.yt8m.modeling
import
yt8m_model_utils
as
utils
layers
=
tf
.
keras
.
layers
layers
=
tf
.
keras
.
layers
...
@@ -171,7 +172,7 @@ class DbofModel(tf.keras.Model):
...
@@ -171,7 +172,7 @@ class DbofModel(tf.keras.Model):
activation
=
self
.
_act_fn
(
activation
)
activation
=
self
.
_act_fn
(
activation
)
tf
.
summary
.
histogram
(
"hidden1_output"
,
activation
)
tf
.
summary
.
histogram
(
"hidden1_output"
,
activation
)
aggregated_model
=
getattr
(
yt8m_agg_model
s
,
aggregated_model
=
getattr
(
nn_layer
s
,
params
.
yt8m_agg_classifier_model
)
params
.
yt8m_agg_classifier_model
)
norm_args
=
dict
(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
)
norm_args
=
dict
(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
)
output
=
aggregated_model
().
create_model
(
output
=
aggregated_model
().
create_model
(
...
...
official/projects/yt8m/train_test.py
View file @
ddc9bce0
...
@@ -36,7 +36,7 @@ class TrainTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -36,7 +36,7 @@ class TrainTest(parameterized.TestCase, tf.test.TestCase):
data_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'data'
)
data_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'data'
)
tf
.
io
.
gfile
.
makedirs
(
data_dir
)
tf
.
io
.
gfile
.
makedirs
(
data_dir
)
self
.
_data_path
=
os
.
path
.
join
(
data_dir
,
'data.tfrecord'
)
self
.
_data_path
=
os
.
path
.
join
(
data_dir
,
'data.tfrecord'
)
examples
=
[
utils
.
M
ake
Y
t8m
E
xample
()
for
_
in
range
(
8
)]
examples
=
[
utils
.
m
ake
_y
t8m
_e
xample
()
for
_
in
range
(
8
)]
tfexample_utils
.
dump_to_tfrecord
(
self
.
_data_path
,
tf_examples
=
examples
)
tfexample_utils
.
dump_to_tfrecord
(
self
.
_data_path
,
tf_examples
=
examples
)
@
parameterized
.
named_parameters
(
@
parameterized
.
named_parameters
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment