Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3d64b536
Commit
3d64b536
authored
May 24, 2022
by
Reed Wanderman-Milne
Committed by
A. Unique TensorFlower
May 24, 2022
Browse files
Internal change.
PiperOrigin-RevId: 450785872
parent
00f96640
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
304 additions
and
0 deletions
+304
-0
official/nlp/tasks/masked_lm_determinism_test.py
official/nlp/tasks/masked_lm_determinism_test.py
+103
-0
official/vision/tasks/image_classification_determinism_test.py
...ial/vision/tasks/image_classification_determinism_test.py
+86
-0
official/vision/tasks/maskrcnn_determinism_test.py
official/vision/tasks/maskrcnn_determinism_test.py
+115
-0
No files found.
official/nlp/tasks/masked_lm_determinism_test.py
0 → 100644
View file @
3d64b536
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests that masked LM models are deterministic when determinism is enabled."""
import
tensorflow
as
tf
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
pretrain_dataloader
from
official.nlp.tasks
import
masked_lm
class
MLMTaskTest
(
tf
.
test
.
TestCase
):
def
_build_dataset
(
self
,
params
,
vocab_size
):
def
dummy_data
(
_
):
dummy_ids
=
tf
.
random
.
uniform
((
1
,
params
.
seq_length
),
maxval
=
vocab_size
,
dtype
=
tf
.
int32
)
dummy_mask
=
tf
.
ones
((
1
,
params
.
seq_length
),
dtype
=
tf
.
int32
)
dummy_type_ids
=
tf
.
zeros
((
1
,
params
.
seq_length
),
dtype
=
tf
.
int32
)
dummy_lm
=
tf
.
zeros
((
1
,
params
.
max_predictions_per_seq
),
dtype
=
tf
.
int32
)
return
dict
(
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_mask
,
input_type_ids
=
dummy_type_ids
,
masked_lm_positions
=
dummy_lm
,
masked_lm_ids
=
dummy_lm
,
masked_lm_weights
=
tf
.
cast
(
dummy_lm
,
dtype
=
tf
.
float32
),
next_sentence_labels
=
tf
.
zeros
((
1
,
1
),
dtype
=
tf
.
int32
))
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
map
(
dummy_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
def
_build_and_run_model
(
self
,
config
,
num_steps
=
5
):
task
=
masked_lm
.
MaskedLMTask
(
config
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
dataset
=
self
.
_build_dataset
(
config
.
train_data
,
config
.
model
.
encoder
.
get
().
vocab_size
)
iterator
=
iter
(
dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
)
# Run training
for
_
in
range
(
num_steps
):
logs
=
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
for
metric
in
metrics
:
logs
[
metric
.
name
]
=
metric
.
result
()
# Run validation
validation_logs
=
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
for
metric
in
metrics
:
validation_logs
[
metric
.
name
]
=
metric
.
result
()
return
logs
,
validation_logs
,
model
.
weights
def
test_task_determinism
(
self
):
config
=
masked_lm
.
MaskedLMConfig
(
init_checkpoint
=
self
.
get_temp_dir
(),
scale_loss
=
True
,
model
=
bert
.
PretrainerConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
]),
train_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
max_predictions_per_seq
=
20
,
seq_length
=
128
,
global_batch_size
=
1
))
tf
.
keras
.
utils
.
set_random_seed
(
1
)
logs1
,
validation_logs1
,
weights1
=
self
.
_build_and_run_model
(
config
)
tf
.
keras
.
utils
.
set_random_seed
(
1
)
logs2
,
validation_logs2
,
weights2
=
self
.
_build_and_run_model
(
config
)
self
.
assertEqual
(
logs1
[
"loss"
],
logs2
[
"loss"
])
self
.
assertEqual
(
validation_logs1
[
"loss"
],
validation_logs2
[
"loss"
])
for
weight1
,
weight2
in
zip
(
weights1
,
weights2
):
self
.
assertAllEqual
(
weight1
,
weight2
)
if
__name__
==
"__main__"
:
tf
.
config
.
experimental
.
enable_op_determinism
()
tf
.
test
.
main
()
official/vision/tasks/image_classification_determinism_test.py
0 → 100644
View file @
3d64b536
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests that image classification models are deterministic."""
# pylint: disable=unused-import
from
absl.testing
import
parameterized
import
orbit
import
tensorflow
as
tf
from
official.core
import
exp_factory
from
official.modeling
import
optimization
from
official.vision.tasks
import
image_classification
class
ImageClassificationDeterminismTaskTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_build_and_run_model
(
self
,
config
):
task
=
image_classification
.
ImageClassificationTask
(
config
.
task
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
strategy
=
tf
.
distribute
.
get_strategy
()
dataset
=
orbit
.
utils
.
make_distributed_dataset
(
strategy
,
task
.
build_inputs
,
config
.
task
.
train_data
)
iterator
=
iter
(
dataset
)
opt_factory
=
optimization
.
OptimizerFactory
(
config
.
trainer
.
optimizer_config
)
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
if
isinstance
(
optimizer
,
optimization
.
ExponentialMovingAverage
)
and
not
optimizer
.
has_shadow_copy
:
optimizer
.
shadow_copy
(
model
)
# Run training
for
_
in
range
(
5
):
logs
=
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
for
metric
in
metrics
:
logs
[
metric
.
name
]
=
metric
.
result
()
# Run validation
validation_logs
=
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
for
metric
in
metrics
:
validation_logs
[
metric
.
name
]
=
metric
.
result
()
return
logs
,
validation_logs
,
model
.
weights
def
test_task_deterministic
(
self
):
config_name
=
"resnet_imagenet"
config
=
exp_factory
.
get_exp_config
(
config_name
)
config
.
task
.
train_data
.
global_batch_size
=
2
# TODO(b/202552359): Run the two models in separate processes. Some
# potential sources of non-determinism only occur when the runs are each
# done in a different process.
tf
.
keras
.
utils
.
set_random_seed
(
1
)
logs1
,
validation_logs1
,
weights1
=
self
.
_build_and_run_model
(
config
)
tf
.
keras
.
utils
.
set_random_seed
(
1
)
logs2
,
validation_logs2
,
weights2
=
self
.
_build_and_run_model
(
config
)
self
.
assertEqual
(
logs1
[
"loss"
],
logs2
[
"loss"
])
self
.
assertEqual
(
logs1
[
"accuracy"
],
logs2
[
"accuracy"
])
self
.
assertEqual
(
logs1
[
"top_5_accuracy"
],
logs2
[
"top_5_accuracy"
])
self
.
assertEqual
(
validation_logs1
[
"loss"
],
validation_logs2
[
"loss"
])
self
.
assertEqual
(
validation_logs1
[
"accuracy"
],
validation_logs2
[
"accuracy"
])
self
.
assertEqual
(
validation_logs1
[
"top_5_accuracy"
],
validation_logs2
[
"top_5_accuracy"
])
for
weight1
,
weight2
in
zip
(
weights1
,
weights2
):
self
.
assertAllEqual
(
weight1
,
weight2
)
if
__name__
==
"__main__"
:
tf
.
config
.
experimental
.
enable_op_determinism
()
tf
.
test
.
main
()
official/vision/tasks/maskrcnn_determinism_test.py
0 → 100644
View file @
3d64b536
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test that Mask RCNN is deterministic when TF determinism is enabled."""
# pylint: disable=unused-import
from
absl.testing
import
parameterized
import
orbit
import
tensorflow
as
tf
from
official.core
import
exp_factory
from
official.modeling
import
optimization
from
official.vision.tasks
import
maskrcnn
class
MaskRcnnTaskTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
def
_edit_config_for_testing
(
self
,
config
):
# modify config to suit local testing
config
.
trainer
.
steps_per_loop
=
1
config
.
task
.
train_data
.
global_batch_size
=
2
config
.
task
.
model
.
backbone
.
resnet
.
model_id
=
18
config
.
task
.
model
.
decoder
.
fpn
.
num_filters
=
32
config
.
task
.
model
.
detection_generator
.
pre_nms_top_k
=
500
config
.
task
.
model
.
detection_head
.
fc_dims
=
128
if
config
.
task
.
model
.
include_mask
:
config
.
task
.
model
.
mask_sampler
.
num_sampled_masks
=
10
config
.
task
.
model
.
mask_head
.
num_convs
=
1
config
.
task
.
model
.
roi_generator
.
num_proposals
=
100
config
.
task
.
model
.
roi_generator
.
pre_nms_top_k
=
150
config
.
task
.
model
.
roi_generator
.
test_pre_nms_top_k
=
150
config
.
task
.
model
.
roi_generator
.
test_num_proposals
=
100
config
.
task
.
model
.
rpn_head
.
num_filters
=
32
config
.
task
.
model
.
roi_sampler
.
num_sampled_rois
=
200
config
.
task
.
model
.
input_size
=
[
128
,
128
,
3
]
config
.
trainer
.
train_steps
=
2
config
.
task
.
train_data
.
shuffle_buffer_size
=
2
config
.
task
.
train_data
.
input_path
=
"coco/train-00000-of-00256.tfrecord"
config
.
task
.
validation_data
.
global_batch_size
=
2
config
.
task
.
validation_data
.
input_path
=
"coco/val-00000-of-00032.tfrecord"
def
_build_and_run_model
(
self
,
config
):
task
=
maskrcnn
.
MaskRCNNTask
(
config
.
task
)
model
=
task
.
build_model
()
train_metrics
=
task
.
build_metrics
(
training
=
True
)
validation_metrics
=
task
.
build_metrics
(
training
=
False
)
strategy
=
tf
.
distribute
.
get_strategy
()
train_dataset
=
orbit
.
utils
.
make_distributed_dataset
(
strategy
,
task
.
build_inputs
,
config
.
task
.
train_data
)
train_iterator
=
iter
(
train_dataset
)
validation_dataset
=
orbit
.
utils
.
make_distributed_dataset
(
strategy
,
task
.
build_inputs
,
config
.
task
.
validation_data
)
validation_iterator
=
iter
(
validation_dataset
)
opt_factory
=
optimization
.
OptimizerFactory
(
config
.
trainer
.
optimizer_config
)
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
# Run training
logs
=
task
.
train_step
(
next
(
train_iterator
),
model
,
optimizer
,
metrics
=
train_metrics
)
for
metric
in
train_metrics
:
logs
[
metric
.
name
]
=
metric
.
result
()
# Run validation
validation_logs
=
task
.
validation_step
(
next
(
validation_iterator
),
model
,
metrics
=
validation_metrics
)
for
metric
in
validation_metrics
:
validation_logs
[
metric
.
name
]
=
metric
.
result
()
return
logs
,
validation_logs
,
model
.
weights
@
parameterized
.
parameters
(
"fasterrcnn_resnetfpn_coco"
,
"maskrcnn_resnetfpn_coco"
,
"maskrcnn_spinenet_coco"
,
"cascadercnn_spinenet_coco"
,
)
def
test_maskrcnn_task_train
(
self
,
test_config
):
"""RetinaNet task test for training and val using toy configs."""
config
=
exp_factory
.
get_exp_config
(
test_config
)
self
.
_edit_config_for_testing
(
config
)
tf
.
keras
.
utils
.
set_random_seed
(
1
)
logs1
,
validation_logs1
,
weights1
=
self
.
_build_and_run_model
(
config
)
tf
.
keras
.
utils
.
set_random_seed
(
1
)
logs2
,
validation_logs2
,
weights2
=
self
.
_build_and_run_model
(
config
)
self
.
assertAllEqual
(
logs1
[
"loss"
],
logs2
[
"loss"
])
self
.
assertAllEqual
(
logs1
[
"total_loss"
],
logs2
[
"total_loss"
])
self
.
assertAllEqual
(
logs1
[
"loss"
],
logs2
[
"loss"
])
self
.
assertAllEqual
(
validation_logs1
[
"coco_metric"
][
1
][
"detection_boxes"
],
validation_logs2
[
"coco_metric"
][
1
][
"detection_boxes"
])
self
.
assertAllEqual
(
validation_logs1
[
"coco_metric"
][
1
][
"detection_scores"
],
validation_logs2
[
"coco_metric"
][
1
][
"detection_scores"
])
self
.
assertAllEqual
(
validation_logs1
[
"coco_metric"
][
1
][
"detection_classes"
],
validation_logs2
[
"coco_metric"
][
1
][
"detection_classes"
])
for
weight1
,
weight2
in
zip
(
weights1
,
weights2
):
self
.
assertAllEqual
(
weight1
,
weight2
)
if
__name__
==
"__main__"
:
tf
.
config
.
experimental
.
enable_op_determinism
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment