Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
dcuai
dlexamples
Commits
cf66c525
Commit
cf66c525
authored
Apr 15, 2022
by
qianyj
Browse files
update some TF file
parent
6b6f8b0c
Changes
264
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2287 additions
and
0 deletions
+2287
-0
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/core/input_reader.py
...alidation/ResNet50_Official/official/core/input_reader.py
+478
-0
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/core/registry.py
...cy_Validation/ResNet50_Official/official/core/registry.py
+101
-0
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/core/registry_test.py
...lidation/ResNet50_Official/official/core/registry_test.py
+88
-0
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/core/task_factory.py
...alidation/ResNet50_Official/official/core/task_factory.py
+70
-0
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/core/test_utils.py
..._Validation/ResNet50_Official/official/core/test_utils.py
+59
-0
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/core/train_lib.py
...y_Validation/ResNet50_Official/official/core/train_lib.py
+150
-0
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/core/train_lib_test.py
...idation/ResNet50_Official/official/core/train_lib_test.py
+225
-0
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/core/train_utils.py
...Validation/ResNet50_Official/official/core/train_utils.py
+478
-0
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/core/train_utils_test.py
...ation/ResNet50_Official/official/core/train_utils_test.py
+98
-0
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/modeling/__init__.py
...alidation/ResNet50_Official/official/modeling/__init__.py
+14
-0
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/modeling/activations/__init__.py
...sNet50_Official/official/modeling/activations/__init__.py
+21
-0
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/modeling/activations/gelu.py
...n/ResNet50_Official/official/modeling/activations/gelu.py
+32
-0
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/modeling/activations/gelu_test.py
...Net50_Official/official/modeling/activations/gelu_test.py
+34
-0
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/modeling/activations/relu.py
...n/ResNet50_Official/official/modeling/activations/relu.py
+31
-0
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/modeling/activations/relu_test.py
...Net50_Official/official/modeling/activations/relu_test.py
+35
-0
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/modeling/activations/sigmoid.py
...esNet50_Official/official/modeling/activations/sigmoid.py
+31
-0
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/modeling/activations/sigmoid_test.py
...50_Official/official/modeling/activations/sigmoid_test.py
+40
-0
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/modeling/activations/swish.py
.../ResNet50_Official/official/modeling/activations/swish.py
+72
-0
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/modeling/activations/swish_test.py
...et50_Official/official/modeling/activations/swish_test.py
+44
-0
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/modeling/fast_training/experimental/tf2_utils_2x_wide.py
.../modeling/fast_training/experimental/tf2_utils_2x_wide.py
+186
-0
No files found.
Too many changes to show.
To preserve performance only
264 of 264+
files are displayed.
Plain diff
Email patch
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/core/input_reader.py
0 → 100644
View file @
cf66c525
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A common dataset reader."""
import
random
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Text
,
Union
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow_datasets
as
tfds
from
official.core
import
config_definitions
as
cfg
def
_get_random_integer
():
return
random
.
randint
(
0
,
(
1
<<
31
)
-
1
)
def
_maybe_map_fn
(
dataset
:
tf
.
data
.
Dataset
,
fn
:
Optional
[
Callable
[...,
Any
]]
=
None
)
->
tf
.
data
.
Dataset
:
"""Calls dataset.map if a valid function is passed in."""
return
dataset
if
fn
is
None
else
dataset
.
map
(
fn
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
def
match_files
(
input_path
:
Union
[
Sequence
[
str
],
str
])
->
List
[
str
]:
"""Matches files from an input_path."""
matched_files
=
[]
# Read dataset from files.
usage
=
(
'`input_path` should be either (1) a str indicating a file '
'path/pattern, or (2) a str indicating multiple file '
'paths/patterns separated by comma (e.g "a, b, c" or no spaces '
'"a,b,c", or (3) a list of str, each of which is a file '
'path/pattern or multiple file paths/patterns separated by '
'comma, but got: %s'
)
if
isinstance
(
input_path
,
str
):
input_path_list
=
[
input_path
]
elif
isinstance
(
input_path
,
(
list
,
tuple
)):
if
any
(
not
isinstance
(
x
,
str
)
for
x
in
input_path
):
raise
ValueError
(
usage
%
input_path
)
input_path_list
=
input_path
else
:
raise
ValueError
(
usage
%
input_path
)
for
input_path
in
input_path_list
:
input_patterns
=
input_path
.
strip
().
split
(
','
)
for
input_pattern
in
input_patterns
:
input_pattern
=
input_pattern
.
strip
()
if
not
input_pattern
:
continue
if
'*'
in
input_pattern
or
'?'
in
input_pattern
:
tmp_matched_files
=
tf
.
io
.
gfile
.
glob
(
input_pattern
)
if
not
tmp_matched_files
:
raise
ValueError
(
'%s does not match any files.'
%
input_pattern
)
matched_files
.
extend
(
tmp_matched_files
)
else
:
matched_files
.
append
(
input_pattern
)
if
not
matched_files
:
raise
ValueError
(
'%s does not match any files.'
%
input_path
)
return
matched_files
def
_read_files_then_shard
(
matched_files
:
List
[
str
],
dataset_fn
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
,
sharding
:
bool
=
False
,
repeat
:
bool
=
False
)
->
tf
.
data
.
Dataset
:
"""Sends all data files to every worker and then shard by data."""
dataset
=
dataset_fn
(
matched_files
)
# When `input_file` is a path to a single file or the number of files is
# less than the number of input pipelines, disable auto sharding
# so that same input file is sent to all workers.
options
=
tf
.
data
.
Options
()
options
.
experimental_distribute
.
auto_shard_policy
=
(
tf
.
data
.
experimental
.
AutoShardPolicy
.
OFF
)
dataset
=
dataset
.
with_options
(
options
)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if
sharding
and
input_context
and
(
input_context
.
num_input_pipelines
>
1
):
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
input_context
.
input_pipeline_id
)
if
repeat
:
dataset
=
dataset
.
repeat
()
return
dataset
def
_shard_files_then_read
(
matched_files
:
List
[
str
],
dataset_fn
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
,
seed
:
Optional
[
Union
[
int
,
tf
.
Tensor
]]
=
None
,
is_training
:
bool
=
False
,
sharding
:
bool
=
False
,
cache
:
bool
=
False
,
cycle_length
:
Optional
[
int
]
=
None
,
block_length
:
Optional
[
int
]
=
None
,
deterministic
:
bool
=
False
)
->
tf
.
data
.
Dataset
:
"""Shards the data files and then sent a split to every worker to read."""
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
matched_files
)
# Shuffle and repeat at file level.
# If cache is enabled, `reshuffle_each_iteration` is set to False,
# because we will read the same cached data in every iteration anyway.
if
is_training
:
# We need a seed to shuffle the files so that when each TPU workers gets
# its own shard the files do not overlap.
if
sharding
and
seed
is
None
:
seed
=
_get_random_integer
()
dataset
=
dataset
.
shuffle
(
len
(
matched_files
),
seed
=
seed
,
reshuffle_each_iteration
=
True
if
not
cache
else
False
)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if
sharding
and
input_context
and
(
input_context
.
num_input_pipelines
>
1
):
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
input_context
.
input_pipeline_id
)
# If cache is enabled, we will call `repeat()` later after `cache()`.
if
is_training
and
not
cache
:
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
interleave
(
map_func
=
dataset_fn
,
cycle_length
=
cycle_length
,
block_length
=
block_length
,
num_parallel_calls
=
(
cycle_length
if
cycle_length
else
tf
.
data
.
experimental
.
AUTOTUNE
),
deterministic
=
deterministic
)
return
dataset
def
_read_tfds
(
tfds_builder
:
tfds
.
core
.
DatasetBuilder
,
tfds_split
:
Text
,
tfds_skip_decoding_feature
:
Text
,
tfds_as_supervised
:
bool
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
,
seed
:
Optional
[
Union
[
int
,
tf
.
Tensor
]]
=
None
,
is_training
:
bool
=
False
,
cache
:
bool
=
False
,
cycle_length
:
Optional
[
int
]
=
None
,
block_length
:
Optional
[
int
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Reads a dataset from tfds."""
# No op if exist.
tfds_builder
.
download_and_prepare
()
read_config
=
tfds
.
ReadConfig
(
interleave_cycle_length
=
cycle_length
,
interleave_block_length
=
block_length
,
input_context
=
input_context
,
shuffle_seed
=
seed
)
decoders
=
{}
if
tfds_skip_decoding_feature
:
for
skip_feature
in
tfds_skip_decoding_feature
.
split
(
','
):
decoders
[
skip_feature
.
strip
()]
=
tfds
.
decode
.
SkipDecoding
()
dataset
=
tfds_builder
.
as_dataset
(
split
=
tfds_split
,
shuffle_files
=
is_training
,
as_supervised
=
tfds_as_supervised
,
decoders
=
decoders
,
read_config
=
read_config
)
if
is_training
and
not
cache
:
dataset
=
dataset
.
repeat
()
return
dataset
class
InputReader
:
"""Input reader that returns a tf.data.Dataset instance."""
# A static random number which is the same across different InputReader
# instances.
static_randnum
=
_get_random_integer
()
def
__init__
(
self
,
params
:
cfg
.
DataConfig
,
dataset_fn
=
tf
.
data
.
TFRecordDataset
,
decoder_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
,
combine_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
,
sample_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
,
parser_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
,
transform_and_batch_fn
:
Optional
[
Callable
[
[
tf
.
data
.
Dataset
,
Optional
[
tf
.
distribute
.
InputContext
]],
tf
.
data
.
Dataset
]]
=
None
,
postprocess_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
):
"""Initializes an InputReader instance.
Args:
params: A config_definitions.DataConfig object.
dataset_fn: A `tf.data.Dataset` that consumes the input files. For
example, it can be `tf.data.TFRecordDataset`.
decoder_fn: An optional `callable` that takes the serialized data string
and decodes them into the raw tensor dictionary.
combine_fn: An optional `callable` that takes a dictionarty of
`tf.data.Dataset` objects as input and outputs a combined dataset. It
will be executed after the decoder_fn and before the sample_fn.
sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as
input and outputs the transformed dataset. It performs sampling on the
decoded raw tensors dict before the parser_fn.
parser_fn: An optional `callable` that takes the decoded raw tensors dict
and parse them into a dictionary of tensors that can be consumed by the
model. It will be executed after decoder_fn.
transform_and_batch_fn: An optional `callable` that takes a
`tf.data.Dataset` object and an optional `tf.distribute.InputContext` as
input, and returns a `tf.data.Dataset` object. It will be executed after
`parser_fn` to transform and batch the dataset; if None, after
`parser_fn` is executed, the dataset will be batched into per-replica
batch size.
postprocess_fn: A optional `callable` that processes batched tensors. It
will be executed after batching.
"""
if
params
.
input_path
and
params
.
tfds_name
:
raise
ValueError
(
'At most one of `input_path` and `tfds_name` can be '
'specified, but got %s and %s.'
%
(
params
.
input_path
,
params
.
tfds_name
))
if
isinstance
(
params
.
input_path
,
cfg
.
base_config
.
Config
)
and
combine_fn
is
None
:
raise
ValueError
(
'A `combine_fn` is required if the `input_path` is a dictionary.'
)
self
.
_tfds_builder
=
None
self
.
_matched_files
=
None
if
not
params
.
input_path
:
# Read dataset from TFDS.
if
not
params
.
tfds_split
:
raise
ValueError
(
'`tfds_name` is %s, but `tfds_split` is not specified.'
%
params
.
tfds_name
)
self
.
_tfds_builder
=
tfds
.
builder
(
params
.
tfds_name
,
data_dir
=
params
.
tfds_data_dir
)
else
:
self
.
_matched_files
=
self
.
get_files
(
params
.
input_path
)
self
.
_global_batch_size
=
params
.
global_batch_size
self
.
_is_training
=
params
.
is_training
self
.
_drop_remainder
=
params
.
drop_remainder
self
.
_shuffle_buffer_size
=
params
.
shuffle_buffer_size
self
.
_cache
=
params
.
cache
self
.
_cycle_length
=
params
.
cycle_length
self
.
_block_length
=
params
.
block_length
self
.
_deterministic
=
params
.
deterministic
self
.
_sharding
=
params
.
sharding
self
.
_tfds_split
=
params
.
tfds_split
self
.
_tfds_as_supervised
=
params
.
tfds_as_supervised
self
.
_tfds_skip_decoding_feature
=
params
.
tfds_skip_decoding_feature
self
.
_dataset_fn
=
dataset_fn
self
.
_decoder_fn
=
decoder_fn
self
.
_combine_fn
=
combine_fn
self
.
_sample_fn
=
sample_fn
self
.
_parser_fn
=
parser_fn
self
.
_transform_and_batch_fn
=
transform_and_batch_fn
self
.
_postprocess_fn
=
postprocess_fn
self
.
_seed
=
params
.
seed
# When tf.data service is enabled, each data service worker should get
# different random seeds. Thus, we set `seed` to None.
# Sharding should also be disabled because tf data service handles how
# each worker shard data with `processing_mode` in distribute method.
if
params
.
enable_tf_data_service
:
self
.
_seed
=
None
self
.
_sharding
=
False
self
.
_enable_tf_data_service
=
(
params
.
enable_tf_data_service
and
params
.
tf_data_service_address
)
self
.
_tf_data_service_address
=
params
.
tf_data_service_address
if
self
.
_enable_tf_data_service
:
# Add a random seed as the tf.data service job name suffix, so tf.data
# service doesn't reuse the previous state if TPU worker gets preempted.
self
.
_tf_data_service_job_name
=
(
params
.
tf_data_service_job_name
+
str
(
self
.
static_randnum
))
self
.
_enable_round_robin_tf_data_service
=
params
.
get
(
'enable_round_robin_tf_data_service'
,
False
)
@
property
def
tfds_info
(
self
)
->
tfds
.
core
.
DatasetInfo
:
"""Returns TFDS dataset info, if available."""
if
self
.
_tfds_builder
:
return
self
.
_tfds_builder
.
info
else
:
raise
ValueError
(
'tfds_info is not available, because the dataset '
'is not loaded from tfds.'
)
def
get_files
(
self
,
input_path
):
"""Gets matched files. Can be overridden by subclasses."""
if
not
input_path
:
return
None
# we want to combine / mix datasets
if
isinstance
(
input_path
,
cfg
.
base_config
.
Config
):
matched_files
=
{}
for
k
,
v
in
input_path
.
as_dict
().
items
():
matched_files
[
k
]
=
match_files
(
v
)
# single dataset
else
:
matched_files
=
match_files
(
input_path
)
return
matched_files
def
_read_data_source
(
self
,
matched_files
:
Union
[
Dict
[
str
,
List
[
str
]],
List
[
str
]],
dataset_fn
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
,
tfds_builder
:
Optional
[
tfds
.
core
.
DatasetBuilder
]
=
None
):
"""Reads the data source (files/tfds) to a dataset."""
def
_files_to_dataset
(
files
:
List
[
str
])
->
tf
.
data
.
Dataset
:
if
len
(
files
)
>
1
:
if
input_context
and
(
len
(
files
)
<
input_context
.
num_input_pipelines
):
logging
.
warn
(
'The number of files %d is less than the number of input pipelines '
'%d. We will send all input files to every worker. '
'Please consider sharding your data into more files.'
,
len
(
files
),
input_context
.
num_input_pipelines
)
return
_read_files_then_shard
(
files
,
dataset_fn
,
input_context
,
sharding
=
self
.
_sharding
,
repeat
=
self
.
_is_training
and
not
self
.
_cache
)
else
:
return
_shard_files_then_read
(
files
,
dataset_fn
,
input_context
,
seed
=
self
.
_seed
,
is_training
=
self
.
_is_training
,
sharding
=
self
.
_sharding
,
cache
=
self
.
_cache
,
cycle_length
=
self
.
_cycle_length
,
block_length
=
self
.
_block_length
,
deterministic
=
self
.
_deterministic
)
elif
len
(
files
)
==
1
:
return
_read_files_then_shard
(
files
,
dataset_fn
,
input_context
,
sharding
=
self
.
_sharding
,
repeat
=
self
.
_is_training
and
not
self
.
_cache
)
else
:
raise
ValueError
(
'It is unexpected that `tfds_builder` is None and '
'there is also no `files`.'
)
if
tfds_builder
:
dataset
=
_read_tfds
(
tfds_builder
=
self
.
_tfds_builder
,
tfds_split
=
self
.
_tfds_split
,
tfds_skip_decoding_feature
=
self
.
_tfds_skip_decoding_feature
,
tfds_as_supervised
=
self
.
_tfds_as_supervised
,
input_context
=
input_context
,
seed
=
self
.
_seed
,
is_training
=
self
.
_is_training
,
cache
=
self
.
_cache
,
cycle_length
=
self
.
_cycle_length
,
block_length
=
self
.
_block_length
)
elif
isinstance
(
matched_files
,
(
list
,
tuple
)):
dataset
=
_files_to_dataset
(
matched_files
)
elif
isinstance
(
matched_files
,
dict
):
dataset
=
{}
for
k
,
fs
in
matched_files
.
items
():
dataset
[
k
]
=
_files_to_dataset
(
fs
)
else
:
raise
ValueError
(
'`matched_files` should be a list or dict.'
)
return
dataset
def
_decode_and_parse_dataset
(
self
,
dataset
:
Union
[
tf
.
data
.
Dataset
,
Dict
[
Text
,
tf
.
data
.
Dataset
]],
batch_size
:
int
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Returns a tf.data.Dataset object after shuffling, decoding, and parsing."""
def
_shuffle_and_decode
(
ds
):
# If cache is enabled, we will call `shuffle()` later after `cache()`.
if
self
.
_is_training
and
not
self
.
_cache
:
ds
=
ds
.
shuffle
(
self
.
_shuffle_buffer_size
,
seed
=
self
.
_seed
)
# Decode
ds
=
_maybe_map_fn
(
ds
,
self
.
_decoder_fn
)
return
ds
dataset
=
tf
.
nest
.
map_structure
(
_shuffle_and_decode
,
dataset
)
if
tf
.
nest
.
is_nested
(
dataset
):
dataset
=
self
.
_combine_fn
(
dataset
)
if
self
.
_sample_fn
is
not
None
:
dataset
=
dataset
.
apply
(
self
.
_sample_fn
)
dataset
=
_maybe_map_fn
(
dataset
,
self
.
_parser_fn
)
if
self
.
_cache
:
dataset
=
dataset
.
cache
()
if
self
.
_is_training
:
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
shuffle
(
self
.
_shuffle_buffer_size
,
seed
=
self
.
_seed
)
if
self
.
_transform_and_batch_fn
is
not
None
:
dataset
=
self
.
_transform_and_batch_fn
(
dataset
,
input_context
)
else
:
per_replica_batch_size
=
input_context
.
get_per_replica_batch_size
(
batch_size
)
if
input_context
else
batch_size
dataset
=
dataset
.
batch
(
per_replica_batch_size
,
drop_remainder
=
self
.
_drop_remainder
)
return
dataset
def
_maybe_apply_data_service
(
self
,
dataset
:
tf
.
data
.
Dataset
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Potentially distributes a dataset."""
if
self
.
_enable_tf_data_service
and
input_context
:
if
self
.
_enable_round_robin_tf_data_service
:
replicas_per_input_pipeline
=
input_context
.
num_replicas_in_sync
//
(
input_context
.
num_input_pipelines
)
base_consumer_index
=
input_context
.
input_pipeline_id
*
(
replicas_per_input_pipeline
)
num_consumers
=
input_context
.
num_input_pipelines
*
(
replicas_per_input_pipeline
)
range_dataset
=
tf
.
data
.
Dataset
.
range
(
replicas_per_input_pipeline
)
dataset
=
range_dataset
.
map
(
lambda
i
:
dataset
.
apply
(
# pylint: disable=g-long-lambda
tf
.
data
.
experimental
.
service
.
distribute
(
processing_mode
=
'parallel_epochs'
,
service
=
self
.
_tf_data_service_address
,
job_name
=
self
.
_tf_data_service_job_name
,
consumer_index
=
base_consumer_index
+
i
,
num_consumers
=
num_consumers
)))
# Use parallel interleave to read multiple batches from a tf.data
# service worker in parallel.
dataset
=
dataset
.
interleave
(
lambda
x
:
x
,
cycle_length
=
replicas_per_input_pipeline
,
num_parallel_calls
=
replicas_per_input_pipeline
,
deterministic
=
True
)
else
:
dataset
=
dataset
.
apply
(
tf
.
data
.
experimental
.
service
.
distribute
(
processing_mode
=
'parallel_epochs'
,
service
=
self
.
_tf_data_service_address
,
job_name
=
self
.
_tf_data_service_job_name
))
return
dataset
def
read
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
,
dataset
:
Optional
[
tf
.
data
.
Dataset
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Generates a tf.data.Dataset object."""
if
dataset
is
None
:
dataset
=
self
.
_read_data_source
(
self
.
_matched_files
,
self
.
_dataset_fn
,
input_context
,
self
.
_tfds_builder
)
dataset
=
self
.
_decode_and_parse_dataset
(
dataset
,
self
.
_global_batch_size
,
input_context
)
dataset
=
_maybe_map_fn
(
dataset
,
self
.
_postprocess_fn
)
dataset
=
self
.
_maybe_apply_data_service
(
dataset
,
input_context
)
if
self
.
_deterministic
is
not
None
:
options
=
tf
.
data
.
Options
()
options
.
experimental_deterministic
=
self
.
_deterministic
dataset
=
dataset
.
with_options
(
options
)
return
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/core/registry.py
0 → 100644
View file @
cf66c525
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Registry utility."""
def
register
(
registered_collection
,
reg_key
):
"""Register decorated function or class to collection.
Register decorated function or class into registered_collection, in a
hierarchical order. For example, when reg_key="my_model/my_exp/my_config_0"
the decorated function or class is stored under
registered_collection["my_model"]["my_exp"]["my_config_0"].
This decorator is supposed to be used together with the lookup() function in
this file.
Args:
registered_collection: a dictionary. The decorated function or class will be
put into this collection.
reg_key: The key for retrieving the registered function or class. If reg_key
is a string, it can be hierarchical like my_model/my_exp/my_config_0
Returns:
A decorator function
Raises:
KeyError: when function or class to register already exists.
"""
def
decorator
(
fn_or_cls
):
"""Put fn_or_cls in the dictionary."""
if
isinstance
(
reg_key
,
str
):
hierarchy
=
reg_key
.
split
(
"/"
)
collection
=
registered_collection
for
h_idx
,
entry_name
in
enumerate
(
hierarchy
[:
-
1
]):
if
entry_name
not
in
collection
:
collection
[
entry_name
]
=
{}
collection
=
collection
[
entry_name
]
if
not
isinstance
(
collection
,
dict
):
raise
KeyError
(
"Collection path {} at position {} already registered as "
"a function or class."
.
format
(
entry_name
,
h_idx
))
leaf_reg_key
=
hierarchy
[
-
1
]
else
:
collection
=
registered_collection
leaf_reg_key
=
reg_key
if
leaf_reg_key
in
collection
:
raise
KeyError
(
"Function or class {} registered multiple times."
.
format
(
leaf_reg_key
))
collection
[
leaf_reg_key
]
=
fn_or_cls
return
fn_or_cls
return
decorator
def
lookup
(
registered_collection
,
reg_key
):
"""Lookup and return decorated function or class in the collection.
Lookup decorated function or class in registered_collection, in a
hierarchical order. For example, when
reg_key="my_model/my_exp/my_config_0",
this function will return
registered_collection["my_model"]["my_exp"]["my_config_0"].
Args:
registered_collection: a dictionary. The decorated function or class will be
retrieved from this collection.
reg_key: The key for retrieving the registered function or class. If reg_key
is a string, it can be hierarchical like my_model/my_exp/my_config_0
Returns:
The registered function or class.
Raises:
LookupError: when reg_key cannot be found.
"""
if
isinstance
(
reg_key
,
str
):
hierarchy
=
reg_key
.
split
(
"/"
)
collection
=
registered_collection
for
h_idx
,
entry_name
in
enumerate
(
hierarchy
):
if
entry_name
not
in
collection
:
raise
LookupError
(
f
"collection path
{
entry_name
}
at position
{
h_idx
}
is never "
f
"registered. Please make sure the
{
entry_name
}
and its library is "
"imported and linked to the trainer binary."
)
collection
=
collection
[
entry_name
]
return
collection
else
:
if
reg_key
not
in
registered_collection
:
raise
LookupError
(
f
"registration key
{
reg_key
}
is never "
f
"registered. Please make sure the
{
reg_key
}
and its library is "
"imported and linked to the trainer binary."
)
return
registered_collection
[
reg_key
]
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/core/registry_test.py
0 → 100644
View file @
cf66c525
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for registry."""
import
tensorflow
as
tf
from
official.core
import
registry
class
RegistryTest
(
tf
.
test
.
TestCase
):
def
test_register
(
self
):
collection
=
{}
@
registry
.
register
(
collection
,
'functions/func_0'
)
def
func_test
():
pass
self
.
assertEqual
(
registry
.
lookup
(
collection
,
'functions/func_0'
),
func_test
)
@
registry
.
register
(
collection
,
'classes/cls_0'
)
class
ClassRegistryKey
:
pass
self
.
assertEqual
(
registry
.
lookup
(
collection
,
'classes/cls_0'
),
ClassRegistryKey
)
@
registry
.
register
(
collection
,
ClassRegistryKey
)
class
ClassRegistryValue
:
pass
self
.
assertEqual
(
registry
.
lookup
(
collection
,
ClassRegistryKey
),
ClassRegistryValue
)
def
test_register_hierarchy
(
self
):
collection
=
{}
@
registry
.
register
(
collection
,
'functions/func_0'
)
def
func_test0
():
pass
@
registry
.
register
(
collection
,
'func_1'
)
def
func_test1
():
pass
@
registry
.
register
(
collection
,
func_test1
)
def
func_test2
():
pass
expected_collection
=
{
'functions'
:
{
'func_0'
:
func_test0
,
},
'func_1'
:
func_test1
,
func_test1
:
func_test2
,
}
self
.
assertEqual
(
collection
,
expected_collection
)
def
test_register_error
(
self
):
collection
=
{}
@
registry
.
register
(
collection
,
'functions/func_0'
)
def
func_test0
():
# pylint: disable=unused-variable
pass
with
self
.
assertRaises
(
KeyError
):
@
registry
.
register
(
collection
,
'functions/func_0/sub_func'
)
def
func_test1
():
# pylint: disable=unused-variable
pass
with
self
.
assertRaises
(
LookupError
):
registry
.
lookup
(
collection
,
'non-exist'
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/core/task_factory.py
0 → 100644
View file @
cf66c525
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A global factory to register and access all registered tasks."""
from
official.core
import
registry
_REGISTERED_TASK_CLS
=
{}
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def
register_task_cls
(
task_config_cls
):
"""Decorates a factory of Tasks for lookup by a subclass of TaskConfig.
This decorator supports registration of tasks as follows:
```
@dataclasses.dataclass
class MyTaskConfig(TaskConfig):
# Add fields here.
pass
@register_task_cls(MyTaskConfig)
class MyTask(Task):
# Inherits def __init__(self, task_config).
pass
my_task_config = MyTaskConfig()
my_task = get_task(my_task_config) # Returns MyTask(my_task_config).
```
Besisdes a class itself, other callables that create a Task from a TaskConfig
can be decorated by the result of this function, as long as there is at most
one registration for each config class.
Args:
task_config_cls: a subclass of TaskConfig (*not* an instance of TaskConfig).
Each task_config_cls can only be used for a single registration.
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of task_config_cls.
"""
return
registry
.
register
(
_REGISTERED_TASK_CLS
,
task_config_cls
)
def
get_task
(
task_config
,
**
kwargs
):
"""Creates a Task (of suitable subclass type) from task_config."""
# TODO(hongkuny): deprecate the task factory to use config.BUILDER.
if
task_config
.
BUILDER
is
not
None
:
return
task_config
.
BUILDER
(
task_config
,
**
kwargs
)
return
get_task_cls
(
task_config
.
__class__
)(
task_config
,
**
kwargs
)
# The user-visible get_task() is defined after classes have been registered.
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def
get_task_cls
(
task_config_cls
):
task_cls
=
registry
.
lookup
(
_REGISTERED_TASK_CLS
,
task_config_cls
)
return
task_cls
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/core/test_utils.py
0 → 100644
View file @
cf66c525
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for testing."""
import
tensorflow
as
tf
class
FakeKerasModel
(
tf
.
keras
.
Model
):
"""Fake keras model for testing."""
def
__init__
(
self
):
super
().
__init__
()
self
.
dense
=
tf
.
keras
.
layers
.
Dense
(
4
,
activation
=
tf
.
nn
.
relu
)
self
.
dense2
=
tf
.
keras
.
layers
.
Dense
(
4
,
activation
=
tf
.
nn
.
relu
)
def
call
(
self
,
inputs
):
return
self
.
dense2
(
self
.
dense
(
inputs
))
class
_Dense
(
tf
.
Module
):
"""A dense layer."""
def
__init__
(
self
,
input_dim
,
output_size
,
name
=
None
):
super
().
__init__
(
name
=
name
)
with
self
.
name_scope
:
self
.
w
=
tf
.
Variable
(
tf
.
random
.
normal
([
input_dim
,
output_size
]),
name
=
'w'
)
self
.
b
=
tf
.
Variable
(
tf
.
zeros
([
output_size
]),
name
=
'b'
)
@
tf
.
Module
.
with_name_scope
def
__call__
(
self
,
x
):
y
=
tf
.
matmul
(
x
,
self
.
w
)
+
self
.
b
return
tf
.
nn
.
relu
(
y
)
class
FakeModule
(
tf
.
Module
):
"""Fake model using tf.Module for testing."""
def
__init__
(
self
,
input_size
,
name
=
None
):
super
().
__init__
(
name
=
name
)
with
self
.
name_scope
:
self
.
dense
=
_Dense
(
input_size
,
4
,
name
=
'dense'
)
self
.
dense2
=
_Dense
(
4
,
4
,
name
=
'dense_1'
)
@
tf
.
Module
.
with_name_scope
def
__call__
(
self
,
x
):
return
self
.
dense2
(
self
.
dense
(
x
))
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/core/train_lib.py
0 → 100644
View file @
cf66c525
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TFM common training driver library."""
# pytype: disable=attribute-error
import
os
from
typing
import
Any
,
Mapping
,
Optional
,
Tuple
# Import libraries
from
absl
import
logging
import
orbit
import
tensorflow
as
tf
from
official.core
import
actions
from
official.core
import
base_task
from
official.core
import
base_trainer
from
official.core
import
config_definitions
from
official.core
import
train_utils
maybe_create_best_ckpt_exporter
=
train_utils
.
maybe_create_best_ckpt_exporter
def
run_experiment
(
distribution_strategy
:
tf
.
distribute
.
Strategy
,
task
:
base_task
.
Task
,
mode
:
str
,
params
:
config_definitions
.
ExperimentConfig
,
model_dir
:
str
,
run_post_eval
:
bool
=
False
,
save_summary
:
bool
=
True
,
trainer
:
Optional
[
base_trainer
.
Trainer
]
=
None
,
controller_cls
=
orbit
.
Controller
)
->
Tuple
[
tf
.
keras
.
Model
,
Mapping
[
str
,
Any
]]:
"""Runs train/eval configured by the experiment params.
Args:
distribution_strategy: A distribution distribution_strategy.
task: A Task instance.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
trainer: the base_trainer.Trainer instance. It should be created within the
strategy.scope().
controller_cls: The controller class to manage the train and eval process.
Must be a orbit.Controller subclass.
Returns:
A 2-tuple of (model, eval_logs).
model: `tf.keras.Model` instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
"""
with
distribution_strategy
.
scope
():
if
not
trainer
:
trainer
=
train_utils
.
create_trainer
(
params
,
task
,
train
=
'train'
in
mode
,
evaluate
=
(
'eval'
in
mode
)
or
run_post_eval
,
checkpoint_exporter
=
maybe_create_best_ckpt_exporter
(
params
,
model_dir
))
if
trainer
.
checkpoint
:
if
model_dir
is
None
:
raise
ValueError
(
'model_dir must be specified, but got None'
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
trainer
.
checkpoint
,
directory
=
model_dir
,
max_to_keep
=
params
.
trainer
.
max_to_keep
,
step_counter
=
trainer
.
global_step
,
checkpoint_interval
=
params
.
trainer
.
checkpoint_interval
,
init_fn
=
trainer
.
initialize
)
else
:
checkpoint_manager
=
None
controller
=
controller_cls
(
strategy
=
distribution_strategy
,
trainer
=
trainer
if
'train'
in
mode
else
None
,
evaluator
=
trainer
,
global_step
=
trainer
.
global_step
,
steps_per_loop
=
params
.
trainer
.
steps_per_loop
,
checkpoint_manager
=
checkpoint_manager
,
summary_dir
=
os
.
path
.
join
(
model_dir
,
'train'
)
if
(
save_summary
)
else
None
,
eval_summary_dir
=
os
.
path
.
join
(
model_dir
,
params
.
trainer
.
validation_summary_subdir
)
if
(
save_summary
)
else
None
,
summary_interval
=
params
.
trainer
.
summary_interval
if
(
save_summary
)
else
None
,
train_actions
=
actions
.
get_train_actions
(
params
,
trainer
,
model_dir
,
checkpoint_manager
=
checkpoint_manager
),
eval_actions
=
actions
.
get_eval_actions
(
params
,
trainer
,
model_dir
))
logging
.
info
(
'Starts to execute mode: %s'
,
mode
)
with
distribution_strategy
.
scope
():
if
mode
==
'train'
:
controller
.
train
(
steps
=
params
.
trainer
.
train_steps
)
elif
mode
==
'train_and_eval'
:
controller
.
train_and_evaluate
(
train_steps
=
params
.
trainer
.
train_steps
,
eval_steps
=
params
.
trainer
.
validation_steps
,
eval_interval
=
params
.
trainer
.
validation_interval
)
elif
mode
==
'eval'
:
controller
.
evaluate
(
steps
=
params
.
trainer
.
validation_steps
)
elif
mode
==
'continuous_eval'
:
def
timeout_fn
():
if
trainer
.
global_step
.
numpy
()
>=
params
.
trainer
.
train_steps
:
return
True
return
False
controller
.
evaluate_continuously
(
steps
=
params
.
trainer
.
validation_steps
,
timeout
=
params
.
trainer
.
continuous_eval_timeout
,
timeout_fn
=
timeout_fn
)
else
:
raise
NotImplementedError
(
'The mode is not implemented: %s'
%
mode
)
num_params
=
train_utils
.
try_count_params
(
trainer
.
model
)
if
num_params
is
not
None
:
logging
.
info
(
'Number of trainable params in model: %f Millions.'
,
num_params
/
10.
**
6
)
flops
=
train_utils
.
try_count_flops
(
trainer
.
model
)
if
flops
is
not
None
:
logging
.
info
(
'FLOPs (multi-adds) in model: %f Billions.'
,
flops
/
10.
**
9
/
2
)
if
run_post_eval
:
with
distribution_strategy
.
scope
():
return
trainer
.
model
,
trainer
.
evaluate
(
tf
.
convert_to_tensor
(
params
.
trainer
.
validation_steps
))
else
:
return
trainer
.
model
,
{}
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/core/train_lib_test.py
0 → 100644
View file @
cf66c525
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for train_ctl_lib."""
import
json
import
os
from
absl
import
flags
from
absl.testing
import
flagsaver
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.common
import
flags
as
tfm_flags
# pylint: disable=unused-import
from
official.common
import
registry_imports
# pylint: enable=unused-import
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.utils.testing
import
mock_task
FLAGS
=
flags
.
FLAGS
tfm_flags
.
define_flags
()
class
TrainTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
TrainTest
,
self
).
setUp
()
self
.
_test_config
=
{
'trainer'
:
{
'checkpoint_interval'
:
10
,
'steps_per_loop'
:
10
,
'summary_interval'
:
10
,
'train_steps'
:
10
,
'validation_steps'
:
5
,
'validation_interval'
:
10
,
'continuous_eval_timeout'
:
1
,
'validation_summary_subdir'
:
'validation'
,
'optimizer_config'
:
{
'optimizer'
:
{
'type'
:
'sgd'
,
},
'learning_rate'
:
{
'type'
:
'constant'
}
}
},
}
@
combinations
.
generate
(
combinations
.
combine
(
distribution_strategy
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
flag_mode
=
[
'train'
,
'eval'
,
'train_and_eval'
],
run_post_eval
=
[
True
,
False
]))
def
test_end_to_end
(
self
,
distribution_strategy
,
flag_mode
,
run_post_eval
):
model_dir
=
self
.
get_temp_dir
()
flags_dict
=
dict
(
experiment
=
'mock'
,
mode
=
flag_mode
,
model_dir
=
model_dir
,
params_override
=
json
.
dumps
(
self
.
_test_config
))
with
flagsaver
.
flagsaver
(
**
flags_dict
):
params
=
train_utils
.
parse_configuration
(
flags
.
FLAGS
)
train_utils
.
serialize_config
(
params
,
model_dir
)
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
_
,
logs
=
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
flag_mode
,
params
=
params
,
model_dir
=
model_dir
,
run_post_eval
=
run_post_eval
)
if
'eval'
in
flag_mode
:
self
.
assertTrue
(
tf
.
io
.
gfile
.
exists
(
os
.
path
.
join
(
model_dir
,
params
.
trainer
.
validation_summary_subdir
)))
if
run_post_eval
:
self
.
assertNotEmpty
(
logs
)
else
:
self
.
assertEmpty
(
logs
)
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
model_dir
,
'params.yaml'
)))
if
flag_mode
==
'eval'
:
return
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
model_dir
,
'checkpoint'
)))
# Tests continuous evaluation.
_
,
logs
=
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
'continuous_eval'
,
params
=
params
,
model_dir
=
model_dir
,
run_post_eval
=
run_post_eval
)
@
combinations
.
generate
(
combinations
.
combine
(
distribution_strategy
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
flag_mode
=
[
'train'
,
'train_and_eval'
],
))
def
test_recovery_nan_error
(
self
,
distribution_strategy
,
flag_mode
):
model_dir
=
self
.
get_temp_dir
()
flags_dict
=
dict
(
experiment
=
'mock'
,
mode
=
flag_mode
,
model_dir
=
model_dir
,
params_override
=
json
.
dumps
(
self
.
_test_config
))
with
flagsaver
.
flagsaver
(
**
flags_dict
):
params
=
train_utils
.
parse_configuration
(
flags
.
FLAGS
)
train_utils
.
serialize_config
(
params
,
model_dir
)
with
distribution_strategy
.
scope
():
# task = task_factory.get_task(params.task, logging_dir=model_dir)
task
=
mock_task
.
MockTask
(
params
.
task
,
logging_dir
=
model_dir
)
# Set the loss to NaN to trigger RunTimeError.
def
build_losses
(
labels
,
model_outputs
,
aux_losses
=
None
):
del
labels
,
model_outputs
return
tf
.
constant
([
np
.
nan
],
tf
.
float32
)
+
aux_losses
task
.
build_losses
=
build_losses
with
self
.
assertRaises
(
RuntimeError
):
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
flag_mode
,
params
=
params
,
model_dir
=
model_dir
)
@
combinations
.
generate
(
combinations
.
combine
(
distribution_strategy
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
flag_mode
=
[
'train'
],
))
def
test_recovery
(
self
,
distribution_strategy
,
flag_mode
):
loss_threshold
=
1.0
model_dir
=
self
.
get_temp_dir
()
flags_dict
=
dict
(
experiment
=
'mock'
,
mode
=
flag_mode
,
model_dir
=
model_dir
,
params_override
=
json
.
dumps
(
self
.
_test_config
))
with
flagsaver
.
flagsaver
(
**
flags_dict
):
params
=
train_utils
.
parse_configuration
(
flags
.
FLAGS
)
params
.
trainer
.
loss_upper_bound
=
loss_threshold
params
.
trainer
.
recovery_max_trials
=
1
train_utils
.
serialize_config
(
params
,
model_dir
)
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
# Saves a checkpoint for reference.
model
=
task
.
build_model
()
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
self
.
get_temp_dir
(),
max_to_keep
=
2
)
checkpoint_manager
.
save
()
before_weights
=
model
.
get_weights
()
def
build_losses
(
labels
,
model_outputs
,
aux_losses
=
None
):
del
labels
,
model_outputs
return
tf
.
constant
([
loss_threshold
],
tf
.
float32
)
+
aux_losses
task
.
build_losses
=
build_losses
model
,
_
=
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
flag_mode
,
params
=
params
,
model_dir
=
model_dir
)
after_weights
=
model
.
get_weights
()
for
left
,
right
in
zip
(
before_weights
,
after_weights
):
self
.
assertAllEqual
(
left
,
right
)
def
test_parse_configuration
(
self
):
model_dir
=
self
.
get_temp_dir
()
flags_dict
=
dict
(
experiment
=
'mock'
,
mode
=
'train'
,
model_dir
=
model_dir
,
params_override
=
json
.
dumps
(
self
.
_test_config
))
with
flagsaver
.
flagsaver
(
**
flags_dict
):
params
=
train_utils
.
parse_configuration
(
flags
.
FLAGS
,
lock_return
=
True
)
with
self
.
assertRaises
(
ValueError
):
params
.
override
({
'task'
:
{
'init_checkpoint'
:
'Foo'
}})
params
=
train_utils
.
parse_configuration
(
flags
.
FLAGS
,
lock_return
=
False
)
params
.
override
({
'task'
:
{
'init_checkpoint'
:
'Bar'
}})
self
.
assertEqual
(
params
.
task
.
init_checkpoint
,
'Bar'
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/core/train_utils.py
0 → 100644
View file @
cf66c525
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training utils."""
import
copy
import
json
import
os
import
pprint
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
from
absl
import
logging
import
dataclasses
import
gin
import
orbit
import
tensorflow
as
tf
# pylint: disable=g-direct-tensorflow-import
from
tensorflow.python.framework.convert_to_constants
import
convert_variables_to_constants_v2_as_graph
# pylint: enable=g-direct-tensorflow-import
from
official.core
import
base_task
from
official.core
import
base_trainer
from
official.core
import
config_definitions
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
def
get_leaf_nested_dict
(
d
:
Dict
[
str
,
Any
],
keys
:
List
[
str
])
->
Dict
[
str
,
Any
]:
"""Get leaf from a dictionary with arbitrary depth with a list of keys.
Args:
d: The dictionary to extract value from.
keys: The list of keys to extract values recursively.
Returns:
The value of the leaf.
Raises:
KeyError: If the value of keys extracted is a dictionary.
"""
leaf
=
d
for
k
in
keys
:
if
not
isinstance
(
leaf
,
dict
)
or
k
not
in
leaf
:
raise
KeyError
(
'Path not exist while traversing the dictionary: d with keys'
': %s.'
%
keys
)
leaf
=
leaf
[
k
]
if
isinstance
(
leaf
,
dict
):
raise
KeyError
(
'The value extracted with keys: %s is not a leaf of the '
'dictionary: %s.'
%
(
keys
,
d
))
return
leaf
def
cast_leaf_nested_dict
(
d
:
Dict
[
str
,
Any
],
cast_fn
:
Callable
[[
Any
],
Any
])
->
Dict
[
str
,
Any
]:
"""Cast the leaves of a dictionary with arbitrary depth in place.
Args:
d: The dictionary to extract value from.
cast_fn: The casting function.
Returns:
A dictionray with the same structure as d.
"""
for
key
,
value
in
d
.
items
():
if
isinstance
(
value
,
dict
):
d
[
key
]
=
cast_leaf_nested_dict
(
value
,
cast_fn
)
else
:
d
[
key
]
=
cast_fn
(
value
)
return
d
def
maybe_create_best_ckpt_exporter
(
params
:
config_definitions
.
ExperimentConfig
,
data_dir
:
str
)
->
Any
:
"""Maybe create a BestCheckpointExporter object, according to the config."""
export_subdir
=
params
.
trainer
.
best_checkpoint_export_subdir
metric_name
=
params
.
trainer
.
best_checkpoint_eval_metric
metric_comp
=
params
.
trainer
.
best_checkpoint_metric_comp
if
data_dir
and
export_subdir
and
metric_name
:
best_ckpt_dir
=
os
.
path
.
join
(
data_dir
,
export_subdir
)
best_ckpt_exporter
=
BestCheckpointExporter
(
best_ckpt_dir
,
metric_name
,
metric_comp
)
logging
.
info
(
'Created the best checkpoint exporter. '
'data_dir: %s, export_subdir: %s, metric_name: %s'
,
data_dir
,
export_subdir
,
metric_name
)
else
:
best_ckpt_exporter
=
None
return
best_ckpt_exporter
# TODO(b/180147589): Add tests for this module.
class
BestCheckpointExporter
:
"""Keeps track of the best result, and saves its checkpoint.
Orbit will support an API for checkpoint exporter. This class will be used
together with orbit once this functionality is ready.
"""
def
__init__
(
self
,
export_dir
:
str
,
metric_name
:
str
,
metric_comp
:
str
):
"""Initialization.
Args:
export_dir: The directory that will contain exported checkpoints.
metric_name: Indicates which metric to look at, when determining which
result is better. If eval_logs being passed to maybe_export_checkpoint
is a nested dictionary, use `|` as a seperator for different layers.
metric_comp: Indicates how to compare results. Either `lower` or `higher`.
"""
self
.
_export_dir
=
export_dir
self
.
_metric_name
=
metric_name
.
split
(
'|'
)
self
.
_metric_comp
=
metric_comp
if
self
.
_metric_comp
not
in
(
'lower'
,
'higher'
):
raise
ValueError
(
'best checkpoint metric comp must be one of '
'higher, lower. Got: {}'
.
format
(
self
.
_metric_comp
))
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
self
.
best_ckpt_logs_path
))
self
.
_best_ckpt_logs
=
self
.
_maybe_load_best_eval_metric
()
self
.
_checkpoint_manager
=
None
def
_get_checkpoint_manager
(
self
,
checkpoint
):
"""Gets an existing checkpoint manager or creates a new one."""
if
self
.
_checkpoint_manager
is
None
or
(
self
.
_checkpoint_manager
.
checkpoint
!=
checkpoint
):
logging
.
info
(
'Creates a new checkpoint manager.'
)
self
.
_checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
directory
=
self
.
_export_dir
,
max_to_keep
=
1
,
checkpoint_name
=
'best_ckpt'
)
return
self
.
_checkpoint_manager
def
maybe_export_checkpoint
(
self
,
checkpoint
,
eval_logs
,
global_step
,
write_logs
=
True
)
->
bool
:
"""Compare eval_logs with past eval_logs and export checkpoint if better."""
logging
.
info
(
'[BestCheckpointExporter] received eval_logs: %s, at step: %d'
,
eval_logs
,
global_step
)
if
self
.
_best_ckpt_logs
is
None
or
self
.
_new_metric_is_better
(
self
.
_best_ckpt_logs
,
eval_logs
):
self
.
_best_ckpt_logs
=
eval_logs
if
write_logs
:
self
.
export_best_eval_metric
(
self
.
_best_ckpt_logs
,
global_step
)
self
.
_get_checkpoint_manager
(
checkpoint
).
save
()
return
True
return
False
def
_maybe_load_best_eval_metric
(
self
):
if
not
tf
.
io
.
gfile
.
exists
(
self
.
best_ckpt_logs_path
):
return
None
with
tf
.
io
.
gfile
.
GFile
(
self
.
best_ckpt_logs_path
,
'r'
)
as
reader
:
return
json
.
loads
(
reader
.
read
())
def
_new_metric_is_better
(
self
,
old_logs
,
new_logs
):
"""Check if the metric in new_logs is better than the metric in old_logs."""
old_value
=
float
(
orbit
.
utils
.
get_value
(
get_leaf_nested_dict
(
old_logs
,
self
.
_metric_name
)))
new_value
=
float
(
orbit
.
utils
.
get_value
(
get_leaf_nested_dict
(
new_logs
,
self
.
_metric_name
)))
logging
.
info
(
'[BestCheckpointExporter] comparing results. old: %f, new: %f'
,
old_value
,
new_value
)
if
self
.
_metric_comp
==
'higher'
:
if
new_value
>
old_value
:
logging
.
info
(
'[BestCheckpointExporter] '
'the new number is better since it is higher.'
)
return
True
else
:
# self._metric_comp == 'lower':
if
new_value
<
old_value
:
logging
.
info
(
'[BestCheckpointExporter] '
'the new number is better since it is lower.'
)
return
True
return
False
def
export_best_eval_metric
(
self
,
eval_logs
,
global_step
):
"""Export evaluation results of the best checkpoint into a json file."""
eval_logs_ext
=
copy
.
copy
(
eval_logs
)
eval_logs_ext
[
'best_ckpt_global_step'
]
=
global_step
eval_logs_ext
=
cast_leaf_nested_dict
(
eval_logs_ext
,
lambda
x
:
float
(
orbit
.
utils
.
get_value
(
x
)))
# Saving json file is very fast.
with
tf
.
io
.
gfile
.
GFile
(
self
.
best_ckpt_logs_path
,
'w'
)
as
writer
:
writer
.
write
(
json
.
dumps
(
eval_logs_ext
,
indent
=
4
)
+
'
\n
'
)
@
property
def
best_ckpt_logs
(
self
):
return
self
.
_best_ckpt_logs
@
property
def
best_ckpt_logs_path
(
self
):
return
os
.
path
.
join
(
self
.
_export_dir
,
'info.json'
)
@
property
def
best_ckpt_path
(
self
):
"""Returns the best ckpt path or None if there is no ckpt yet."""
return
tf
.
train
.
latest_checkpoint
(
self
.
_export_dir
)
@
gin
.
configurable
def
create_trainer
(
params
:
config_definitions
.
ExperimentConfig
,
task
:
base_task
.
Task
,
train
:
bool
,
evaluate
:
bool
,
checkpoint_exporter
:
Optional
[
BestCheckpointExporter
]
=
None
,
trainer_cls
=
base_trainer
.
Trainer
)
->
base_trainer
.
Trainer
:
"""Create trainer."""
logging
.
info
(
'Running default trainer.'
)
model
=
task
.
build_model
()
optimizer
=
task
.
create_optimizer
(
params
.
trainer
.
optimizer_config
,
params
.
runtime
)
return
trainer_cls
(
params
,
task
,
model
=
model
,
optimizer
=
optimizer
,
train
=
train
,
evaluate
=
evaluate
,
checkpoint_exporter
=
checkpoint_exporter
)
@
dataclasses
.
dataclass
class
ParseConfigOptions
:
"""Use this dataclass instead of FLAGS to customize parse_configuration()."""
experiment
:
str
config_file
:
List
[
str
]
tpu
:
str
=
''
tf_data_service
:
str
=
''
params_override
:
str
=
''
def
__contains__
(
self
,
name
):
return
name
in
dataclasses
.
asdict
(
self
)
def
parse_configuration
(
flags_obj
,
lock_return
=
True
,
print_return
=
True
):
"""Parses ExperimentConfig from flags."""
if
flags_obj
.
experiment
is
None
:
raise
ValueError
(
'The flag --experiment must be specified.'
)
# 1. Get the default config from the registered experiment.
params
=
exp_factory
.
get_exp_config
(
flags_obj
.
experiment
)
# 2. Get the first level of override from `--config_file`.
# `--config_file` is typically used as a template that specifies the common
# override for a particular experiment.
for
config_file
in
flags_obj
.
config_file
or
[]:
params
=
hyperparams
.
override_params_dict
(
params
,
config_file
,
is_strict
=
True
)
# 3. Override the TPU address and tf.data service address.
params
.
override
({
'runtime'
:
{
'tpu'
:
flags_obj
.
tpu
,
},
})
if
(
'tf_data_service'
in
flags_obj
and
flags_obj
.
tf_data_service
and
isinstance
(
params
.
task
,
config_definitions
.
TaskConfig
)):
params
.
override
({
'task'
:
{
'train_data'
:
{
'tf_data_service_address'
:
flags_obj
.
tf_data_service
,
},
'validation_data'
:
{
'tf_data_service_address'
:
flags_obj
.
tf_data_service
,
}
}
})
# 4. Get the second level of override from `--params_override`.
# `--params_override` is typically used as a further override over the
# template. For example, one may define a particular template for training
# ResNet50 on ImageNet in a config file and pass it via `--config_file`,
# then define different learning rates and pass it via `--params_override`.
if
flags_obj
.
params_override
:
params
=
hyperparams
.
override_params_dict
(
params
,
flags_obj
.
params_override
,
is_strict
=
True
)
params
.
validate
()
if
lock_return
:
params
.
lock
()
if
print_return
:
pp
=
pprint
.
PrettyPrinter
()
logging
.
info
(
'Final experiment parameters:
\n
%s'
,
pp
.
pformat
(
params
.
as_dict
()))
return
params
def
serialize_config
(
params
:
config_definitions
.
ExperimentConfig
,
model_dir
:
str
):
"""Serializes and saves the experiment config."""
if
model_dir
is
None
:
raise
ValueError
(
'model_dir must be specified, but got None'
)
params_save_path
=
os
.
path
.
join
(
model_dir
,
'params.yaml'
)
logging
.
info
(
'Saving experiment configuration to %s'
,
params_save_path
)
tf
.
io
.
gfile
.
makedirs
(
model_dir
)
hyperparams
.
save_params_dict_to_yaml
(
params
,
params_save_path
)
def
save_gin_config
(
filename_suffix
:
str
,
model_dir
:
str
):
"""Serializes and saves the experiment config."""
gin_save_path
=
os
.
path
.
join
(
model_dir
,
'operative_config.{}.gin'
.
format
(
filename_suffix
))
logging
.
info
(
'Saving gin configurations to %s'
,
gin_save_path
)
tf
.
io
.
gfile
.
makedirs
(
model_dir
)
with
tf
.
io
.
gfile
.
GFile
(
gin_save_path
,
'w'
)
as
f
:
f
.
write
(
gin
.
operative_config_str
())
def
read_global_step_from_checkpoint
(
ckpt_file_path
):
"""Read global step from checkpoint, or get global step from its filename."""
global_step
=
tf
.
Variable
(
-
1
,
dtype
=
tf
.
int64
)
ckpt
=
tf
.
train
.
Checkpoint
(
global_step
=
global_step
)
try
:
ckpt
.
restore
(
ckpt_file_path
).
expect_partial
()
global_step_maybe_restored
=
global_step
.
numpy
()
except
tf
.
errors
.
InvalidArgumentError
:
global_step_maybe_restored
=
-
1
if
global_step_maybe_restored
==
-
1
:
raise
ValueError
(
'global_step not found in checkpoint {}. '
'If you want to run finetune eval jobs, you need to '
'make sure that your pretrain model writes '
'global_step in its checkpoints.'
.
format
(
ckpt_file_path
))
global_step_restored
=
global_step
.
numpy
()
logging
.
info
(
'get global_step %d from checkpoint %s'
,
global_step_restored
,
ckpt_file_path
)
return
global_step_restored
def
write_json_summary
(
log_dir
,
global_step
,
eval_metrics
):
"""Dump evaluation metrics to json file."""
serializable_dict
=
{}
for
name
,
value
in
eval_metrics
.
items
():
if
hasattr
(
value
,
'numpy'
):
serializable_dict
[
name
]
=
str
(
value
.
numpy
())
else
:
serializable_dict
[
name
]
=
str
(
value
)
output_json
=
os
.
path
.
join
(
log_dir
,
'metrics-{}.json'
.
format
(
global_step
))
logging
.
info
(
'Evaluation results at pretrain step %d: %s'
,
global_step
,
serializable_dict
)
with
tf
.
io
.
gfile
.
GFile
(
output_json
,
'w'
)
as
writer
:
writer
.
write
(
json
.
dumps
(
serializable_dict
,
indent
=
4
)
+
'
\n
'
)
def
write_summary
(
summary_writer
,
global_step
,
eval_metrics
):
"""Write evaluation metrics to TF summary."""
numeric_dict
=
{}
for
name
,
value
in
eval_metrics
.
items
():
numeric_dict
[
name
]
=
float
(
orbit
.
utils
.
get_value
(
value
))
with
summary_writer
.
as_default
():
for
name
,
value
in
numeric_dict
.
items
():
tf
.
summary
.
scalar
(
name
,
value
,
step
=
global_step
)
summary_writer
.
flush
()
def
remove_ckpts
(
model_dir
):
"""Remove model checkpoints, so we can restart."""
ckpts
=
os
.
path
.
join
(
model_dir
,
'ckpt-*'
)
logging
.
info
(
'removing checkpoint files %s'
,
ckpts
)
for
file_to_remove
in
tf
.
io
.
gfile
.
glob
(
ckpts
):
tf
.
io
.
gfile
.
rmtree
(
file_to_remove
)
file_to_remove
=
os
.
path
.
join
(
model_dir
,
'checkpoint'
)
if
tf
.
io
.
gfile
.
exists
(
file_to_remove
):
tf
.
io
.
gfile
.
remove
(
file_to_remove
)
def
write_model_params
(
model
:
Union
[
tf
.
Module
,
tf
.
keras
.
Model
],
output_path
:
str
)
->
None
:
"""Writes the model parameters and shapes to a file.
Args:
model: A model instance.
output_path: Output file path.
"""
with
tf
.
io
.
gfile
.
GFile
(
output_path
,
'w'
)
as
f
:
total_params
=
0
for
var
in
model
.
variables
:
shape
=
tf
.
shape
(
var
)
total_params
+=
tf
.
math
.
reduce_prod
(
shape
).
numpy
()
f
.
write
(
f
'
{
var
.
name
}
{
shape
.
numpy
().
tolist
()
}
\n
'
)
f
.
write
(
f
'
\n
Total params:
{
total_params
}
\n
'
)
def
try_count_params
(
model
:
Union
[
tf
.
Module
,
tf
.
keras
.
Model
],
trainable_only
:
bool
=
False
):
"""Count the number of parameters if model is possible.
Args:
model: Try to count the number of params in this model.
trainable_only: Whether to calculate trainable params only. This flag is
not used when the model has `count_params` attribute.
Returns:
The number of parameters or None.
"""
if
hasattr
(
model
,
'count_params'
):
try
:
return
model
.
count_params
()
except
ValueError
:
logging
.
info
(
'Number of trainable params unknown, because the build() '
'methods in keras layers were not called. This is probably '
'because the model was not feed any input, e.g., the max '
'train step already reached before this run.'
)
return
None
else
:
total_params
=
0
variables
=
model
.
trainable_variables
if
trainable_only
else
model
.
variables
for
var
in
variables
:
shape
=
tf
.
shape
(
var
)
total_params
+=
tf
.
math
.
reduce_prod
(
shape
).
numpy
()
return
total_params
def
try_count_flops
(
model
:
Union
[
tf
.
Module
,
tf
.
keras
.
Model
],
inputs_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
output_path
:
Optional
[
str
]
=
None
):
"""Counts and returns model FLOPs.
Args:
model: A model instance.
inputs_kwargs: An optional dictionary of argument pairs specifying inputs'
shape specifications to getting corresponding concrete function.
output_path: A file path to write the profiling results to.
Returns:
The model's FLOPs.
"""
if
hasattr
(
model
,
'inputs'
):
try
:
# Get input shape and set batch size to 1.
if
model
.
inputs
:
inputs
=
[
tf
.
TensorSpec
([
1
]
+
input
.
shape
[
1
:],
input
.
dtype
)
for
input
in
model
.
inputs
]
concrete_func
=
tf
.
function
(
model
).
get_concrete_function
(
inputs
)
# If model.inputs is invalid, try to use the input to get concrete
# function for model.call (subclass model).
else
:
concrete_func
=
tf
.
function
(
model
.
call
).
get_concrete_function
(
**
inputs_kwargs
)
frozen_func
,
_
=
convert_variables_to_constants_v2_as_graph
(
concrete_func
)
# Calculate FLOPs.
run_meta
=
tf
.
compat
.
v1
.
RunMetadata
()
opts
=
tf
.
compat
.
v1
.
profiler
.
ProfileOptionBuilder
.
float_operation
()
if
output_path
is
not
None
:
opts
[
'output'
]
=
f
'file:outfile=
{
output_path
}
'
else
:
opts
[
'output'
]
=
'none'
flops
=
tf
.
compat
.
v1
.
profiler
.
profile
(
graph
=
frozen_func
.
graph
,
run_meta
=
run_meta
,
options
=
opts
)
return
flops
.
total_float_ops
except
Exception
as
e
:
# pylint: disable=broad-except
logging
.
info
(
'Failed to count model FLOPs with error %s, because the build() '
'methods in keras layers were not called. This is probably because '
'the model was not feed any input, e.g., the max train step already '
'reached before this run.'
,
e
)
return
None
return
None
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/core/train_utils_test.py
0 → 100644
View file @
cf66c525
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.core.train_utils."""
import
os
import
numpy
as
np
import
tensorflow
as
tf
from
official.core
import
test_utils
from
official.core
import
train_utils
class
TrainUtilsTest
(
tf
.
test
.
TestCase
):
def
test_get_leaf_nested_dict
(
self
):
d
=
{
'a'
:
{
'i'
:
{
'x'
:
5
}}}
self
.
assertEqual
(
train_utils
.
get_leaf_nested_dict
(
d
,
[
'a'
,
'i'
,
'x'
]),
5
)
def
test_get_leaf_nested_dict_not_leaf
(
self
):
with
self
.
assertRaisesRegex
(
KeyError
,
'The value extracted with keys.*'
):
d
=
{
'a'
:
{
'i'
:
{
'x'
:
5
}}}
train_utils
.
get_leaf_nested_dict
(
d
,
[
'a'
,
'i'
])
def
test_get_leaf_nested_dict_path_not_exist_missing_key
(
self
):
with
self
.
assertRaisesRegex
(
KeyError
,
'Path not exist while traversing .*'
):
d
=
{
'a'
:
{
'i'
:
{
'x'
:
5
}}}
train_utils
.
get_leaf_nested_dict
(
d
,
[
'a'
,
'i'
,
'y'
])
def
test_get_leaf_nested_dict_path_not_exist_out_of_range
(
self
):
with
self
.
assertRaisesRegex
(
KeyError
,
'Path not exist while traversing .*'
):
d
=
{
'a'
:
{
'i'
:
{
'x'
:
5
}}}
train_utils
.
get_leaf_nested_dict
(
d
,
[
'a'
,
'i'
,
'z'
])
def
test_get_leaf_nested_dict_path_not_exist_meets_leaf
(
self
):
with
self
.
assertRaisesRegex
(
KeyError
,
'Path not exist while traversing .*'
):
d
=
{
'a'
:
{
'i'
:
5
}}
train_utils
.
get_leaf_nested_dict
(
d
,
[
'a'
,
'i'
,
'z'
])
def
test_cast_leaf_nested_dict
(
self
):
d
=
{
'a'
:
{
'i'
:
{
'x'
:
'123'
}},
'b'
:
456.5
}
d
=
train_utils
.
cast_leaf_nested_dict
(
d
,
int
)
self
.
assertEqual
(
d
[
'a'
][
'i'
][
'x'
],
123
)
self
.
assertEqual
(
d
[
'b'
],
456
)
def
test_write_model_params_keras_model
(
self
):
inputs
=
np
.
zeros
([
2
,
3
])
model
=
test_utils
.
FakeKerasModel
()
model
(
inputs
)
# Must do forward pass to build the model.
filepath
=
os
.
path
.
join
(
self
.
create_tempdir
(),
'model_params.txt'
)
train_utils
.
write_model_params
(
model
,
filepath
)
actual
=
tf
.
io
.
gfile
.
GFile
(
filepath
,
'r'
).
read
().
splitlines
()
expected
=
[
'fake_keras_model/dense/kernel:0 [3, 4]'
,
'fake_keras_model/dense/bias:0 [4]'
,
'fake_keras_model/dense_1/kernel:0 [4, 4]'
,
'fake_keras_model/dense_1/bias:0 [4]'
,
''
,
'Total params: 36'
,
]
self
.
assertEqual
(
actual
,
expected
)
def
test_write_model_params_module
(
self
):
inputs
=
np
.
zeros
([
2
,
3
],
dtype
=
np
.
float32
)
model
=
test_utils
.
FakeModule
(
3
,
name
=
'fake_module'
)
model
(
inputs
)
# Must do forward pass to build the model.
filepath
=
os
.
path
.
join
(
self
.
create_tempdir
(),
'model_params.txt'
)
train_utils
.
write_model_params
(
model
,
filepath
)
actual
=
tf
.
io
.
gfile
.
GFile
(
filepath
,
'r'
).
read
().
splitlines
()
expected
=
[
'fake_module/dense/b:0 [4]'
,
'fake_module/dense/w:0 [3, 4]'
,
'fake_module/dense_1/b:0 [4]'
,
'fake_module/dense_1/w:0 [4, 4]'
,
''
,
'Total params: 36'
,
]
self
.
assertEqual
(
actual
,
expected
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/modeling/__init__.py
0 → 100644
View file @
cf66c525
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/modeling/activations/__init__.py
0 → 100644
View file @
cf66c525
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Activations package definition."""
from
official.modeling.activations.gelu
import
gelu
from
official.modeling.activations.relu
import
relu6
from
official.modeling.activations.sigmoid
import
hard_sigmoid
from
official.modeling.activations.swish
import
hard_swish
from
official.modeling.activations.swish
import
identity
from
official.modeling.activations.swish
import
simple_swish
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/modeling/activations/gelu.py
0 → 100644
View file @
cf66c525
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gaussian error linear unit."""
import
tensorflow
as
tf
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
def
gelu
(
x
):
"""Gaussian Error Linear Unit.
This is a smoother version of the RELU.
Original paper: https://arxiv.org/abs/1606.08415
Args:
x: float Tensor to perform activation.
Returns:
`x` with the GELU activation applied.
"""
return
tf
.
keras
.
activations
.
gelu
(
x
,
approximate
=
True
)
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/modeling/activations/gelu_test.py
0 → 100644
View file @
cf66c525
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the Gaussian error linear unit."""
import
tensorflow
as
tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.modeling
import
activations
@
keras_parameterized
.
run_all_keras_modes
class
GeluTest
(
keras_parameterized
.
TestCase
):
def
test_gelu
(
self
):
expected_data
=
[[
0.14967535
,
0.
,
-
0.10032465
],
[
-
0.15880796
,
-
0.04540223
,
2.9963627
]]
gelu_data
=
activations
.
gelu
([[.
25
,
0
,
-
.
25
],
[
-
1
,
-
2
,
3
]])
self
.
assertAllClose
(
expected_data
,
gelu_data
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/modeling/activations/relu.py
0 → 100644
View file @
cf66c525
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Customized Relu activation."""
import
tensorflow
as
tf
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
def
relu6
(
features
):
"""Computes the Relu6 activation function.
Args:
features: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
features
=
tf
.
convert_to_tensor
(
features
)
return
tf
.
nn
.
relu6
(
features
)
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/modeling/activations/relu_test.py
0 → 100644
View file @
cf66c525
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the customized Relu activation."""
import
tensorflow
as
tf
from
tensorflow.python.keras
import
\
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.modeling
import
activations
@
keras_parameterized
.
run_all_keras_modes
class
CustomizedReluTest
(
keras_parameterized
.
TestCase
):
def
test_relu6
(
self
):
features
=
[[.
25
,
0
,
-
.
25
],
[
-
1
,
-
2
,
3
]]
customized_relu6_data
=
activations
.
relu6
(
features
)
relu6_data
=
tf
.
nn
.
relu6
(
features
)
self
.
assertAllClose
(
customized_relu6_data
,
relu6_data
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/modeling/activations/sigmoid.py
0 → 100644
View file @
cf66c525
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Customized Sigmoid activation."""
import
tensorflow
as
tf
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
def
hard_sigmoid
(
features
):
"""Computes the hard sigmoid activation function.
Args:
features: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
features
=
tf
.
convert_to_tensor
(
features
)
return
tf
.
nn
.
relu6
(
features
+
tf
.
cast
(
3.
,
features
.
dtype
))
*
0.16667
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/modeling/activations/sigmoid_test.py
0 → 100644
View file @
cf66c525
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the customized Sigmoid activation."""
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.keras
import
\
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.modeling
import
activations
@
keras_parameterized
.
run_all_keras_modes
class
CustomizedSigmoidTest
(
keras_parameterized
.
TestCase
):
def
_hard_sigmoid_nn
(
self
,
x
):
x
=
np
.
float32
(
x
)
return
tf
.
nn
.
relu6
(
x
+
3.
)
*
0.16667
def
test_hard_sigmoid
(
self
):
features
=
[[.
25
,
0
,
-
.
25
],
[
-
1
,
-
2
,
3
]]
customized_hard_sigmoid_data
=
activations
.
hard_sigmoid
(
features
)
sigmoid_data
=
self
.
_hard_sigmoid_nn
(
features
)
self
.
assertAllClose
(
customized_hard_sigmoid_data
,
sigmoid_data
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/modeling/activations/swish.py
0 → 100644
View file @
cf66c525
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Customized Swish activation."""
import
tensorflow
as
tf
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
def
simple_swish
(
features
):
"""Computes the Swish activation function.
The tf.nn.swish operation uses a custom gradient to reduce memory usage.
Since saving custom gradients in SavedModel is currently not supported, and
one would not be able to use an exported TF-Hub module for fine-tuning, we
provide this wrapper that can allow to select whether to use the native
TensorFlow swish operation, or whether to use a customized operation that
has uses default TensorFlow gradient computation.
Args:
features: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
features
=
tf
.
convert_to_tensor
(
features
)
return
features
*
tf
.
nn
.
sigmoid
(
features
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
def
hard_swish
(
features
):
"""Computes a hard version of the swish function.
This operation can be used to reduce computational cost and improve
quantization for edge devices.
Args:
features: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
features
=
tf
.
convert_to_tensor
(
features
)
fdtype
=
features
.
dtype
return
features
*
tf
.
nn
.
relu6
(
features
+
tf
.
cast
(
3.
,
fdtype
))
*
(
1.
/
6.
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
def
identity
(
features
):
"""Computes the identity function.
Useful for helping in quantization.
Args:
features: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
features
=
tf
.
convert_to_tensor
(
features
)
return
tf
.
identity
(
features
)
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/modeling/activations/swish_test.py
0 → 100644
View file @
cf66c525
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the customized Swish activation."""
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.modeling
import
activations
@
keras_parameterized
.
run_all_keras_modes
class
CustomizedSwishTest
(
keras_parameterized
.
TestCase
):
def
_hard_swish_np
(
self
,
x
):
x
=
np
.
float32
(
x
)
return
x
*
np
.
clip
(
x
+
3
,
0
,
6
)
/
6
def
test_simple_swish
(
self
):
features
=
[[.
25
,
0
,
-
.
25
],
[
-
1
,
-
2
,
3
]]
customized_swish_data
=
activations
.
simple_swish
(
features
)
swish_data
=
tf
.
nn
.
swish
(
features
)
self
.
assertAllClose
(
customized_swish_data
,
swish_data
)
def
test_hard_swish
(
self
):
features
=
[[.
25
,
0
,
-
.
25
],
[
-
1
,
-
2
,
3
]]
customized_swish_data
=
activations
.
hard_swish
(
features
)
swish_data
=
self
.
_hard_swish_np
(
features
)
self
.
assertAllClose
(
customized_swish_data
,
swish_data
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
TensorFlow2x/Accuracy_Validation/ResNet50_Official/official/modeling/fast_training/experimental/tf2_utils_2x_wide.py
0 → 100644
View file @
cf66c525
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Stacking model horizontally."""
from
absl
import
logging
import
numpy
as
np
import
tensorflow
as
tf
def
expand_vector
(
v
:
np
.
ndarray
)
->
np
.
ndarray
:
"""Expands a vector with batch dimensions.
Equivalent to expand_1_axis(v, epsilon=0.0, axis=-1)
Args:
v: A vector with shape [..., a].
Returns:
A vector with shape [..., 2 * a].
"""
return
np
.
repeat
(
v
,
2
,
axis
=-
1
)
def
expand_1_axis
(
w
:
np
.
ndarray
,
epsilon
:
float
,
axis
:
int
)
->
np
.
ndarray
:
"""Expands either the first dimension or the last dimension of w.
If `axis = 0`, the following constraint will be satisfied:
matmul(x, w) ==
matmul(expand_vector(x), expand_1_axis(w, epsilon=0.1, axis=0))
If `axis = -1`, the following constraint will be satisfied if `epsilon = 0.0`:
expand_vector(matmul(x, w)) ==
2 * matmul(x, expand_1_axis(w, epsilon=0.0, axis=-1))
Args:
w: Numpy array of shape [a_0, a_1, ..., a_i-1, a_i].
epsilon: Symmetric Noise added to expanded tensor.
axis: Must be either 0 or -1.
Returns:
Expanded numpy array.
"""
assert
axis
in
(
0
,
-
1
),
(
"Only support expanding the first or the last dimension. "
"Got: {}"
.
format
(
axis
))
rank
=
len
(
w
.
shape
)
d_w
=
np
.
random
.
normal
(
np
.
zeros_like
(
w
),
np
.
fabs
(
w
)
*
epsilon
,
w
.
shape
)
d_w
=
np
.
repeat
(
d_w
,
2
,
axis
=
axis
)
sign_flip
=
np
.
array
([
1
,
-
1
])
for
_
in
range
(
rank
-
1
):
sign_flip
=
np
.
expand_dims
(
sign_flip
,
axis
=-
1
if
axis
==
0
else
0
)
sign_flip
=
np
.
tile
(
sign_flip
,
[
w
.
shape
[
0
]]
+
[
1
]
*
(
rank
-
2
)
+
[
w
.
shape
[
-
1
]])
d_w
*=
sign_flip
w_expand
=
(
np
.
repeat
(
w
,
2
,
axis
=
axis
)
+
d_w
)
/
2
return
w_expand
def
expand_2_axes
(
w
:
np
.
ndarray
,
epsilon
:
float
)
->
np
.
ndarray
:
"""Expands the first dimension and the last dimension of w.
The following constraint will be satisfied:
expand_vector(matmul(x, w)) == matmul(expand_vector(x), expand_2_axes(w))
Args:
w: Numpy array of shape [a_0, a_1, ..., a_i-1, a_i].
epsilon: Symmetric Noise added to expanded tensor.
Returns:
Expanded numpy array.
"""
rank
=
len
(
w
.
shape
)
d_w
=
np
.
random
.
normal
(
np
.
zeros_like
(
w
),
np
.
fabs
(
w
)
*
epsilon
,
w
.
shape
)
d_w
=
np
.
repeat
(
np
.
repeat
(
d_w
,
2
,
axis
=
0
),
2
,
axis
=-
1
)
sign_flip
=
np
.
array
([
1
,
-
1
])
for
_
in
range
(
rank
-
1
):
sign_flip
=
np
.
expand_dims
(
sign_flip
,
axis
=-
1
)
sign_flip
=
np
.
tile
(
sign_flip
,
[
w
.
shape
[
0
]]
+
[
1
]
*
(
rank
-
2
)
+
[
w
.
shape
[
-
1
]
*
2
])
d_w
*=
sign_flip
w_expand
=
(
np
.
repeat
(
np
.
repeat
(
w
,
2
,
axis
=
0
),
2
,
axis
=-
1
)
+
d_w
)
/
2
return
w_expand
def
var_to_var
(
var_from
:
tf
.
Variable
,
var_to
:
tf
.
Variable
,
epsilon
:
float
):
"""Expands a variable to another variable.
Assume the shape of `var_from` is (a, b, ..., y, z), the shape of `var_to`
can be (a, ..., z * 2), (a * 2, ..., z * 2), (a * 2, ..., z)
If the shape of `var_to` is (a, ..., 2 * z):
For any x, tf.matmul(x, var_to) ~= expand_vector(tf.matmul(x, var_from)) / 2
Not that there will be noise added to the left hand side, if epsilon != 0.
If the shape of `var_to` is (2 * a, ..., z):
For any x, tf.matmul(expand_vector(x), var_to) == tf.matmul(x, var_from)
If the shape of `var_to` is (2 * a, ..., 2 * z):
For any x, tf.matmul(expand_vector(x), var_to) ==
expand_vector(tf.matmul(expand_vector(x), var_from))
Args:
var_from: input variable to expand.
var_to: output variable.
epsilon: the noise ratio that will be added, when splitting `var_from`.
"""
shape_from
=
var_from
.
shape
shape_to
=
var_to
.
shape
if
shape_from
==
shape_to
:
var_to
.
assign
(
var_from
)
elif
len
(
shape_from
)
==
1
and
len
(
shape_to
)
==
1
:
var_to
.
assign
(
expand_vector
(
var_from
.
numpy
()))
elif
shape_from
[
0
]
*
2
==
shape_to
[
0
]
and
shape_from
[
-
1
]
==
shape_to
[
-
1
]:
var_to
.
assign
(
expand_1_axis
(
var_from
.
numpy
(),
epsilon
=
epsilon
,
axis
=
0
))
elif
shape_from
[
0
]
==
shape_to
[
0
]
and
shape_from
[
-
1
]
*
2
==
shape_to
[
-
1
]:
var_to
.
assign
(
expand_1_axis
(
var_from
.
numpy
(),
epsilon
=
epsilon
,
axis
=-
1
))
elif
shape_from
[
0
]
*
2
==
shape_to
[
0
]
and
shape_from
[
-
1
]
*
2
==
shape_to
[
-
1
]:
var_to
.
assign
(
expand_2_axes
(
var_from
.
numpy
(),
epsilon
=
epsilon
))
else
:
raise
ValueError
(
"Shape not supported, {}, {}"
.
format
(
shape_from
,
shape_to
))
def
model_to_model_2x_wide
(
model_from
:
tf
.
Module
,
model_to
:
tf
.
Module
,
epsilon
:
float
=
0.1
):
"""Expands a model to a wider version.
Also makes sure that the output of the model is not changed after expanding.
For example:
```
model_narrow = tf.keras.Sequential()
model_narrow.add(tf.keras.Input(shape=(3,)))
model_narrow.add(tf.keras.layers.Dense(4))
model_narrow.add(tf.keras.layers.Dense(1))
model_wide = tf.keras.Sequential()
model_wide.add(tf.keras.Input(shape=(6,)))
model_wide.add(tf.keras.layers.Dense(8))
model_wide.add(tf.keras.layers.Dense(1))
model_to_model_2x_wide(model_narrow, model_wide)
assert model_narrow([[1, 2, 3]]) == model_wide([[1, 1, 2, 2, 3, 3]])
```
We assume that `model_from` and `model_to` has the same architecture and only
widths of them differ.
Args:
model_from: input model to expand.
model_to: output model whose variables will be assigned expanded values
according to `model_from`.
epsilon: the noise ratio that will be added, when splitting `var_from`.
"""
for
w_from
,
w_to
in
zip
(
model_from
.
trainable_variables
,
model_to
.
trainable_variables
):
logging
.
info
(
"expanding %s %s to %s %s"
,
w_from
.
name
,
w_from
.
shape
,
w_to
.
name
,
w_to
.
shape
)
var_to_var
(
w_from
,
w_to
,
epsilon
=
epsilon
)
Prev
1
…
5
6
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment