Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
472e2f80
Commit
472e2f80
authored
Mar 16, 2024
by
zhanggzh
Browse files
Merge remote-tracking branch 'tf_model/main'
parents
d91296eb
f3a14f85
Changes
215
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
1824 additions
and
0 deletions
+1824
-0
models-2.13.1/official/legacy/image_classification/classifier_trainer_test.py
...al/legacy/image_classification/classifier_trainer_test.py
+238
-0
models-2.13.1/official/legacy/image_classification/classifier_trainer_util_test.py
...gacy/image_classification/classifier_trainer_util_test.py
+165
-0
models-2.13.1/official/legacy/image_classification/configs/__init__.py
.../official/legacy/image_classification/configs/__init__.py
+14
-0
models-2.13.1/official/legacy/image_classification/configs/base_configs.py
...icial/legacy/image_classification/configs/base_configs.py
+256
-0
models-2.13.1/official/legacy/image_classification/configs/configs.py
...1/official/legacy/image_classification/configs/configs.py
+136
-0
models-2.13.1/official/legacy/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-gpu.yaml
...s/examples/efficientnet/imagenet/efficientnet-b0-gpu.yaml
+52
-0
models-2.13.1/official/legacy/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml
...s/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml
+52
-0
models-2.13.1/official/legacy/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-gpu.yaml
...s/examples/efficientnet/imagenet/efficientnet-b1-gpu.yaml
+47
-0
models-2.13.1/official/legacy/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-tpu.yaml
...s/examples/efficientnet/imagenet/efficientnet-b1-tpu.yaml
+51
-0
models-2.13.1/official/legacy/image_classification/configs/examples/resnet/imagenet/gpu.yaml
..._classification/configs/examples/resnet/imagenet/gpu.yaml
+49
-0
models-2.13.1/official/legacy/image_classification/configs/examples/resnet/imagenet/tpu.yaml
..._classification/configs/examples/resnet/imagenet/tpu.yaml
+55
-0
models-2.13.1/official/legacy/image_classification/configs/examples/vgg16/imagenet/gpu.yaml
...e_classification/configs/examples/vgg16/imagenet/gpu.yaml
+46
-0
models-2.13.1/official/legacy/image_classification/dataset_factory.py
...1/official/legacy/image_classification/dataset_factory.py
+533
-0
models-2.13.1/official/legacy/image_classification/efficientnet/__init__.py
...cial/legacy/image_classification/efficientnet/__init__.py
+14
-0
models-2.13.1/official/legacy/image_classification/efficientnet/common_modules.py
...egacy/image_classification/efficientnet/common_modules.py
+116
-0
No files found.
Too many changes to show.
To preserve performance only
215 of 215+
files are displayed.
Plain diff
Email patch
models-2.13.1/official/legacy/image_classification/classifier_trainer_test.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for the classifier trainer models."""
import
functools
import
json
import
os
import
sys
from
typing
import
Any
,
Callable
,
Iterable
,
Mapping
,
MutableMapping
,
Optional
,
Tuple
from
absl
import
flags
from
absl.testing
import
flagsaver
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.legacy.image_classification
import
classifier_trainer
from
official.utils.flags
import
core
as
flags_core
classifier_trainer
.
define_classifier_flags
()
def
distribution_strategy_combinations
()
->
Iterable
[
Tuple
[
Any
,
...]]:
"""Returns the combinations of end-to-end tests to run."""
return
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
strategy_combinations
.
mirrored_strategy_with_two_gpus
,
],
model
=
[
'efficientnet'
,
'resnet'
,
'vgg'
,
],
dataset
=
[
'imagenet'
,
],
)
def
get_params_override
(
params_override
:
Mapping
[
str
,
Any
])
->
str
:
"""Converts params_override dict to string command."""
return
'--params_override='
+
json
.
dumps
(
params_override
)
def
basic_params_override
(
dtype
:
str
=
'float32'
)
->
MutableMapping
[
str
,
Any
]:
"""Returns a basic parameter configuration for testing."""
return
{
'train_dataset'
:
{
'builder'
:
'synthetic'
,
'use_per_replica_batch_size'
:
True
,
'batch_size'
:
1
,
'image_size'
:
224
,
'dtype'
:
dtype
,
},
'validation_dataset'
:
{
'builder'
:
'synthetic'
,
'batch_size'
:
1
,
'use_per_replica_batch_size'
:
True
,
'image_size'
:
224
,
'dtype'
:
dtype
,
},
'train'
:
{
'steps'
:
1
,
'epochs'
:
1
,
'callbacks'
:
{
'enable_checkpoint_and_export'
:
True
,
'enable_tensorboard'
:
False
,
},
},
'evaluation'
:
{
'steps'
:
1
,
},
}
@
flagsaver
.
flagsaver
def
run_end_to_end
(
main
:
Callable
[[
Any
],
None
],
extra_flags
:
Optional
[
Iterable
[
str
]]
=
None
,
model_dir
:
Optional
[
str
]
=
None
):
"""Runs the classifier trainer end-to-end."""
extra_flags
=
[]
if
extra_flags
is
None
else
extra_flags
args
=
[
sys
.
argv
[
0
],
'--model_dir'
,
model_dir
]
+
extra_flags
flags_core
.
parse_flags
(
argv
=
args
)
main
(
flags
.
FLAGS
)
class
ClassifierTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
"""Unit tests for Keras models."""
_tempdir
=
None
@
classmethod
def
setUpClass
(
cls
):
# pylint: disable=invalid-name
super
(
ClassifierTest
,
cls
).
setUpClass
()
def
tearDown
(
self
):
super
(
ClassifierTest
,
self
).
tearDown
()
tf
.
io
.
gfile
.
rmtree
(
self
.
get_temp_dir
())
@
combinations
.
generate
(
distribution_strategy_combinations
())
def
test_end_to_end_train_and_eval
(
self
,
distribution
,
model
,
dataset
):
"""Test train_and_eval and export for Keras classifier models."""
# Some parameters are not defined as flags (e.g. cannot run
# classifier_train.py --batch_size=...) by design, so use
# "--params_override=..." instead
model_dir
=
self
.
create_tempdir
().
full_path
base_flags
=
[
'--data_dir=not_used'
,
'--model_type='
+
model
,
'--dataset='
+
dataset
,
]
train_and_eval_flags
=
base_flags
+
[
get_params_override
(
basic_params_override
()),
'--mode=train_and_eval'
,
]
run
=
functools
.
partial
(
classifier_trainer
.
run
,
strategy_override
=
distribution
)
run_end_to_end
(
main
=
run
,
extra_flags
=
train_and_eval_flags
,
model_dir
=
model_dir
)
@
combinations
.
generate
(
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
one_device_strategy_gpu
,
],
model
=
[
'efficientnet'
,
'resnet'
,
'vgg'
,
],
dataset
=
'imagenet'
,
dtype
=
'float16'
,
))
def
test_gpu_train
(
self
,
distribution
,
model
,
dataset
,
dtype
):
"""Test train_and_eval and export for Keras classifier models."""
# Some parameters are not defined as flags (e.g. cannot run
# classifier_train.py --batch_size=...) by design, so use
# "--params_override=..." instead
model_dir
=
self
.
create_tempdir
().
full_path
base_flags
=
[
'--data_dir=not_used'
,
'--model_type='
+
model
,
'--dataset='
+
dataset
,
]
train_and_eval_flags
=
base_flags
+
[
get_params_override
(
basic_params_override
(
dtype
)),
'--mode=train_and_eval'
,
]
export_params
=
basic_params_override
()
export_path
=
os
.
path
.
join
(
model_dir
,
'export'
)
export_params
[
'export'
]
=
{}
export_params
[
'export'
][
'destination'
]
=
export_path
export_flags
=
base_flags
+
[
'--mode=export_only'
,
get_params_override
(
export_params
)
]
run
=
functools
.
partial
(
classifier_trainer
.
run
,
strategy_override
=
distribution
)
run_end_to_end
(
main
=
run
,
extra_flags
=
train_and_eval_flags
,
model_dir
=
model_dir
)
run_end_to_end
(
main
=
run
,
extra_flags
=
export_flags
,
model_dir
=
model_dir
)
self
.
assertTrue
(
os
.
path
.
exists
(
export_path
))
@
combinations
.
generate
(
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
cloud_tpu_strategy
,
],
model
=
[
'efficientnet'
,
'resnet'
,
'vgg'
,
],
dataset
=
'imagenet'
,
dtype
=
'bfloat16'
,
))
def
test_tpu_train
(
self
,
distribution
,
model
,
dataset
,
dtype
):
"""Test train_and_eval and export for Keras classifier models."""
# Some parameters are not defined as flags (e.g. cannot run
# classifier_train.py --batch_size=...) by design, so use
# "--params_override=..." instead
model_dir
=
self
.
create_tempdir
().
full_path
base_flags
=
[
'--data_dir=not_used'
,
'--model_type='
+
model
,
'--dataset='
+
dataset
,
]
train_and_eval_flags
=
base_flags
+
[
get_params_override
(
basic_params_override
(
dtype
)),
'--mode=train_and_eval'
,
]
run
=
functools
.
partial
(
classifier_trainer
.
run
,
strategy_override
=
distribution
)
run_end_to_end
(
main
=
run
,
extra_flags
=
train_and_eval_flags
,
model_dir
=
model_dir
)
@
combinations
.
generate
(
distribution_strategy_combinations
())
def
test_end_to_end_invalid_mode
(
self
,
distribution
,
model
,
dataset
):
"""Test the Keras EfficientNet model with `strategy`."""
model_dir
=
self
.
create_tempdir
().
full_path
extra_flags
=
[
'--data_dir=not_used'
,
'--mode=invalid_mode'
,
'--model_type='
+
model
,
'--dataset='
+
dataset
,
get_params_override
(
basic_params_override
()),
]
run
=
functools
.
partial
(
classifier_trainer
.
run
,
strategy_override
=
distribution
)
with
self
.
assertRaises
(
ValueError
):
run_end_to_end
(
main
=
run
,
extra_flags
=
extra_flags
,
model_dir
=
model_dir
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
models-2.13.1/official/legacy/image_classification/classifier_trainer_util_test.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for the classifier trainer models."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
copy
import
os
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.legacy.image_classification
import
classifier_trainer
from
official.legacy.image_classification
import
dataset_factory
from
official.legacy.image_classification
import
test_utils
from
official.legacy.image_classification.configs
import
base_configs
def
get_trivial_model
(
num_classes
:
int
)
->
tf
.
keras
.
Model
:
"""Creates and compiles trivial model for ImageNet dataset."""
model
=
test_utils
.
trivial_model
(
num_classes
=
num_classes
)
lr
=
0.01
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
learning_rate
=
lr
)
loss_obj
=
tf
.
keras
.
losses
.
SparseCategoricalCrossentropy
()
model
.
compile
(
optimizer
=
optimizer
,
loss
=
loss_obj
,
run_eagerly
=
True
)
return
model
def
get_trivial_data
()
->
tf
.
data
.
Dataset
:
"""Gets trivial data in the ImageNet size."""
def
generate_data
(
_
)
->
tf
.
data
.
Dataset
:
image
=
tf
.
zeros
(
shape
=
(
224
,
224
,
3
),
dtype
=
tf
.
float32
)
label
=
tf
.
zeros
([
1
],
dtype
=
tf
.
int32
)
return
image
,
label
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
map
(
generate_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
dataset
=
dataset
.
prefetch
(
buffer_size
=
1
).
batch
(
1
)
return
dataset
class
UtilTests
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
"""Tests for individual utility functions within classifier_trainer.py."""
@
parameterized
.
named_parameters
(
(
'efficientnet-b0'
,
'efficientnet'
,
'efficientnet-b0'
,
224
),
(
'efficientnet-b1'
,
'efficientnet'
,
'efficientnet-b1'
,
240
),
(
'efficientnet-b2'
,
'efficientnet'
,
'efficientnet-b2'
,
260
),
(
'efficientnet-b3'
,
'efficientnet'
,
'efficientnet-b3'
,
300
),
(
'efficientnet-b4'
,
'efficientnet'
,
'efficientnet-b4'
,
380
),
(
'efficientnet-b5'
,
'efficientnet'
,
'efficientnet-b5'
,
456
),
(
'efficientnet-b6'
,
'efficientnet'
,
'efficientnet-b6'
,
528
),
(
'efficientnet-b7'
,
'efficientnet'
,
'efficientnet-b7'
,
600
),
(
'resnet'
,
'resnet'
,
''
,
None
),
)
def
test_get_model_size
(
self
,
model
,
model_name
,
expected
):
config
=
base_configs
.
ExperimentConfig
(
model_name
=
model
,
model
=
base_configs
.
ModelConfig
(
model_params
=
{
'model_name'
:
model_name
,
},))
size
=
classifier_trainer
.
get_image_size_from_model
(
config
)
self
.
assertEqual
(
size
,
expected
)
@
parameterized
.
named_parameters
(
(
'dynamic'
,
'dynamic'
,
None
,
'dynamic'
),
(
'scalar'
,
128.
,
None
,
128.
),
(
'float32'
,
None
,
'float32'
,
1
),
(
'float16'
,
None
,
'float16'
,
128
),
)
def
test_get_loss_scale
(
self
,
loss_scale
,
dtype
,
expected
):
config
=
base_configs
.
ExperimentConfig
(
runtime
=
base_configs
.
RuntimeConfig
(
loss_scale
=
loss_scale
),
train_dataset
=
dataset_factory
.
DatasetConfig
(
dtype
=
dtype
))
ls
=
classifier_trainer
.
get_loss_scale
(
config
,
fp16_default
=
128
)
self
.
assertEqual
(
ls
,
expected
)
@
parameterized
.
named_parameters
((
'float16'
,
'float16'
),
(
'bfloat16'
,
'bfloat16'
))
def
test_initialize
(
self
,
dtype
):
config
=
base_configs
.
ExperimentConfig
(
runtime
=
base_configs
.
RuntimeConfig
(
run_eagerly
=
False
,
enable_xla
=
False
,
per_gpu_thread_count
=
1
,
gpu_thread_mode
=
'gpu_private'
,
num_gpus
=
1
,
dataset_num_private_threads
=
1
,
),
train_dataset
=
dataset_factory
.
DatasetConfig
(
dtype
=
dtype
),
model
=
base_configs
.
ModelConfig
(),
)
class
EmptyClass
:
pass
fake_ds_builder
=
EmptyClass
()
fake_ds_builder
.
dtype
=
dtype
fake_ds_builder
.
config
=
EmptyClass
()
classifier_trainer
.
initialize
(
config
,
fake_ds_builder
)
def
test_resume_from_checkpoint
(
self
):
"""Tests functionality for resuming from checkpoint."""
# Set the keras policy
tf
.
keras
.
mixed_precision
.
set_global_policy
(
'mixed_bfloat16'
)
# Get the model, datasets, and compile it.
model
=
get_trivial_model
(
10
)
# Create the checkpoint
model_dir
=
self
.
create_tempdir
().
full_path
train_epochs
=
1
train_steps
=
10
ds
=
get_trivial_data
()
callbacks
=
[
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
os
.
path
.
join
(
model_dir
,
'model.ckpt-{epoch:04d}'
),
save_weights_only
=
True
)
]
model
.
fit
(
ds
,
callbacks
=
callbacks
,
epochs
=
train_epochs
,
steps_per_epoch
=
train_steps
)
# Test load from checkpoint
clean_model
=
get_trivial_model
(
10
)
weights_before_load
=
copy
.
deepcopy
(
clean_model
.
get_weights
())
initial_epoch
=
classifier_trainer
.
resume_from_checkpoint
(
model
=
clean_model
,
model_dir
=
model_dir
,
train_steps
=
train_steps
)
self
.
assertEqual
(
initial_epoch
,
1
)
self
.
assertNotAllClose
(
weights_before_load
,
clean_model
.
get_weights
())
tf
.
io
.
gfile
.
rmtree
(
model_dir
)
def
test_serialize_config
(
self
):
"""Tests functionality for serializing data."""
config
=
base_configs
.
ExperimentConfig
()
model_dir
=
self
.
create_tempdir
().
full_path
classifier_trainer
.
serialize_config
(
params
=
config
,
model_dir
=
model_dir
)
saved_params_path
=
os
.
path
.
join
(
model_dir
,
'params.yaml'
)
self
.
assertTrue
(
os
.
path
.
exists
(
saved_params_path
))
tf
.
io
.
gfile
.
rmtree
(
model_dir
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
models-2.13.1/official/legacy/image_classification/configs/__init__.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
models-2.13.1/official/legacy/image_classification/configs/base_configs.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Definitions for high level configuration groups.."""
import
dataclasses
from
typing
import
Any
,
List
,
Optional
from
official.core
import
config_definitions
from
official.modeling
import
hyperparams
RuntimeConfig
=
config_definitions
.
RuntimeConfig
@
dataclasses
.
dataclass
class
TensorBoardConfig
(
hyperparams
.
Config
):
"""Configuration for TensorBoard.
Attributes:
track_lr: Whether or not to track the learning rate in TensorBoard. Defaults
to True.
write_model_weights: Whether or not to write the model weights as images in
TensorBoard. Defaults to False.
"""
track_lr
:
bool
=
True
write_model_weights
:
bool
=
False
@
dataclasses
.
dataclass
class
CallbacksConfig
(
hyperparams
.
Config
):
"""Configuration for Callbacks.
Attributes:
enable_checkpoint_and_export: Whether or not to enable checkpoints as a
Callback. Defaults to True.
enable_backup_and_restore: Whether or not to add BackupAndRestore
callback. Defaults to True.
enable_tensorboard: Whether or not to enable TensorBoard as a Callback.
Defaults to True.
enable_time_history: Whether or not to enable TimeHistory Callbacks.
Defaults to True.
"""
enable_checkpoint_and_export
:
bool
=
True
enable_backup_and_restore
:
bool
=
False
enable_tensorboard
:
bool
=
True
enable_time_history
:
bool
=
True
@
dataclasses
.
dataclass
class
ExportConfig
(
hyperparams
.
Config
):
"""Configuration for exports.
Attributes:
checkpoint: the path to the checkpoint to export.
destination: the path to where the checkpoint should be exported.
"""
checkpoint
:
str
=
None
destination
:
str
=
None
@
dataclasses
.
dataclass
class
MetricsConfig
(
hyperparams
.
Config
):
"""Configuration for Metrics.
Attributes:
accuracy: Whether or not to track accuracy as a Callback. Defaults to None.
top_5: Whether or not to track top_5_accuracy as a Callback. Defaults to
None.
"""
accuracy
:
bool
=
None
top_5
:
bool
=
None
@
dataclasses
.
dataclass
class
TimeHistoryConfig
(
hyperparams
.
Config
):
"""Configuration for the TimeHistory callback.
Attributes:
log_steps: Interval of steps between logging of batch level stats.
"""
log_steps
:
int
=
None
@
dataclasses
.
dataclass
class
TrainConfig
(
hyperparams
.
Config
):
"""Configuration for training.
Attributes:
resume_checkpoint: Whether or not to enable load checkpoint loading.
Defaults to None.
epochs: The number of training epochs to run. Defaults to None.
steps: The number of steps to run per epoch. If None, then this will be
inferred based on the number of images and batch size. Defaults to None.
callbacks: An instance of CallbacksConfig.
metrics: An instance of MetricsConfig.
tensorboard: An instance of TensorBoardConfig.
set_epoch_loop: Whether or not to set `steps_per_execution` to
equal the number of training steps in `model.compile`. This reduces the
number of callbacks run per epoch which significantly improves end-to-end
TPU training time.
"""
resume_checkpoint
:
bool
=
None
epochs
:
int
=
None
steps
:
int
=
None
callbacks
:
CallbacksConfig
=
CallbacksConfig
()
metrics
:
MetricsConfig
=
None
tensorboard
:
TensorBoardConfig
=
TensorBoardConfig
()
time_history
:
TimeHistoryConfig
=
TimeHistoryConfig
()
set_epoch_loop
:
bool
=
False
@
dataclasses
.
dataclass
class
EvalConfig
(
hyperparams
.
Config
):
"""Configuration for evaluation.
Attributes:
epochs_between_evals: The number of train epochs to run between evaluations.
Defaults to None.
steps: The number of eval steps to run during evaluation. If None, this will
be inferred based on the number of images and batch size. Defaults to
None.
skip_eval: Whether or not to skip evaluation.
"""
epochs_between_evals
:
int
=
None
steps
:
int
=
None
skip_eval
:
bool
=
False
@
dataclasses
.
dataclass
class
LossConfig
(
hyperparams
.
Config
):
"""Configuration for Loss.
Attributes:
name: The name of the loss. Defaults to None.
label_smoothing: Whether or not to apply label smoothing to the loss. This
only applies to 'categorical_cross_entropy'.
"""
name
:
str
=
None
label_smoothing
:
float
=
None
@
dataclasses
.
dataclass
class
OptimizerConfig
(
hyperparams
.
Config
):
"""Configuration for Optimizers.
Attributes:
name: The name of the optimizer. Defaults to None.
decay: Decay or rho, discounting factor for gradient. Defaults to None.
epsilon: Small value used to avoid 0 denominator. Defaults to None.
momentum: Plain momentum constant. Defaults to None.
nesterov: Whether or not to apply Nesterov momentum. Defaults to None.
moving_average_decay: The amount of decay to apply. If 0 or None, then
exponential moving average is not used. Defaults to None.
lookahead: Whether or not to apply the lookahead optimizer. Defaults to
None.
beta_1: The exponential decay rate for the 1st moment estimates. Used in the
Adam optimizers. Defaults to None.
beta_2: The exponential decay rate for the 2nd moment estimates. Used in the
Adam optimizers. Defaults to None.
epsilon: Small value used to avoid 0 denominator. Defaults to 1e-7.
"""
name
:
str
=
None
decay
:
float
=
None
epsilon
:
float
=
None
momentum
:
float
=
None
nesterov
:
bool
=
None
moving_average_decay
:
Optional
[
float
]
=
None
lookahead
:
Optional
[
bool
]
=
None
beta_1
:
float
=
None
beta_2
:
float
=
None
epsilon
:
float
=
None
@
dataclasses
.
dataclass
class
LearningRateConfig
(
hyperparams
.
Config
):
"""Configuration for learning rates.
Attributes:
name: The name of the learning rate. Defaults to None.
initial_lr: The initial learning rate. Defaults to None.
decay_epochs: The number of decay epochs. Defaults to None.
decay_rate: The rate of decay. Defaults to None.
warmup_epochs: The number of warmup epochs. Defaults to None.
batch_lr_multiplier: The multiplier to apply to the base learning rate, if
necessary. Defaults to None.
examples_per_epoch: the number of examples in a single epoch. Defaults to
None.
boundaries: boundaries used in piecewise constant decay with warmup.
multipliers: multipliers used in piecewise constant decay with warmup.
scale_by_batch_size: Scale the learning rate by a fraction of the batch
size. Set to 0 for no scaling (default).
staircase: Apply exponential decay at discrete values instead of continuous.
"""
name
:
str
=
None
initial_lr
:
float
=
None
decay_epochs
:
float
=
None
decay_rate
:
float
=
None
warmup_epochs
:
int
=
None
examples_per_epoch
:
int
=
None
boundaries
:
List
[
int
]
=
None
multipliers
:
List
[
float
]
=
None
scale_by_batch_size
:
float
=
0.
staircase
:
bool
=
None
@
dataclasses
.
dataclass
class
ModelConfig
(
hyperparams
.
Config
):
"""Configuration for Models.
Attributes:
name: The name of the model. Defaults to None.
model_params: The parameters used to create the model. Defaults to None.
num_classes: The number of classes in the model. Defaults to None.
loss: A `LossConfig` instance. Defaults to None.
optimizer: An `OptimizerConfig` instance. Defaults to None.
"""
name
:
str
=
None
model_params
:
hyperparams
.
Config
=
None
num_classes
:
int
=
None
loss
:
LossConfig
=
None
optimizer
:
OptimizerConfig
=
None
@
dataclasses
.
dataclass
class
ExperimentConfig
(
hyperparams
.
Config
):
"""Base configuration for an image classification experiment.
Attributes:
model_dir: The directory to use when running an experiment.
mode: e.g. 'train_and_eval', 'export'
runtime: A `RuntimeConfig` instance.
train: A `TrainConfig` instance.
evaluation: An `EvalConfig` instance.
model: A `ModelConfig` instance.
export: An `ExportConfig` instance.
"""
model_dir
:
str
=
None
model_name
:
str
=
None
mode
:
str
=
None
runtime
:
RuntimeConfig
=
None
train_dataset
:
Any
=
None
validation_dataset
:
Any
=
None
train
:
TrainConfig
=
None
evaluation
:
EvalConfig
=
None
model
:
ModelConfig
=
None
export
:
ExportConfig
=
None
models-2.13.1/official/legacy/image_classification/configs/configs.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration utils for image classification experiments."""
import
dataclasses
from
official.legacy.image_classification
import
dataset_factory
from
official.legacy.image_classification.configs
import
base_configs
from
official.legacy.image_classification.efficientnet
import
efficientnet_config
from
official.legacy.image_classification.resnet
import
resnet_config
from
official.legacy.image_classification.vgg
import
vgg_config
@
dataclasses
.
dataclass
class
EfficientNetImageNetConfig
(
base_configs
.
ExperimentConfig
):
"""Base configuration to train efficientnet-b0 on ImageNet.
Attributes:
export: An `ExportConfig` instance
runtime: A `RuntimeConfig` instance.
dataset: A `DatasetConfig` instance.
train: A `TrainConfig` instance.
evaluation: An `EvalConfig` instance.
model: A `ModelConfig` instance.
"""
export
:
base_configs
.
ExportConfig
=
base_configs
.
ExportConfig
()
runtime
:
base_configs
.
RuntimeConfig
=
base_configs
.
RuntimeConfig
()
train_dataset
:
dataset_factory
.
DatasetConfig
=
dataset_factory
.
ImageNetConfig
(
split
=
'train'
)
validation_dataset
:
dataset_factory
.
DatasetConfig
=
dataset_factory
.
ImageNetConfig
(
split
=
'validation'
)
train
:
base_configs
.
TrainConfig
=
base_configs
.
TrainConfig
(
resume_checkpoint
=
True
,
epochs
=
500
,
steps
=
None
,
callbacks
=
base_configs
.
CallbacksConfig
(
enable_checkpoint_and_export
=
True
,
enable_tensorboard
=
True
),
metrics
=
[
'accuracy'
,
'top_5'
],
time_history
=
base_configs
.
TimeHistoryConfig
(
log_steps
=
100
),
tensorboard
=
base_configs
.
TensorBoardConfig
(
track_lr
=
True
,
write_model_weights
=
False
),
set_epoch_loop
=
False
)
evaluation
:
base_configs
.
EvalConfig
=
base_configs
.
EvalConfig
(
epochs_between_evals
=
1
,
steps
=
None
)
model
:
base_configs
.
ModelConfig
=
efficientnet_config
.
EfficientNetModelConfig
(
)
@
dataclasses
.
dataclass
class
ResNetImagenetConfig
(
base_configs
.
ExperimentConfig
):
"""Base configuration to train resnet-50 on ImageNet."""
export
:
base_configs
.
ExportConfig
=
base_configs
.
ExportConfig
()
runtime
:
base_configs
.
RuntimeConfig
=
base_configs
.
RuntimeConfig
()
train_dataset
:
dataset_factory
.
DatasetConfig
=
\
dataset_factory
.
ImageNetConfig
(
split
=
'train'
,
one_hot
=
False
,
mean_subtract
=
True
,
standardize
=
True
)
validation_dataset
:
dataset_factory
.
DatasetConfig
=
\
dataset_factory
.
ImageNetConfig
(
split
=
'validation'
,
one_hot
=
False
,
mean_subtract
=
True
,
standardize
=
True
)
train
:
base_configs
.
TrainConfig
=
base_configs
.
TrainConfig
(
resume_checkpoint
=
True
,
epochs
=
90
,
steps
=
None
,
callbacks
=
base_configs
.
CallbacksConfig
(
enable_checkpoint_and_export
=
True
,
enable_tensorboard
=
True
),
metrics
=
[
'accuracy'
,
'top_5'
],
time_history
=
base_configs
.
TimeHistoryConfig
(
log_steps
=
100
),
tensorboard
=
base_configs
.
TensorBoardConfig
(
track_lr
=
True
,
write_model_weights
=
False
),
set_epoch_loop
=
False
)
evaluation
:
base_configs
.
EvalConfig
=
base_configs
.
EvalConfig
(
epochs_between_evals
=
1
,
steps
=
None
)
model
:
base_configs
.
ModelConfig
=
resnet_config
.
ResNetModelConfig
()
@
dataclasses
.
dataclass
class
VGGImagenetConfig
(
base_configs
.
ExperimentConfig
):
"""Base configuration to train vgg-16 on ImageNet."""
export
:
base_configs
.
ExportConfig
=
base_configs
.
ExportConfig
()
runtime
:
base_configs
.
RuntimeConfig
=
base_configs
.
RuntimeConfig
()
train_dataset
:
dataset_factory
.
DatasetConfig
=
dataset_factory
.
ImageNetConfig
(
split
=
'train'
,
one_hot
=
False
,
mean_subtract
=
True
,
standardize
=
True
)
validation_dataset
:
dataset_factory
.
DatasetConfig
=
dataset_factory
.
ImageNetConfig
(
split
=
'validation'
,
one_hot
=
False
,
mean_subtract
=
True
,
standardize
=
True
)
train
:
base_configs
.
TrainConfig
=
base_configs
.
TrainConfig
(
resume_checkpoint
=
True
,
epochs
=
90
,
steps
=
None
,
callbacks
=
base_configs
.
CallbacksConfig
(
enable_checkpoint_and_export
=
True
,
enable_tensorboard
=
True
),
metrics
=
[
'accuracy'
,
'top_5'
],
time_history
=
base_configs
.
TimeHistoryConfig
(
log_steps
=
100
),
tensorboard
=
base_configs
.
TensorBoardConfig
(
track_lr
=
True
,
write_model_weights
=
False
),
set_epoch_loop
=
False
)
evaluation
:
base_configs
.
EvalConfig
=
base_configs
.
EvalConfig
(
epochs_between_evals
=
1
,
steps
=
None
)
model
:
base_configs
.
ModelConfig
=
vgg_config
.
VGGModelConfig
()
def
get_config
(
model
:
str
,
dataset
:
str
)
->
base_configs
.
ExperimentConfig
:
"""Given model and dataset names, return the ExperimentConfig."""
dataset_model_config_map
=
{
'imagenet'
:
{
'efficientnet'
:
EfficientNetImageNetConfig
(),
'resnet'
:
ResNetImagenetConfig
(),
'vgg'
:
VGGImagenetConfig
(),
}
}
try
:
return
dataset_model_config_map
[
dataset
][
model
]
except
KeyError
:
if
dataset
not
in
dataset_model_config_map
:
raise
KeyError
(
'Invalid dataset received. Received: {}. Supported '
'datasets include: {}'
.
format
(
dataset
,
', '
.
join
(
dataset_model_config_map
.
keys
())))
raise
KeyError
(
'Invalid model received. Received: {}. Supported models for'
'{} include: {}'
.
format
(
model
,
dataset
,
', '
.
join
(
dataset_model_config_map
[
dataset
].
keys
())))
models-2.13.1/official/legacy/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-gpu.yaml
0 → 100644
View file @
472e2f80
# Training configuration for EfficientNet-b0 trained on ImageNet on GPUs.
# Takes ~32 minutes per epoch for 8 V100s.
# Reaches ~76.1% within 350 epochs.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime
:
distribution_strategy
:
'
mirrored'
num_gpus
:
1
train_dataset
:
name
:
'
imagenet2012'
data_dir
:
null
builder
:
'
records'
split
:
'
train'
num_classes
:
1000
num_examples
:
1281167
batch_size
:
32
use_per_replica_batch_size
:
true
dtype
:
'
float32'
augmenter
:
name
:
'
autoaugment'
validation_dataset
:
name
:
'
imagenet2012'
data_dir
:
null
builder
:
'
records'
split
:
'
validation'
num_classes
:
1000
num_examples
:
50000
batch_size
:
32
use_per_replica_batch_size
:
true
dtype
:
'
float32'
model
:
model_params
:
model_name
:
'
efficientnet-b0'
overrides
:
num_classes
:
1000
batch_norm
:
'
default'
dtype
:
'
float32'
activation
:
'
swish'
optimizer
:
name
:
'
rmsprop'
momentum
:
0.9
decay
:
0.9
moving_average_decay
:
0.0
lookahead
:
false
learning_rate
:
name
:
'
exponential'
loss
:
label_smoothing
:
0.1
train
:
resume_checkpoint
:
true
epochs
:
500
evaluation
:
epochs_between_evals
:
1
models-2.13.1/official/legacy/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml
0 → 100644
View file @
472e2f80
# Training configuration for EfficientNet-b0 trained on ImageNet on TPUs.
# Takes ~2 minutes, 50 seconds per epoch for v3-32.
# Reaches ~76.1% within 350 epochs.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime
:
distribution_strategy
:
'
tpu'
train_dataset
:
name
:
'
imagenet2012'
data_dir
:
null
builder
:
'
records'
split
:
'
train'
num_classes
:
1000
num_examples
:
1281167
batch_size
:
128
use_per_replica_batch_size
:
true
dtype
:
'
bfloat16'
augmenter
:
name
:
'
autoaugment'
validation_dataset
:
name
:
'
imagenet2012'
data_dir
:
null
builder
:
'
records'
split
:
'
validation'
num_classes
:
1000
num_examples
:
50000
batch_size
:
128
use_per_replica_batch_size
:
true
dtype
:
'
bfloat16'
model
:
model_params
:
model_name
:
'
efficientnet-b0'
overrides
:
num_classes
:
1000
batch_norm
:
'
tpu'
dtype
:
'
bfloat16'
activation
:
'
swish'
optimizer
:
name
:
'
rmsprop'
momentum
:
0.9
decay
:
0.9
moving_average_decay
:
0.0
lookahead
:
false
learning_rate
:
name
:
'
exponential'
loss
:
label_smoothing
:
0.1
train
:
resume_checkpoint
:
true
epochs
:
500
set_epoch_loop
:
true
evaluation
:
epochs_between_evals
:
1
models-2.13.1/official/legacy/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-gpu.yaml
0 → 100644
View file @
472e2f80
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime
:
distribution_strategy
:
'
mirrored'
num_gpus
:
1
train_dataset
:
name
:
'
imagenet2012'
data_dir
:
null
builder
:
'
records'
split
:
'
train'
num_classes
:
1000
num_examples
:
1281167
batch_size
:
32
use_per_replica_batch_size
:
true
dtype
:
'
float32'
validation_dataset
:
name
:
'
imagenet2012'
data_dir
:
null
builder
:
'
records'
split
:
'
validation'
num_classes
:
1000
num_examples
:
50000
batch_size
:
32
use_per_replica_batch_size
:
true
dtype
:
'
float32'
model
:
model_params
:
model_name
:
'
efficientnet-b1'
overrides
:
num_classes
:
1000
batch_norm
:
'
default'
dtype
:
'
float32'
activation
:
'
swish'
optimizer
:
name
:
'
rmsprop'
momentum
:
0.9
decay
:
0.9
moving_average_decay
:
0.0
lookahead
:
false
learning_rate
:
name
:
'
exponential'
loss
:
label_smoothing
:
0.1
train
:
resume_checkpoint
:
true
epochs
:
500
evaluation
:
epochs_between_evals
:
1
models-2.13.1/official/legacy/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-tpu.yaml
0 → 100644
View file @
472e2f80
# Training configuration for EfficientNet-b1 trained on ImageNet on TPUs.
# Takes ~3 minutes, 15 seconds per epoch for v3-32.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime
:
distribution_strategy
:
'
tpu'
train_dataset
:
name
:
'
imagenet2012'
data_dir
:
null
builder
:
'
records'
split
:
'
train'
num_classes
:
1000
num_examples
:
1281167
batch_size
:
128
use_per_replica_batch_size
:
true
dtype
:
'
bfloat16'
augmenter
:
name
:
'
autoaugment'
validation_dataset
:
name
:
'
imagenet2012'
data_dir
:
null
builder
:
'
records'
split
:
'
validation'
num_classes
:
1000
num_examples
:
50000
batch_size
:
128
use_per_replica_batch_size
:
true
dtype
:
'
bfloat16'
model
:
model_params
:
model_name
:
'
efficientnet-b1'
overrides
:
num_classes
:
1000
batch_norm
:
'
tpu'
dtype
:
'
bfloat16'
activation
:
'
swish'
optimizer
:
name
:
'
rmsprop'
momentum
:
0.9
decay
:
0.9
moving_average_decay
:
0.0
lookahead
:
false
learning_rate
:
name
:
'
exponential'
loss
:
label_smoothing
:
0.1
train
:
resume_checkpoint
:
true
epochs
:
500
set_epoch_loop
:
true
evaluation
:
epochs_between_evals
:
1
models-2.13.1/official/legacy/image_classification/configs/examples/resnet/imagenet/gpu.yaml
0 → 100644
View file @
472e2f80
# Training configuration for ResNet trained on ImageNet on GPUs.
# Reaches > 76.1% within 90 epochs.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime
:
distribution_strategy
:
'
mirrored'
num_gpus
:
1
batchnorm_spatial_persistent
:
true
train_dataset
:
name
:
'
imagenet2012'
data_dir
:
null
builder
:
'
tfds'
split
:
'
train'
image_size
:
224
num_classes
:
1000
num_examples
:
1281167
batch_size
:
256
use_per_replica_batch_size
:
true
dtype
:
'
float16'
mean_subtract
:
true
standardize
:
true
validation_dataset
:
name
:
'
imagenet2012'
data_dir
:
null
builder
:
'
tfds'
split
:
'
validation'
image_size
:
224
num_classes
:
1000
num_examples
:
50000
batch_size
:
256
use_per_replica_batch_size
:
true
dtype
:
'
float16'
mean_subtract
:
true
standardize
:
true
model
:
name
:
'
resnet'
model_params
:
rescale_inputs
:
false
optimizer
:
name
:
'
momentum'
momentum
:
0.9
decay
:
0.9
epsilon
:
0.001
loss
:
label_smoothing
:
0.1
train
:
resume_checkpoint
:
true
epochs
:
90
evaluation
:
epochs_between_evals
:
1
models-2.13.1/official/legacy/image_classification/configs/examples/resnet/imagenet/tpu.yaml
0 → 100644
View file @
472e2f80
# Training configuration for ResNet trained on ImageNet on TPUs.
# Takes ~4 minutes, 30 seconds seconds per epoch for a v3-32.
# Reaches > 76.1% within 90 epochs.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime
:
distribution_strategy
:
'
tpu'
train_dataset
:
name
:
'
imagenet2012'
data_dir
:
null
builder
:
'
tfds'
split
:
'
train'
one_hot
:
false
image_size
:
224
num_classes
:
1000
num_examples
:
1281167
batch_size
:
128
use_per_replica_batch_size
:
true
mean_subtract
:
false
standardize
:
false
dtype
:
'
bfloat16'
validation_dataset
:
name
:
'
imagenet2012'
data_dir
:
null
builder
:
'
tfds'
split
:
'
validation'
one_hot
:
false
image_size
:
224
num_classes
:
1000
num_examples
:
50000
batch_size
:
128
use_per_replica_batch_size
:
true
mean_subtract
:
false
standardize
:
false
dtype
:
'
bfloat16'
model
:
name
:
'
resnet'
model_params
:
rescale_inputs
:
true
optimizer
:
name
:
'
momentum'
momentum
:
0.9
decay
:
0.9
epsilon
:
0.001
moving_average_decay
:
0.
lookahead
:
false
loss
:
label_smoothing
:
0.1
train
:
callbacks
:
enable_checkpoint_and_export
:
true
resume_checkpoint
:
true
epochs
:
90
set_epoch_loop
:
true
evaluation
:
epochs_between_evals
:
1
models-2.13.1/official/legacy/image_classification/configs/examples/vgg16/imagenet/gpu.yaml
0 → 100644
View file @
472e2f80
# Training configuration for VGG-16 trained on ImageNet on GPUs.
# Reaches > 72.8% within 90 epochs.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime
:
distribution_strategy
:
'
mirrored'
num_gpus
:
1
batchnorm_spatial_persistent
:
true
train_dataset
:
name
:
'
imagenet2012'
data_dir
:
null
builder
:
'
records'
split
:
'
train'
image_size
:
224
num_classes
:
1000
num_examples
:
1281167
batch_size
:
128
use_per_replica_batch_size
:
true
dtype
:
'
float32'
mean_subtract
:
true
standardize
:
true
validation_dataset
:
name
:
'
imagenet2012'
data_dir
:
null
builder
:
'
records'
split
:
'
validation'
image_size
:
224
num_classes
:
1000
num_examples
:
50000
batch_size
:
128
use_per_replica_batch_size
:
true
dtype
:
'
float32'
mean_subtract
:
true
standardize
:
true
model
:
name
:
'
vgg'
optimizer
:
name
:
'
momentum'
momentum
:
0.9
epsilon
:
0.001
loss
:
label_smoothing
:
0.0
train
:
resume_checkpoint
:
true
epochs
:
90
evaluation
:
epochs_between_evals
:
1
models-2.13.1/official/legacy/image_classification/dataset_factory.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataset utilities for vision tasks using TFDS and tf.data.Dataset."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
dataclasses
import
os
from
typing
import
Any
,
List
,
Mapping
,
Optional
,
Tuple
,
Union
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow_datasets
as
tfds
from
official.legacy.image_classification
import
augment
from
official.legacy.image_classification
import
preprocessing
from
official.modeling.hyperparams
import
base_config
AUGMENTERS
=
{
'autoaugment'
:
augment
.
AutoAugment
,
'randaugment'
:
augment
.
RandAugment
,
}
@
dataclasses
.
dataclass
class
AugmentConfig
(
base_config
.
Config
):
"""Configuration for image augmenters.
Attributes:
name: The name of the image augmentation to use. Possible options are None
(default), 'autoaugment', or 'randaugment'.
params: Any parameters used to initialize the augmenter.
"""
name
:
Optional
[
str
]
=
None
params
:
Optional
[
Mapping
[
str
,
Any
]]
=
None
def
build
(
self
)
->
augment
.
ImageAugment
:
"""Build the augmenter using this config."""
params
=
self
.
params
or
{}
augmenter
=
AUGMENTERS
.
get
(
self
.
name
,
None
)
return
augmenter
(
**
params
)
if
augmenter
is
not
None
else
None
@
dataclasses
.
dataclass
class
DatasetConfig
(
base_config
.
Config
):
"""The base configuration for building datasets.
Attributes:
name: The name of the Dataset. Usually should correspond to a TFDS dataset.
data_dir: The path where the dataset files are stored, if available.
filenames: Optional list of strings representing the TFRecord names.
builder: The builder type used to load the dataset. Value should be one of
'tfds' (load using TFDS), 'records' (load from TFRecords), or 'synthetic'
(generate dummy synthetic data without reading from files).
split: The split of the dataset. Usually 'train', 'validation', or 'test'.
image_size: The size of the image in the dataset. This assumes that `width`
== `height`. Set to 'infer' to infer the image size from TFDS info. This
requires `name` to be a registered dataset in TFDS.
num_classes: The number of classes given by the dataset. Set to 'infer' to
infer the image size from TFDS info. This requires `name` to be a
registered dataset in TFDS.
num_channels: The number of channels given by the dataset. Set to 'infer' to
infer the image size from TFDS info. This requires `name` to be a
registered dataset in TFDS.
num_examples: The number of examples given by the dataset. Set to 'infer' to
infer the image size from TFDS info. This requires `name` to be a
registered dataset in TFDS.
batch_size: The base batch size for the dataset.
use_per_replica_batch_size: Whether to scale the batch size based on
available resources. If set to `True`, the dataset builder will return
batch_size multiplied by `num_devices`, the number of device replicas
(e.g., the number of GPUs or TPU cores). This setting should be `True` if
the strategy argument is passed to `build()` and `num_devices > 1`.
num_devices: The number of replica devices to use. This should be set by
`strategy.num_replicas_in_sync` when using a distribution strategy.
dtype: The desired dtype of the dataset. This will be set during
preprocessing.
one_hot: Whether to apply one hot encoding. Set to `True` to be able to use
label smoothing.
augmenter: The augmenter config to use. No augmentation is used by default.
download: Whether to download data using TFDS.
shuffle_buffer_size: The buffer size used for shuffling training data.
file_shuffle_buffer_size: The buffer size used for shuffling raw training
files.
skip_decoding: Whether to skip image decoding when loading from TFDS.
cache: whether to cache to dataset examples. Can be used to avoid re-reading
from disk on the second epoch. Requires significant memory overhead.
tf_data_service: The URI of a tf.data service to offload preprocessing onto
during training. The URI should be in the format "protocol://address",
e.g. "grpc://tf-data-service:5050".
mean_subtract: whether or not to apply mean subtraction to the dataset.
standardize: whether or not to apply standardization to the dataset.
"""
name
:
Optional
[
str
]
=
None
data_dir
:
Optional
[
str
]
=
None
filenames
:
Optional
[
List
[
str
]]
=
None
builder
:
str
=
'tfds'
split
:
str
=
'train'
image_size
:
Union
[
int
,
str
]
=
'infer'
num_classes
:
Union
[
int
,
str
]
=
'infer'
num_channels
:
Union
[
int
,
str
]
=
'infer'
num_examples
:
Union
[
int
,
str
]
=
'infer'
batch_size
:
int
=
128
use_per_replica_batch_size
:
bool
=
True
num_devices
:
int
=
1
dtype
:
str
=
'float32'
one_hot
:
bool
=
True
augmenter
:
AugmentConfig
=
AugmentConfig
()
download
:
bool
=
False
shuffle_buffer_size
:
int
=
10000
file_shuffle_buffer_size
:
int
=
1024
skip_decoding
:
bool
=
True
cache
:
bool
=
False
tf_data_service
:
Optional
[
str
]
=
None
mean_subtract
:
bool
=
False
standardize
:
bool
=
False
@
property
def
has_data
(
self
):
"""Whether this dataset is has any data associated with it."""
return
self
.
name
or
self
.
data_dir
or
self
.
filenames
@
dataclasses
.
dataclass
class
ImageNetConfig
(
DatasetConfig
):
"""The base ImageNet dataset config."""
name
:
str
=
'imagenet2012'
# Note: for large datasets like ImageNet, using records is faster than tfds
builder
:
str
=
'records'
image_size
:
int
=
224
num_channels
:
int
=
3
num_examples
:
int
=
1281167
num_classes
:
int
=
1000
batch_size
:
int
=
128
@
dataclasses
.
dataclass
class
Cifar10Config
(
DatasetConfig
):
"""The base CIFAR-10 dataset config."""
name
:
str
=
'cifar10'
image_size
:
int
=
224
batch_size
:
int
=
128
download
:
bool
=
True
cache
:
bool
=
True
class
DatasetBuilder
:
"""An object for building datasets.
Allows building various pipelines fetching examples, preprocessing, etc.
Maintains additional state information calculated from the dataset, i.e.,
training set split, batch size, and number of steps (batches).
"""
def
__init__
(
self
,
config
:
DatasetConfig
,
**
overrides
:
Any
):
"""Initialize the builder from the config."""
self
.
config
=
config
.
replace
(
**
overrides
)
self
.
builder_info
=
None
if
self
.
config
.
augmenter
is
not
None
:
logging
.
info
(
'Using augmentation: %s'
,
self
.
config
.
augmenter
.
name
)
self
.
augmenter
=
self
.
config
.
augmenter
.
build
()
else
:
self
.
augmenter
=
None
@
property
def
is_training
(
self
)
->
bool
:
"""Whether this is the training set."""
return
self
.
config
.
split
==
'train'
@
property
def
batch_size
(
self
)
->
int
:
"""The batch size, multiplied by the number of replicas (if configured)."""
if
self
.
config
.
use_per_replica_batch_size
:
return
self
.
config
.
batch_size
*
self
.
config
.
num_devices
else
:
return
self
.
config
.
batch_size
@
property
def
global_batch_size
(
self
):
"""The global batch size across all replicas."""
return
self
.
batch_size
@
property
def
local_batch_size
(
self
):
"""The base unscaled batch size."""
if
self
.
config
.
use_per_replica_batch_size
:
return
self
.
config
.
batch_size
else
:
return
self
.
config
.
batch_size
//
self
.
config
.
num_devices
@
property
def
num_steps
(
self
)
->
int
:
"""The number of steps (batches) to exhaust this dataset."""
# Always divide by the global batch size to get the correct # of steps
return
self
.
num_examples
//
self
.
global_batch_size
@
property
def
dtype
(
self
)
->
tf
.
dtypes
.
DType
:
"""Converts the config's dtype string to a tf dtype.
Returns:
A mapping from string representation of a dtype to the `tf.dtypes.DType`.
Raises:
ValueError if the config's dtype is not supported.
"""
dtype_map
=
{
'float32'
:
tf
.
float32
,
'bfloat16'
:
tf
.
bfloat16
,
'float16'
:
tf
.
float16
,
'fp32'
:
tf
.
float32
,
'bf16'
:
tf
.
bfloat16
,
}
try
:
return
dtype_map
[
self
.
config
.
dtype
]
except
:
raise
ValueError
(
'Invalid DType provided. Supported types: {}'
.
format
(
dtype_map
.
keys
()))
@
property
def
image_size
(
self
)
->
int
:
"""The size of each image (can be inferred from the dataset)."""
if
self
.
config
.
image_size
==
'infer'
:
return
self
.
info
.
features
[
'image'
].
shape
[
0
]
else
:
return
int
(
self
.
config
.
image_size
)
@
property
def
num_channels
(
self
)
->
int
:
"""The number of image channels (can be inferred from the dataset)."""
if
self
.
config
.
num_channels
==
'infer'
:
return
self
.
info
.
features
[
'image'
].
shape
[
-
1
]
else
:
return
int
(
self
.
config
.
num_channels
)
@
property
def
num_examples
(
self
)
->
int
:
"""The number of examples (can be inferred from the dataset)."""
if
self
.
config
.
num_examples
==
'infer'
:
return
self
.
info
.
splits
[
self
.
config
.
split
].
num_examples
else
:
return
int
(
self
.
config
.
num_examples
)
@
property
def
num_classes
(
self
)
->
int
:
"""The number of classes (can be inferred from the dataset)."""
if
self
.
config
.
num_classes
==
'infer'
:
return
self
.
info
.
features
[
'label'
].
num_classes
else
:
return
int
(
self
.
config
.
num_classes
)
@
property
def
info
(
self
)
->
tfds
.
core
.
DatasetInfo
:
"""The TFDS dataset info, if available."""
try
:
if
self
.
builder_info
is
None
:
self
.
builder_info
=
tfds
.
builder
(
self
.
config
.
name
).
info
except
ConnectionError
as
e
:
logging
.
error
(
'Failed to use TFDS to load info. Please set dataset info '
'(image_size, num_channels, num_examples, num_classes) in '
'the dataset config.'
)
raise
e
return
self
.
builder_info
def
build
(
self
,
strategy
:
Optional
[
tf
.
distribute
.
Strategy
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Construct a dataset end-to-end and return it using an optional strategy.
Args:
strategy: a strategy that, if passed, will distribute the dataset
according to that strategy. If passed and `num_devices > 1`,
`use_per_replica_batch_size` must be set to `True`.
Returns:
A TensorFlow dataset outputting batched images and labels.
"""
if
strategy
:
if
strategy
.
num_replicas_in_sync
!=
self
.
config
.
num_devices
:
logging
.
warn
(
'Passed a strategy with %d devices, but expected'
'%d devices.'
,
strategy
.
num_replicas_in_sync
,
self
.
config
.
num_devices
)
dataset
=
strategy
.
distribute_datasets_from_function
(
self
.
_build
)
else
:
dataset
=
self
.
_build
()
return
dataset
def
_build
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Construct a dataset end-to-end and return it.
Args:
input_context: An optional context provided by `tf.distribute` for
cross-replica training.
Returns:
A TensorFlow dataset outputting batched images and labels.
"""
builders
=
{
'tfds'
:
self
.
load_tfds
,
'records'
:
self
.
load_records
,
'synthetic'
:
self
.
load_synthetic
,
}
builder
=
builders
.
get
(
self
.
config
.
builder
,
None
)
if
builder
is
None
:
raise
ValueError
(
'Unknown builder type {}'
.
format
(
self
.
config
.
builder
))
self
.
input_context
=
input_context
dataset
=
builder
()
dataset
=
self
.
pipeline
(
dataset
)
return
dataset
def
load_tfds
(
self
)
->
tf
.
data
.
Dataset
:
"""Return a dataset loading files from TFDS."""
logging
.
info
(
'Using TFDS to load data.'
)
builder
=
tfds
.
builder
(
self
.
config
.
name
,
data_dir
=
self
.
config
.
data_dir
)
if
self
.
config
.
download
:
builder
.
download_and_prepare
()
decoders
=
{}
if
self
.
config
.
skip_decoding
:
decoders
[
'image'
]
=
tfds
.
decode
.
SkipDecoding
()
read_config
=
tfds
.
ReadConfig
(
interleave_cycle_length
=
10
,
interleave_block_length
=
1
,
input_context
=
self
.
input_context
)
dataset
=
builder
.
as_dataset
(
split
=
self
.
config
.
split
,
as_supervised
=
True
,
shuffle_files
=
True
,
decoders
=
decoders
,
read_config
=
read_config
)
return
dataset
def
load_records
(
self
)
->
tf
.
data
.
Dataset
:
"""Return a dataset loading files with TFRecords."""
logging
.
info
(
'Using TFRecords to load data.'
)
if
self
.
config
.
filenames
is
None
:
if
self
.
config
.
data_dir
is
None
:
raise
ValueError
(
'Dataset must specify a path for the data files.'
)
file_pattern
=
os
.
path
.
join
(
self
.
config
.
data_dir
,
'{}*'
.
format
(
self
.
config
.
split
))
dataset
=
tf
.
data
.
Dataset
.
list_files
(
file_pattern
,
shuffle
=
False
)
else
:
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
self
.
config
.
filenames
)
return
dataset
def
load_synthetic
(
self
)
->
tf
.
data
.
Dataset
:
"""Return a dataset generating dummy synthetic data."""
logging
.
info
(
'Generating a synthetic dataset.'
)
def
generate_data
(
_
):
image
=
tf
.
zeros
([
self
.
image_size
,
self
.
image_size
,
self
.
num_channels
],
dtype
=
self
.
dtype
)
label
=
tf
.
zeros
([
1
],
dtype
=
tf
.
int32
)
return
image
,
label
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
map
(
generate_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
def
pipeline
(
self
,
dataset
:
tf
.
data
.
Dataset
)
->
tf
.
data
.
Dataset
:
"""Build a pipeline fetching, shuffling, and preprocessing the dataset.
Args:
dataset: A `tf.data.Dataset` that loads raw files.
Returns:
A TensorFlow dataset outputting batched images and labels.
"""
if
(
self
.
config
.
builder
!=
'tfds'
and
self
.
input_context
and
self
.
input_context
.
num_input_pipelines
>
1
):
dataset
=
dataset
.
shard
(
self
.
input_context
.
num_input_pipelines
,
self
.
input_context
.
input_pipeline_id
)
logging
.
info
(
'Sharding the dataset: input_pipeline_id=%d '
'num_input_pipelines=%d'
,
self
.
input_context
.
num_input_pipelines
,
self
.
input_context
.
input_pipeline_id
)
if
self
.
is_training
and
self
.
config
.
builder
==
'records'
:
# Shuffle the input files.
dataset
.
shuffle
(
buffer_size
=
self
.
config
.
file_shuffle_buffer_size
)
if
self
.
is_training
and
not
self
.
config
.
cache
:
dataset
=
dataset
.
repeat
()
if
self
.
config
.
builder
==
'records'
:
# Read the data from disk in parallel
dataset
=
dataset
.
interleave
(
tf
.
data
.
TFRecordDataset
,
cycle_length
=
10
,
block_length
=
1
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
if
self
.
config
.
cache
:
dataset
=
dataset
.
cache
()
if
self
.
is_training
:
dataset
=
dataset
.
shuffle
(
self
.
config
.
shuffle_buffer_size
)
dataset
=
dataset
.
repeat
()
# Parse, pre-process, and batch the data in parallel
if
self
.
config
.
builder
==
'records'
:
preprocess
=
self
.
parse_record
else
:
preprocess
=
self
.
preprocess
dataset
=
dataset
.
map
(
preprocess
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
if
self
.
input_context
and
self
.
config
.
num_devices
>
1
:
if
not
self
.
config
.
use_per_replica_batch_size
:
raise
ValueError
(
'The builder does not support a global batch size with more than '
'one replica. Got {} replicas. Please set a '
'`per_replica_batch_size` and enable '
'`use_per_replica_batch_size=True`.'
.
format
(
self
.
config
.
num_devices
))
# The batch size of the dataset will be multiplied by the number of
# replicas automatically when strategy.distribute_datasets_from_function
# is called, so we use local batch size here.
dataset
=
dataset
.
batch
(
self
.
local_batch_size
,
drop_remainder
=
self
.
is_training
)
else
:
dataset
=
dataset
.
batch
(
self
.
global_batch_size
,
drop_remainder
=
self
.
is_training
)
# Prefetch overlaps in-feed with training
dataset
=
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
if
self
.
config
.
tf_data_service
:
if
not
hasattr
(
tf
.
data
.
experimental
,
'service'
):
raise
ValueError
(
'The tf_data_service flag requires Tensorflow version '
'>= 2.3.0, but the version is {}'
.
format
(
tf
.
__version__
))
dataset
=
dataset
.
apply
(
tf
.
data
.
experimental
.
service
.
distribute
(
processing_mode
=
'parallel_epochs'
,
service
=
self
.
config
.
tf_data_service
,
job_name
=
'resnet_train'
))
dataset
=
dataset
.
prefetch
(
buffer_size
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
def
parse_record
(
self
,
record
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
"""Parse an ImageNet record from a serialized string Tensor."""
keys_to_features
=
{
'image/encoded'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
''
),
'image/format'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
'jpeg'
),
'image/class/label'
:
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
,
-
1
),
'image/class/text'
:
tf
.
io
.
FixedLenFeature
([],
tf
.
string
,
''
),
'image/object/bbox/xmin'
:
tf
.
io
.
VarLenFeature
(
dtype
=
tf
.
float32
),
'image/object/bbox/ymin'
:
tf
.
io
.
VarLenFeature
(
dtype
=
tf
.
float32
),
'image/object/bbox/xmax'
:
tf
.
io
.
VarLenFeature
(
dtype
=
tf
.
float32
),
'image/object/bbox/ymax'
:
tf
.
io
.
VarLenFeature
(
dtype
=
tf
.
float32
),
'image/object/class/label'
:
tf
.
io
.
VarLenFeature
(
dtype
=
tf
.
int64
),
}
parsed
=
tf
.
io
.
parse_single_example
(
record
,
keys_to_features
)
label
=
tf
.
reshape
(
parsed
[
'image/class/label'
],
shape
=
[
1
])
# Subtract one so that labels are in [0, 1000)
label
-=
1
image_bytes
=
tf
.
reshape
(
parsed
[
'image/encoded'
],
shape
=
[])
image
,
label
=
self
.
preprocess
(
image_bytes
,
label
)
return
image
,
label
def
preprocess
(
self
,
image
:
tf
.
Tensor
,
label
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
"""Apply image preprocessing and augmentation to the image and label."""
if
self
.
is_training
:
image
=
preprocessing
.
preprocess_for_train
(
image
,
image_size
=
self
.
image_size
,
mean_subtract
=
self
.
config
.
mean_subtract
,
standardize
=
self
.
config
.
standardize
,
dtype
=
self
.
dtype
,
augmenter
=
self
.
augmenter
)
else
:
image
=
preprocessing
.
preprocess_for_eval
(
image
,
image_size
=
self
.
image_size
,
num_channels
=
self
.
num_channels
,
mean_subtract
=
self
.
config
.
mean_subtract
,
standardize
=
self
.
config
.
standardize
,
dtype
=
self
.
dtype
)
label
=
tf
.
cast
(
label
,
tf
.
int32
)
if
self
.
config
.
one_hot
:
label
=
tf
.
one_hot
(
label
,
self
.
num_classes
)
label
=
tf
.
reshape
(
label
,
[
self
.
num_classes
])
return
image
,
label
@
classmethod
def
from_params
(
cls
,
*
args
,
**
kwargs
):
"""Construct a dataset builder from a default config and any overrides."""
config
=
DatasetConfig
.
from_args
(
*
args
,
**
kwargs
)
return
cls
(
config
)
models-2.13.1/official/legacy/image_classification/efficientnet/__init__.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
models-2.13.1/official/legacy/image_classification/efficientnet/common_modules.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common modeling utilities."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
typing
import
Optional
,
Text
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow.compat.v1
as
tf1
from
tensorflow.python.tpu
import
tpu_function
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
TpuBatchNormalization
(
tf
.
keras
.
layers
.
BatchNormalization
):
"""Cross replica batch normalization."""
def
__init__
(
self
,
fused
:
Optional
[
bool
]
=
False
,
**
kwargs
):
if
fused
in
(
True
,
None
):
raise
ValueError
(
'TpuBatchNormalization does not support fused=True.'
)
super
(
TpuBatchNormalization
,
self
).
__init__
(
fused
=
fused
,
**
kwargs
)
def
_cross_replica_average
(
self
,
t
:
tf
.
Tensor
,
num_shards_per_group
:
int
):
"""Calculates the average value of input tensor across TPU replicas."""
num_shards
=
tpu_function
.
get_tpu_context
().
number_of_shards
group_assignment
=
None
if
num_shards_per_group
>
1
:
if
num_shards
%
num_shards_per_group
!=
0
:
raise
ValueError
(
'num_shards: %d mod shards_per_group: %d, should be 0'
%
(
num_shards
,
num_shards_per_group
))
num_groups
=
num_shards
//
num_shards_per_group
group_assignment
=
[[
x
for
x
in
range
(
num_shards
)
if
x
//
num_shards_per_group
==
y
]
for
y
in
range
(
num_groups
)]
return
tf1
.
tpu
.
cross_replica_sum
(
t
,
group_assignment
)
/
tf
.
cast
(
num_shards_per_group
,
t
.
dtype
)
def
_moments
(
self
,
inputs
:
tf
.
Tensor
,
reduction_axes
:
int
,
keep_dims
:
int
):
"""Compute the mean and variance: it overrides the original _moments."""
shard_mean
,
shard_variance
=
super
(
TpuBatchNormalization
,
self
).
_moments
(
inputs
,
reduction_axes
,
keep_dims
=
keep_dims
)
num_shards
=
tpu_function
.
get_tpu_context
().
number_of_shards
or
1
if
num_shards
<=
8
:
# Skip cross_replica for 2x2 or smaller slices.
num_shards_per_group
=
1
else
:
num_shards_per_group
=
max
(
8
,
num_shards
//
8
)
if
num_shards_per_group
>
1
:
# Compute variance using: Var[X]= E[X^2] - E[X]^2.
shard_square_of_mean
=
tf
.
math
.
square
(
shard_mean
)
shard_mean_of_square
=
shard_variance
+
shard_square_of_mean
group_mean
=
self
.
_cross_replica_average
(
shard_mean
,
num_shards_per_group
)
group_mean_of_square
=
self
.
_cross_replica_average
(
shard_mean_of_square
,
num_shards_per_group
)
group_variance
=
group_mean_of_square
-
tf
.
math
.
square
(
group_mean
)
return
(
group_mean
,
group_variance
)
else
:
return
(
shard_mean
,
shard_variance
)
def
get_batch_norm
(
batch_norm_type
:
Text
)
->
tf
.
keras
.
layers
.
BatchNormalization
:
"""A helper to create a batch normalization getter.
Args:
batch_norm_type: The type of batch normalization layer implementation. `tpu`
will use `TpuBatchNormalization`.
Returns:
An instance of `tf.keras.layers.BatchNormalization`.
"""
if
batch_norm_type
==
'tpu'
:
return
TpuBatchNormalization
return
tf
.
keras
.
layers
.
BatchNormalization
# pytype: disable=bad-return-type # typed-keras
def
count_params
(
model
,
trainable_only
=
True
):
"""Returns the count of all model parameters, or just trainable ones."""
if
not
trainable_only
:
return
model
.
count_params
()
else
:
return
int
(
np
.
sum
([
tf
.
keras
.
backend
.
count_params
(
p
)
for
p
in
model
.
trainable_weights
]))
def
load_weights
(
model
:
tf
.
keras
.
Model
,
model_weights_path
:
Text
,
weights_format
:
Text
=
'saved_model'
):
"""Load model weights from the given file path.
Args:
model: the model to load weights into
model_weights_path: the path of the model weights
weights_format: the model weights format. One of 'saved_model', 'h5', or
'checkpoint'.
"""
if
weights_format
==
'saved_model'
:
loaded_model
=
tf
.
keras
.
models
.
load_model
(
model_weights_path
)
model
.
set_weights
(
loaded_model
.
get_weights
())
else
:
model
.
load_weights
(
model_weights_path
)
Prev
1
…
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment