Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0650ea24
Commit
0650ea24
authored
Jul 13, 2021
by
Fan Yang
Committed by
A. Unique TensorFlower
Jul 13, 2021
Browse files
Release example project.
PiperOrigin-RevId: 384509754
parent
c9967b10
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
689 additions
and
0 deletions
+689
-0
official/vision/beta/projects/example/example_config.py
official/vision/beta/projects/example/example_config.py
+117
-0
official/vision/beta/projects/example/example_config_local.yaml
...al/vision/beta/projects/example/example_config_local.yaml
+32
-0
official/vision/beta/projects/example/example_config_tpu.yaml
...cial/vision/beta/projects/example/example_config_tpu.yaml
+35
-0
official/vision/beta/projects/example/example_input.py
official/vision/beta/projects/example/example_input.py
+137
-0
official/vision/beta/projects/example/example_model.py
official/vision/beta/projects/example/example_model.py
+102
-0
official/vision/beta/projects/example/example_task.py
official/vision/beta/projects/example/example_task.py
+209
-0
official/vision/beta/projects/example/registry_imports.py
official/vision/beta/projects/example/registry_imports.py
+27
-0
official/vision/beta/projects/example/train.py
official/vision/beta/projects/example/train.py
+30
-0
No files found.
official/vision/beta/projects/example/example_config.py
0 → 100644
View file @
0650ea24
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example experiment configuration definition."""
from
typing
import
List
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.modeling
import
optimization
@
dataclasses
.
dataclass
class
ExampleDataConfig
(
cfg
.
DataConfig
):
"""Input config for training. Add more fields as needed."""
input_path
:
str
=
''
global_batch_size
:
int
=
0
is_training
:
bool
=
True
dtype
:
str
=
'float32'
shuffle_buffer_size
:
int
=
10000
cycle_length
:
int
=
10
file_type
:
str
=
'tfrecord'
@
dataclasses
.
dataclass
class
ExampleModel
(
hyperparams
.
Config
):
"""The model config. Used by build_example_model function."""
num_classes
:
int
=
0
input_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
@
dataclasses
.
dataclass
class
Losses
(
hyperparams
.
Config
):
l2_weight_decay
:
float
=
0.0
@
dataclasses
.
dataclass
class
Evaluation
(
hyperparams
.
Config
):
top_k
:
int
=
5
@
dataclasses
.
dataclass
class
ExampleTask
(
cfg
.
TaskConfig
):
"""The task config."""
model
:
ExampleModel
=
ExampleModel
()
train_data
:
ExampleDataConfig
=
ExampleDataConfig
(
is_training
=
True
)
validation_data
:
ExampleDataConfig
=
ExampleDataConfig
(
is_training
=
False
)
losses
:
Losses
=
Losses
()
evaluation
:
Evaluation
=
Evaluation
()
@
exp_factory
.
register_config_factory
(
'tf_vision_example_experiment'
)
def
tf_vision_example_experiment
()
->
cfg
.
ExperimentConfig
:
"""Definition of a full example experiment."""
train_batch_size
=
256
eval_batch_size
=
256
steps_per_epoch
=
10
config
=
cfg
.
ExperimentConfig
(
task
=
ExampleTask
(
model
=
ExampleModel
(
num_classes
=
10
,
input_size
=
[
128
,
128
,
3
]),
losses
=
Losses
(
l2_weight_decay
=
1e-4
),
train_data
=
ExampleDataConfig
(
input_path
=
'/path/to/train*'
,
is_training
=
True
,
global_batch_size
=
train_batch_size
),
validation_data
=
ExampleDataConfig
(
input_path
=
'/path/to/valid*'
,
is_training
=
False
,
global_batch_size
=
eval_batch_size
)),
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
train_steps
=
90
*
steps_per_epoch
,
validation_steps
=
steps_per_epoch
,
validation_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'sgd'
,
'sgd'
:
{
'momentum'
:
0.9
}
},
'learning_rate'
:
{
'type'
:
'cosine'
,
'cosine'
:
{
'initial_learning_rate'
:
1.6
,
'decay_steps'
:
350
*
steps_per_epoch
}
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
'warmup_steps'
:
5
*
steps_per_epoch
,
'warmup_learning_rate'
:
0
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
official/vision/beta/projects/example/example_config_local.yaml
0 → 100644
View file @
0650ea24
task
:
model
:
num_classes
:
1001
input_size
:
[
128
,
128
,
3
]
train_data
:
input_path
:
'
imagenet-2012-tfrecord/train*'
is_training
:
true
global_batch_size
:
64
dtype
:
'
bfloat16'
validation_data
:
input_path
:
'
imagenet-2012-tfrecord/valid*'
is_training
:
false
global_batch_size
:
64
dtype
:
'
bfloat16'
drop_remainder
:
false
trainer
:
train_steps
:
62400
validation_steps
:
13
validation_interval
:
312
steps_per_loop
:
312
summary_interval
:
312
checkpoint_interval
:
312
optimizer_config
:
optimizer
:
type
:
'
sgd'
sgd
:
momentum
:
0.9
learning_rate
:
type
:
'
stepwise'
stepwise
:
boundaries
:
[
18750
,
37500
,
50000
]
values
:
[
0.1
,
0.01
,
0.001
,
0.0001
]
official/vision/beta/projects/example/example_config_tpu.yaml
0 → 100644
View file @
0650ea24
runtime
:
distribution_strategy
:
'
tpu'
mixed_precision_dtype
:
'
bfloat16'
task
:
model
:
num_classes
:
1001
input_size
:
[
128
,
128
,
3
]
train_data
:
input_path
:
'
imagenet-2012-tfrecord/train*'
is_training
:
true
global_batch_size
:
4096
dtype
:
'
bfloat16'
validation_data
:
input_path
:
'
imagenet-2012-tfrecord/valid*'
is_training
:
false
global_batch_size
:
4096
dtype
:
'
bfloat16'
drop_remainder
:
false
trainer
:
train_steps
:
62400
validation_steps
:
13
validation_interval
:
312
steps_per_loop
:
312
summary_interval
:
312
checkpoint_interval
:
312
optimizer_config
:
optimizer
:
type
:
'
sgd'
sgd
:
momentum
:
0.9
learning_rate
:
type
:
'
stepwise'
stepwise
:
boundaries
:
[
18750
,
37500
,
50000
]
values
:
[
0.1
,
0.01
,
0.001
,
0.0001
]
official/vision/beta/projects/example/example_input.py
0 → 100644
View file @
0650ea24
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example classification decoder and parser.
This file defines the Decoder and Parser to load data. The example is shown on
loading standard tf.Example data but non-standard tf.Example or other data
format can be supported by implementing proper decoder and parser.
"""
from
typing
import
Mapping
,
List
,
Tuple
# Import libraries
import
tensorflow
as
tf
from
official.vision.beta.dataloaders
import
decoder
from
official.vision.beta.dataloaders
import
parser
from
official.vision.beta.ops
import
preprocess_ops
MEAN_RGB
=
(
0.485
*
255
,
0.456
*
255
,
0.406
*
255
)
STDDEV_RGB
=
(
0.229
*
255
,
0.224
*
255
,
0.225
*
255
)
class
Decoder
(
decoder
.
Decoder
):
"""A tf.Example decoder for classification task."""
def
__init__
(
self
):
"""Initializes the decoder.
The constructor defines the mapping between the field name and the value
from an input tf.Example. For example, we define two fields for image bytes
and labels. There is no limit on the number of fields to decode.
"""
self
.
_keys_to_features
=
{
'image/encoded'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
),
'image/class/label'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
int64
,
default_value
=-
1
)
}
def
decode
(
self
,
serialized_example
:
tf
.
train
.
Example
)
->
Mapping
[
str
,
tf
.
Tensor
]:
"""Decodes a tf.Example to a dictionary.
This function decodes a serialized tf.Example to a dictionary. The output
will be consumed by `_parse_train_data` and `_parse_validation_data` in
Parser.
Args:
serialized_example: A serialized tf.Example.
Returns:
A dictionary of field key name and decoded tensor mapping.
"""
return
tf
.
io
.
parse_single_example
(
serialized_example
,
self
.
_keys_to_features
)
class
Parser
(
parser
.
Parser
):
"""Parser to parse an image and its annotations.
To define own Parser, client should override _parse_train_data and
_parse_eval_data functions, where decoded tensors are parsed with optional
pre-processing steps. The output from the two functions can be any structure
like tuple, list or dictionary.
"""
def
__init__
(
self
,
output_size
:
List
[
int
],
num_classes
:
float
):
"""Initializes parameters for parsing annotations in the dataset.
This example only takes two arguments but one can freely add as many
arguments as needed. For example, pre-processing and augmentations usually
happen in Parser, and related parameters can be passed in by this
constructor.
Args:
output_size: `Tensor` or `list` for [height, width] of output image.
num_classes: `float`, number of classes.
"""
self
.
_output_size
=
output_size
self
.
_num_classes
=
num_classes
self
.
_dtype
=
tf
.
float32
def
_parse_data
(
self
,
decoded_tensors
:
Mapping
[
str
,
tf
.
Tensor
])
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
label
=
tf
.
cast
(
decoded_tensors
[
'image/class/label'
],
dtype
=
tf
.
int32
)
image_bytes
=
decoded_tensors
[
'image/encoded'
]
image
=
tf
.
io
.
decode_jpeg
(
image_bytes
,
channels
=
3
)
image
=
tf
.
image
.
resize
(
image
,
self
.
_output_size
,
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
image
=
tf
.
ensure_shape
(
image
,
self
.
_output_size
+
[
3
])
# Normalizes image with mean and std pixel values.
image
=
preprocess_ops
.
normalize_image
(
image
,
offset
=
MEAN_RGB
,
scale
=
STDDEV_RGB
)
image
=
tf
.
image
.
convert_image_dtype
(
image
,
self
.
_dtype
)
return
image
,
label
def
_parse_train_data
(
self
,
decoded_tensors
:
Mapping
[
str
,
tf
.
Tensor
])
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
"""Parses data for training.
Args:
decoded_tensors: A dictionary of field key name and decoded tensor mapping
from Decoder.
Returns:
A tuple of (image, label) tensors.
"""
return
self
.
_parse_data
(
decoded_tensors
)
def
_parse_eval_data
(
self
,
decoded_tensors
:
Mapping
[
str
,
tf
.
Tensor
])
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
"""Parses data for evaluation.
Args:
decoded_tensors: A dictionary of field key name and decoded tensor mapping
from Decoder.
Returns:
A tuple of (image, label) tensors.
"""
return
self
.
_parse_data
(
decoded_tensors
)
official/vision/beta/projects/example/example_model.py
0 → 100644
View file @
0650ea24
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A sample model implementation.
This is only a dummy example to showcase how a model is composed. It is usually
not needed to implement a modedl from scratch. Most SoTA models can be found and
directly used from `official/vision/beta/modeling` directory.
"""
from
typing
import
Any
,
Mapping
# Import libraries
import
tensorflow
as
tf
from
official.vision.beta.projects.example
import
example_config
as
example_cfg
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
ExampleModel
(
tf
.
keras
.
Model
):
"""A example model class.
A model is a subclass of tf.keras.Model where layers are built in the
constructor.
"""
def
__init__
(
self
,
num_classes
:
int
,
input_specs
:
tf
.
keras
.
layers
.
InputSpec
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
**
kwargs
):
"""Initializes the example model.
All layers are defined in the constructor, and config is recorded in the
`_config_dict` object for serialization.
Args:
num_classes: The number of classes in classification task.
input_specs: A `tf.keras.layers.InputSpec` spec of the input tensor.
**kwargs: Additional keyword arguments to be passed.
"""
inputs
=
tf
.
keras
.
Input
(
shape
=
input_specs
.
shape
[
1
:],
name
=
input_specs
.
name
)
outputs
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
16
,
kernel_size
=
3
,
strides
=
2
,
padding
=
'same'
,
use_bias
=
False
)(
inputs
)
outputs
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
32
,
kernel_size
=
3
,
strides
=
2
,
padding
=
'same'
,
use_bias
=
False
)(
outputs
)
outputs
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
64
,
kernel_size
=
3
,
strides
=
2
,
padding
=
'same'
,
use_bias
=
False
)(
outputs
)
outputs
=
tf
.
keras
.
layers
.
GlobalAveragePooling2D
()(
outputs
)
outputs
=
tf
.
keras
.
layers
.
Dense
(
1024
,
activation
=
'relu'
)(
outputs
)
outputs
=
tf
.
keras
.
layers
.
Dense
(
num_classes
)(
outputs
)
super
().
__init__
(
inputs
=
inputs
,
outputs
=
outputs
,
**
kwargs
)
self
.
_input_specs
=
input_specs
self
.
_config_dict
=
{
'num_classes'
:
num_classes
,
'input_specs'
:
input_specs
}
def
get_config
(
self
)
->
Mapping
[
str
,
Any
]:
"""Gets the config of this model."""
return
self
.
_config_dict
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
"""Constructs an instance of this model from input config."""
return
cls
(
**
config
)
def
build_example_model
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
:
example_cfg
.
ExampleModel
,
**
kwargs
)
->
tf
.
keras
.
Model
:
"""Builds and returns the example model.
This function is the main entry point to build a model. Commonly, it build a
model by building a backbone, decoder and head. An example of building a
classification model is at
third_party/tensorflow_models/official/vision/beta/modeling/backbones/resnet.py.
However, it is not mandatory for all models to have these three pieces
exactly. Depending on the task, model can be as simple as the example model
here or more complex, such as multi-head architecture.
Args:
input_specs: The specs of the input layer that defines input size.
model_config: The config containing parameters to build a model.
**kwargs: Additional keyword arguments to be passed.
Returns:
A tf.keras.Model object.
"""
return
ExampleModel
(
num_classes
=
model_config
.
num_classes
,
input_specs
=
input_specs
,
**
kwargs
)
official/vision/beta/projects/example/example_task.py
0 → 100644
View file @
0650ea24
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An example task definition for image classification."""
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Sequence
,
Mapping
import
tensorflow
as
tf
from
official.common
import
dataset_fn
from
official.core
import
base_task
from
official.core
import
task_factory
from
official.modeling
import
tf_utils
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.projects.example
import
example_config
as
exp_cfg
from
official.vision.beta.projects.example
import
example_input
from
official.vision.beta.projects.example
import
example_model
@
task_factory
.
register_task_cls
(
exp_cfg
.
ExampleTask
)
class
ExampleTask
(
base_task
.
Task
):
"""Class of an example task.
A task is a subclass of base_task.Task that defines model, input, loss, metric
and one training and evaluation step, etc.
"""
def
build_model
(
self
)
->
tf
.
keras
.
Model
:
"""Builds a model."""
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
self
.
task_config
.
model
.
input_size
)
model
=
example_model
.
build_example_model
(
input_specs
=
input_specs
,
model_config
=
self
.
task_config
.
model
)
return
model
def
build_inputs
(
self
,
params
:
exp_cfg
.
ExampleDataConfig
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Builds input.
The input from this function is a tf.data.Dataset that has gone through
pre-processing steps, such as augmentation, batching, shuffuling, etc.
Args:
params: The experiment config.
input_context: An optional InputContext used by input reader.
Returns:
A tf.data.Dataset object.
"""
num_classes
=
self
.
task_config
.
model
.
num_classes
input_size
=
self
.
task_config
.
model
.
input_size
decoder
=
example_input
.
Decoder
()
parser
=
example_input
.
Parser
(
output_size
=
input_size
[:
2
],
num_classes
=
num_classes
)
reader
=
input_reader_factory
.
input_reader_generator
(
params
,
dataset_fn
=
dataset_fn
.
pick_dataset_fn
(
params
.
file_type
),
decoder_fn
=
decoder
.
decode
,
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
))
dataset
=
reader
.
read
(
input_context
=
input_context
)
return
dataset
def
build_losses
(
self
,
labels
:
tf
.
Tensor
,
model_outputs
:
tf
.
Tensor
,
aux_losses
:
Optional
[
Any
]
=
None
)
->
tf
.
Tensor
:
"""Builds losses for training and validation.
Args:
labels: Input groundtruth labels.
model_outputs: Output of the model.
aux_losses: The auxiliarly loss tensors, i.e. `losses` in tf.keras.Model.
Returns:
The total loss tensor.
"""
total_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
,
model_outputs
,
from_logits
=
True
)
total_loss
=
tf_utils
.
safe_mean
(
total_loss
)
if
aux_losses
:
total_loss
+=
tf
.
add_n
(
aux_losses
)
return
total_loss
def
build_metrics
(
self
,
training
:
bool
=
True
)
->
Sequence
[
tf
.
keras
.
metrics
.
Metric
]:
"""Gets streaming metrics for training/validation.
This function builds and returns a list of metrics to compute during
training and validation. The list contains objects of subclasses of
tf.keras.metrics.Metric. Training and validation can have different metrics.
Args:
training: Whether the metric is for training or not.
Returns:
A list of tf.keras.metrics.Metric objects.
"""
k
=
self
.
task_config
.
evaluation
.
top_k
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
SparseTopKCategoricalAccuracy
(
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))
]
return
metrics
def
train_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
metrics
:
Optional
[
List
[
Any
]]
=
None
)
->
Mapping
[
str
,
Any
]:
"""Does forward and backward.
This example assumes input is a tuple of (features, labels), which follows
the output from data loader, i.e., Parser. The output from Parser is fed
into train_step to perform one step forward and backward pass. Other data
structure, such as dictionary, can also be used, as long as it is consistent
between output from Parser and input used here.
Args:
inputs: A tuple of of input tensors of (features, labels).
model: A tf.keras.Model instance.
optimizer: The optimizer for this training step.
metrics: A nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
labels
=
inputs
num_replicas
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
with
tf
.
GradientTape
()
as
tape
:
outputs
=
model
(
features
,
training
=
True
)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
# Computes per-replica loss.
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss
=
loss
/
num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
scaled_loss
=
optimizer
.
get_scaled_loss
(
scaled_loss
)
tvars
=
model
.
trainable_variables
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
logs
=
{
self
.
loss
:
loss
}
if
metrics
:
self
.
process_metrics
(
metrics
,
labels
,
outputs
)
return
logs
def
validation_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
metrics
:
Optional
[
List
[
Any
]]
=
None
)
->
Mapping
[
str
,
Any
]:
"""Runs validatation step.
Args:
inputs: A tuple of of input tensors of (features, labels).
model: A tf.keras.Model instance.
metrics: A nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
labels
=
inputs
outputs
=
self
.
inference_step
(
features
,
model
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
logs
=
{
self
.
loss
:
loss
}
if
metrics
:
self
.
process_metrics
(
metrics
,
labels
,
outputs
)
return
logs
def
inference_step
(
self
,
inputs
:
tf
.
Tensor
,
model
:
tf
.
keras
.
Model
)
->
Any
:
"""Performs the forward step. It is used in validation_step."""
return
model
(
inputs
,
training
=
False
)
official/vision/beta/projects/example/registry_imports.py
0 → 100644
View file @
0650ea24
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All necessary imports for registration.
Custom models, task, configs, etc need to be imported to registry so they can be
picked up by the trainer. They can be included in this file so you do not need
to handle each file separately.
"""
# pylint: disable=unused-import
from
official.common
import
registry_imports
from
official.vision.beta.projects.example
import
example_config
from
official.vision.beta.projects.example
import
example_input
from
official.vision.beta.projects.example
import
example_model
from
official.vision.beta.projects.example
import
example_task
official/vision/beta/projects/example/train.py
0 → 100644
View file @
0650ea24
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Vision trainer.
All custom registry are imported from registry_imports. Here we use default
trainer so we directly call train.main. If you need to customize the trainer,
branch from `official/vision/beta/train.py` and make changes.
"""
from
absl
import
app
from
official.common
import
flags
as
tfm_flags
from
official.vision.beta
import
train
from
official.vision.beta.projects.example
import
registry_imports
# pylint: disable=unused-import
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
app
.
run
(
train
.
main
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment