Commit c8e6faf7 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 431756117
parent 13a5e4fb
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper utils for export library."""
from typing import List, Optional
import tensorflow as tf
# pylint: disable=g-long-lambda
def get_image_input_signatures(input_type: str,
batch_size: Optional[int],
input_image_size: List[int],
num_channels: int = 3):
"""Gets input signatures for an image.
Args:
input_type: A `str`, can be either tf_example, image_bytes, or image_tensor.
batch_size: `int` for batch size or None.
input_image_size: List[int] for the height and width of the input image.
num_channels: `int` for number of channels in the input image.
Returns:
tf.TensorSpec of the input tensor.
"""
if input_type == 'image_tensor':
input_signature = tf.TensorSpec(
shape=[batch_size] + [None] * len(input_image_size) + [num_channels],
dtype=tf.uint8)
elif input_type in ['image_bytes', 'serve_examples', 'tf_example']:
input_signature = tf.TensorSpec(shape=[batch_size], dtype=tf.string)
elif input_type == 'tflite':
input_signature = tf.TensorSpec(
shape=[1] + input_image_size + [num_channels], dtype=tf.float32)
else:
raise ValueError('Unrecognized `input_type`')
return input_signature
def decode_image(encoded_image_bytes: str,
input_image_size: List[int],
num_channels: int = 3,) -> tf.Tensor:
"""Decodes an image bytes to an image tensor.
Use `tf.image.decode_image` to decode an image if input is expected to be 2D
image; otherwise use `tf.io.decode_raw` to convert the raw bytes to tensor
and reshape it to desire shape.
Args:
encoded_image_bytes: An encoded image string to be decoded.
input_image_size: List[int] for the desired input size. This will be used to
infer whether the image is 2d or 3d.
num_channels: `int` for number of image channels.
Returns:
A decoded image tensor.
"""
if len(input_image_size) == 2:
# Decode an image if 2D input is expected.
image_tensor = tf.image.decode_image(
encoded_image_bytes, channels=num_channels)
else:
# Convert raw bytes into a tensor and reshape it, if not 2D input.
image_tensor = tf.io.decode_raw(encoded_image_bytes, out_type=tf.uint8)
image_tensor.set_shape([None] * len(input_image_size) + [num_channels])
return image_tensor
def decode_image_tf_example(tf_example_string_tensor: tf.train.Example,
input_image_size: List[int],
num_channels: int = 3,
encoded_key: str = 'image/encoded'
) -> tf.Tensor:
"""Decodes a TF Example to an image tensor."""
keys_to_features = {
encoded_key: tf.io.FixedLenFeature((), tf.string, default_value=''),
}
parsed_tensors = tf.io.parse_single_example(
serialized=tf_example_string_tensor, features=keys_to_features)
image_tensor = decode_image(
parsed_tensors[encoded_key],
input_image_size=input_image_size,
num_channels=num_channels)
return image_tensor
def parse_image(
inputs, input_type: str, input_image_size: List[int], num_channels: int):
"""Parses image."""
if input_type in ['tf_example', 'serve_examples']:
decode_image_tf_example_fn = (
lambda x: decode_image_tf_example(x, input_image_size, num_channels))
image_tensor = tf.map_fn(
decode_image_tf_example_fn,
elems=inputs,
fn_output_signature=tf.TensorSpec(
shape=[None] * len(input_image_size) + [num_channels],
dtype=tf.uint8),
)
elif input_type == 'image_bytes':
decode_image_fn = lambda x: decode_image(x, input_image_size, num_channels)
image_tensor = tf.map_fn(
decode_image_fn, elems=inputs,
fn_output_signature=tf.TensorSpec(
shape=[None] * len(input_image_size) + [num_channels],
dtype=tf.uint8),)
else:
image_tensor = inputs
return image_tensor
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Image classification input and model functions for serving/inference."""
import tensorflow as tf
from official.vision.modeling import factory
from official.vision.ops import preprocess_ops
from official.vision.serving import export_base
MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)
STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
class ClassificationModule(export_base.ExportModule):
"""classification Module."""
def _build_model(self):
input_specs = tf.keras.layers.InputSpec(
shape=[self._batch_size] + self._input_image_size + [3])
return factory.build_classification_model(
input_specs=input_specs,
model_config=self.params.task.model,
l2_regularizer=None)
def _build_inputs(self, image):
"""Builds classification model inputs for serving."""
# Center crops and resizes image.
image = preprocess_ops.center_crop_image(image)
image = tf.image.resize(
image, self._input_image_size, method=tf.image.ResizeMethod.BILINEAR)
image = tf.reshape(
image, [self._input_image_size[0], self._input_image_size[1], 3])
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image,
offset=MEAN_RGB,
scale=STDDEV_RGB)
return image
def serve(self, images):
"""Cast image to float and run inference.
Args:
images: uint8 Tensor of shape [batch_size, None, None, 3]
Returns:
Tensor holding classification output logits.
"""
# Skip image preprocessing when input_type is tflite so it is compatible
# with TFLite quantization.
if self._input_type != 'tflite':
with tf.device('cpu:0'):
images = tf.cast(images, dtype=tf.float32)
images = tf.nest.map_structure(
tf.identity,
tf.map_fn(
self._build_inputs,
elems=images,
fn_output_signature=tf.TensorSpec(
shape=self._input_image_size + [3], dtype=tf.float32),
parallel_iterations=32))
logits = self.inference_step(images)
probs = tf.nn.softmax(logits)
return {'logits': logits, 'probs': probs}
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Test for image classification export lib."""
import io
import os
from absl.testing import parameterized
import numpy as np
from PIL import Image
import tensorflow as tf
from official.core import exp_factory
from official.vision import registry_imports # pylint: disable=unused-import
from official.vision.serving import image_classification
class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
def _get_classification_module(self, input_type):
params = exp_factory.get_exp_config('resnet_imagenet')
params.task.model.backbone.resnet.model_id = 18
classification_module = image_classification.ClassificationModule(
params,
batch_size=1,
input_image_size=[224, 224],
input_type=input_type)
return classification_module
def _export_from_module(self, module, input_type, save_directory):
signatures = module.get_inference_signatures(
{input_type: 'serving_default'})
tf.saved_model.save(module,
save_directory,
signatures=signatures)
def _get_dummy_input(self, input_type):
"""Get dummy input for the given input type."""
if input_type == 'image_tensor':
return tf.zeros((1, 224, 224, 3), dtype=np.uint8)
elif input_type == 'image_bytes':
image = Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))
byte_io = io.BytesIO()
image.save(byte_io, 'PNG')
return [byte_io.getvalue()]
elif input_type == 'tf_example':
image_tensor = tf.zeros((224, 224, 3), dtype=tf.uint8)
encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).numpy()
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[encoded_jpeg])),
})).SerializeToString()
return [example]
elif input_type == 'tflite':
return tf.zeros((1, 224, 224, 3), dtype=np.float32)
@parameterized.parameters(
{'input_type': 'image_tensor'},
{'input_type': 'image_bytes'},
{'input_type': 'tf_example'},
{'input_type': 'tflite'},
)
def test_export(self, input_type='image_tensor'):
tmp_dir = self.get_temp_dir()
module = self._get_classification_module(input_type)
# Test that the model restores any attrs that are trackable objects
# (eg: tables, resource variables, keras models/layers, tf.hub modules).
module.model.test_trackable = tf.keras.layers.InputLayer(input_shape=(4,))
self._export_from_module(module, input_type, tmp_dir)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, 'saved_model.pb')))
self.assertTrue(os.path.exists(
os.path.join(tmp_dir, 'variables', 'variables.index')))
self.assertTrue(os.path.exists(
os.path.join(tmp_dir, 'variables', 'variables.data-00000-of-00001')))
imported = tf.saved_model.load(tmp_dir)
classification_fn = imported.signatures['serving_default']
images = self._get_dummy_input(input_type)
if input_type != 'tflite':
processed_images = tf.nest.map_structure(
tf.stop_gradient,
tf.map_fn(
module._build_inputs,
elems=tf.zeros((1, 224, 224, 3), dtype=tf.uint8),
fn_output_signature=tf.TensorSpec(
shape=[224, 224, 3], dtype=tf.float32)))
else:
processed_images = images
expected_logits = module.model(processed_images, training=False)
expected_prob = tf.nn.softmax(expected_logits)
out = classification_fn(tf.constant(images))
# The imported model should contain any trackable attrs that the original
# model had.
self.assertTrue(hasattr(imported.model, 'test_trackable'))
self.assertAllClose(out['logits'].numpy(), expected_logits.numpy())
self.assertAllClose(out['probs'].numpy(), expected_prob.numpy())
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Semantic segmentation input and model functions for serving/inference."""
import tensorflow as tf
from official.vision.modeling import factory
from official.vision.ops import preprocess_ops
from official.vision.serving import export_base
MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)
STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
class SegmentationModule(export_base.ExportModule):
"""Segmentation Module."""
def _build_model(self):
input_specs = tf.keras.layers.InputSpec(
shape=[self._batch_size] + self._input_image_size + [3])
return factory.build_segmentation_model(
input_specs=input_specs,
model_config=self.params.task.model,
l2_regularizer=None)
def _build_inputs(self, image):
"""Builds classification model inputs for serving."""
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image,
offset=MEAN_RGB,
scale=STDDEV_RGB)
image, image_info = preprocess_ops.resize_and_crop_image(
image,
self._input_image_size,
padded_size=self._input_image_size,
aug_scale_min=1.0,
aug_scale_max=1.0)
return image, image_info
def serve(self, images):
"""Cast image to float and run inference.
Args:
images: uint8 Tensor of shape [batch_size, None, None, 3]
Returns:
Tensor holding classification output logits.
"""
# Skip image preprocessing when input_type is tflite so it is compatible
# with TFLite quantization.
image_info = None
if self._input_type != 'tflite':
with tf.device('cpu:0'):
images = tf.cast(images, dtype=tf.float32)
images_spec = tf.TensorSpec(
shape=self._input_image_size + [3], dtype=tf.float32)
image_info_spec = tf.TensorSpec(shape=[4, 2], dtype=tf.float32)
images, image_info = tf.nest.map_structure(
tf.identity,
tf.map_fn(
self._build_inputs,
elems=images,
fn_output_signature=(images_spec, image_info_spec),
parallel_iterations=32))
outputs = self.inference_step(images)
outputs['logits'] = tf.image.resize(
outputs['logits'], self._input_image_size, method='bilinear')
if image_info is not None:
outputs.update({'image_info': image_info})
return outputs
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Test for semantic segmentation export lib."""
import io
import os
from absl.testing import parameterized
import numpy as np
from PIL import Image
import tensorflow as tf
from official.core import exp_factory
from official.vision import registry_imports # pylint: disable=unused-import
from official.vision.serving import semantic_segmentation
class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
def _get_segmentation_module(self, input_type):
params = exp_factory.get_exp_config('mnv2_deeplabv3_pascal')
segmentation_module = semantic_segmentation.SegmentationModule(
params,
batch_size=1,
input_image_size=[112, 112],
input_type=input_type)
return segmentation_module
def _export_from_module(self, module, input_type, save_directory):
signatures = module.get_inference_signatures(
{input_type: 'serving_default'})
tf.saved_model.save(module, save_directory, signatures=signatures)
def _get_dummy_input(self, input_type):
"""Get dummy input for the given input type."""
if input_type == 'image_tensor':
return tf.zeros((1, 112, 112, 3), dtype=np.uint8)
elif input_type == 'image_bytes':
image = Image.fromarray(np.zeros((112, 112, 3), dtype=np.uint8))
byte_io = io.BytesIO()
image.save(byte_io, 'PNG')
return [byte_io.getvalue()]
elif input_type == 'tf_example':
image_tensor = tf.zeros((112, 112, 3), dtype=tf.uint8)
encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).numpy()
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[encoded_jpeg])),
})).SerializeToString()
return [example]
elif input_type == 'tflite':
return tf.zeros((1, 112, 112, 3), dtype=np.float32)
@parameterized.parameters(
{'input_type': 'image_tensor'},
{'input_type': 'image_bytes'},
{'input_type': 'tf_example'},
{'input_type': 'tflite'},
)
def test_export(self, input_type='image_tensor'):
tmp_dir = self.get_temp_dir()
module = self._get_segmentation_module(input_type)
self._export_from_module(module, input_type, tmp_dir)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, 'saved_model.pb')))
self.assertTrue(
os.path.exists(os.path.join(tmp_dir, 'variables', 'variables.index')))
self.assertTrue(
os.path.exists(
os.path.join(tmp_dir, 'variables',
'variables.data-00000-of-00001')))
imported = tf.saved_model.load(tmp_dir)
segmentation_fn = imported.signatures['serving_default']
images = self._get_dummy_input(input_type)
if input_type != 'tflite':
processed_images, _ = tf.nest.map_structure(
tf.stop_gradient,
tf.map_fn(
module._build_inputs,
elems=tf.zeros((1, 112, 112, 3), dtype=tf.uint8),
fn_output_signature=(tf.TensorSpec(
shape=[112, 112, 3], dtype=tf.float32),
tf.TensorSpec(
shape=[4, 2], dtype=tf.float32))))
else:
processed_images = images
expected_output = tf.image.resize(
module.model(processed_images, training=False)['logits'], [112, 112],
method='bilinear')
out = segmentation_fn(tf.constant(images))
self.assertAllClose(out['logits'].numpy(), expected_output.numpy())
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Video classification input and model functions for serving/inference."""
from typing import Mapping, Dict, Text
import tensorflow as tf
from official.vision.dataloaders import video_input
from official.vision.serving import export_base
from official.vision.tasks import video_classification
MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)
STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
class VideoClassificationModule(export_base.ExportModule):
"""Video classification Module."""
def _build_model(self):
input_params = self.params.task.train_data
self._num_frames = input_params.feature_shape[0]
self._stride = input_params.temporal_stride
self._min_resize = input_params.min_image_size
self._crop_size = input_params.feature_shape[1]
self._output_audio = input_params.output_audio
task = video_classification.VideoClassificationTask(self.params.task)
return task.build_model()
def _decode_tf_example(self, encoded_inputs: tf.Tensor):
sequence_description = {
# Each image is a string encoding JPEG.
video_input.IMAGE_KEY:
tf.io.FixedLenSequenceFeature((), tf.string),
}
if self._output_audio:
sequence_description[self._params.task.validation_data.audio_feature] = (
tf.io.VarLenFeature(dtype=tf.float32))
_, decoded_tensors = tf.io.parse_single_sequence_example(
encoded_inputs, {}, sequence_description)
for key, value in decoded_tensors.items():
if isinstance(value, tf.SparseTensor):
decoded_tensors[key] = tf.sparse.to_dense(value)
return decoded_tensors
def _preprocess_image(self, image):
image = video_input.process_image(
image=image,
is_training=False,
num_frames=self._num_frames,
stride=self._stride,
num_test_clips=1,
min_resize=self._min_resize,
crop_size=self._crop_size,
num_crops=1)
image = tf.cast(image, tf.float32) # Use config.
features = {'image': image}
return features
def _preprocess_audio(self, audio):
features = {}
audio = tf.cast(audio, dtype=tf.float32) # Use config.
audio = video_input.preprocess_ops_3d.sample_sequence(
audio, 20, random=False, stride=1)
audio = tf.ensure_shape(
audio, self._params.task.validation_data.audio_feature_shape)
features['audio'] = audio
return features
@tf.function
def inference_from_tf_example(
self, encoded_inputs: tf.Tensor) -> Mapping[str, tf.Tensor]:
with tf.device('cpu:0'):
if self._output_audio:
inputs = tf.map_fn(
self._decode_tf_example, (encoded_inputs),
fn_output_signature={
video_input.IMAGE_KEY: tf.string,
self._params.task.validation_data.audio_feature: tf.float32
})
return self.serve(inputs['image'], inputs['audio'])
else:
inputs = tf.map_fn(
self._decode_tf_example, (encoded_inputs),
fn_output_signature={
video_input.IMAGE_KEY: tf.string,
})
return self.serve(inputs[video_input.IMAGE_KEY], tf.zeros([1, 1]))
@tf.function
def inference_from_image_tensors(
self, input_frames: tf.Tensor) -> Mapping[str, tf.Tensor]:
return self.serve(input_frames, tf.zeros([1, 1]))
@tf.function
def inference_from_image_audio_tensors(
self, input_frames: tf.Tensor,
input_audio: tf.Tensor) -> Mapping[str, tf.Tensor]:
return self.serve(input_frames, input_audio)
@tf.function
def inference_from_image_bytes(self, inputs: tf.Tensor):
raise NotImplementedError(
'Video classification do not support image bytes input.')
def serve(self, input_frames: tf.Tensor, input_audio: tf.Tensor):
"""Cast image to float and run inference.
Args:
input_frames: uint8 Tensor of shape [batch_size, None, None, 3]
input_audio: float32
Returns:
Tensor holding classification output logits.
"""
with tf.device('cpu:0'):
inputs = tf.map_fn(
self._preprocess_image, (input_frames),
fn_output_signature={
'image': tf.float32,
})
if self._output_audio:
inputs.update(
tf.map_fn(
self._preprocess_audio, (input_audio),
fn_output_signature={'audio': tf.float32}))
logits = self.inference_step(inputs)
if self.params.task.train_data.is_multilabel:
probs = tf.math.sigmoid(logits)
else:
probs = tf.nn.softmax(logits)
return {'logits': logits, 'probs': probs}
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
"""Gets defined function signatures.
Args:
function_keys: A dictionary with keys as the function to create signature
for and values as the signature keys when returns.
Returns:
A dictionary with key as signature key and value as concrete functions
that can be used for tf.saved_model.save.
"""
signatures = {}
for key, def_name in function_keys.items():
if key == 'image_tensor':
input_signature = tf.TensorSpec(
shape=[self._batch_size] + self._input_image_size + [3],
dtype=tf.uint8,
name='INPUT_FRAMES')
signatures[
def_name] = self.inference_from_image_tensors.get_concrete_function(
input_signature)
elif key == 'frames_audio':
input_signature = [
tf.TensorSpec(
shape=[self._batch_size] + self._input_image_size + [3],
dtype=tf.uint8,
name='INPUT_FRAMES'),
tf.TensorSpec(
shape=[self._batch_size] +
self.params.task.train_data.audio_feature_shape,
dtype=tf.float32,
name='INPUT_AUDIO')
]
signatures[
def_name] = self.inference_from_image_audio_tensors.get_concrete_function(
input_signature)
elif key == 'serve_examples' or key == 'tf_example':
input_signature = tf.TensorSpec(
shape=[self._batch_size], dtype=tf.string)
signatures[
def_name] = self.inference_from_tf_example.get_concrete_function(
input_signature)
else:
raise ValueError('Unrecognized `input_type`')
return signatures
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# import io
import os
import random
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.core import exp_factory
from official.vision import registry_imports # pylint: disable=unused-import
from official.vision.dataloaders import tfexample_utils
from official.vision.serving import video_classification
class VideoClassificationTest(tf.test.TestCase, parameterized.TestCase):
def _get_classification_module(self):
params = exp_factory.get_exp_config('video_classification_ucf101')
params.task.train_data.feature_shape = (8, 64, 64, 3)
params.task.validation_data.feature_shape = (8, 64, 64, 3)
params.task.model.backbone.resnet_3d.model_id = 50
classification_module = video_classification.VideoClassificationModule(
params, batch_size=1, input_image_size=[8, 64, 64])
return classification_module
def _export_from_module(self, module, input_type, save_directory):
signatures = module.get_inference_signatures(
{input_type: 'serving_default'})
tf.saved_model.save(module, save_directory, signatures=signatures)
def _get_dummy_input(self, input_type, module=None):
"""Get dummy input for the given input type."""
if input_type == 'image_tensor':
images = np.random.randint(
low=0, high=255, size=(1, 8, 64, 64, 3), dtype=np.uint8)
# images = np.zeros((1, 8, 64, 64, 3), dtype=np.uint8)
return images, images
elif input_type == 'tf_example':
example = tfexample_utils.make_video_test_example(
image_shape=(64, 64, 3),
audio_shape=(20, 128),
label=random.randint(0, 100)).SerializeToString()
images = tf.nest.map_structure(
tf.stop_gradient,
tf.map_fn(
module._decode_tf_example,
elems=tf.constant([example]),
fn_output_signature={
video_classification.video_input.IMAGE_KEY: tf.string,
}))
images = images[video_classification.video_input.IMAGE_KEY]
return [example], images
else:
raise ValueError(f'{input_type}')
@parameterized.parameters(
{'input_type': 'image_tensor'},
{'input_type': 'tf_example'},
)
def test_export(self, input_type):
tmp_dir = self.get_temp_dir()
module = self._get_classification_module()
self._export_from_module(module, input_type, tmp_dir)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, 'saved_model.pb')))
self.assertTrue(
os.path.exists(os.path.join(tmp_dir, 'variables', 'variables.index')))
self.assertTrue(
os.path.exists(
os.path.join(tmp_dir, 'variables',
'variables.data-00000-of-00001')))
imported = tf.saved_model.load(tmp_dir)
classification_fn = imported.signatures['serving_default']
images, images_tensor = self._get_dummy_input(input_type, module)
processed_images = tf.nest.map_structure(
tf.stop_gradient,
tf.map_fn(
module._preprocess_image,
elems=images_tensor,
fn_output_signature={
'image': tf.float32,
}))
expected_logits = module.model(processed_images, training=False)
expected_prob = tf.nn.softmax(expected_logits)
out = classification_fn(tf.constant(images))
# The imported model should contain any trackable attrs that the original
# model had.
self.assertAllClose(out['logits'].numpy(), expected_logits.numpy())
self.assertAllClose(out['probs'].numpy(), expected_prob.numpy())
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tasks package definition."""
from official.vision.tasks import image_classification
from official.vision.tasks import maskrcnn
from official.vision.tasks import retinanet
from official.vision.tasks import semantic_segmentation
from official.vision.tasks import video_classification
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image classification task definition."""
from typing import Any, Optional, List, Tuple
from absl import logging
import tensorflow as tf
from official.common import dataset_fn
from official.core import base_task
from official.core import task_factory
from official.modeling import tf_utils
from official.vision.configs import image_classification as exp_cfg
from official.vision.dataloaders import classification_input
from official.vision.dataloaders import input_reader_factory
from official.vision.dataloaders import tfds_factory
from official.vision.modeling import factory
from official.vision.ops import augment
@task_factory.register_task_cls(exp_cfg.ImageClassificationTask)
class ImageClassificationTask(base_task.Task):
"""A task for image classification."""
def build_model(self):
"""Builds classification model."""
input_specs = tf.keras.layers.InputSpec(
shape=[None] + self.task_config.model.input_size)
l2_weight_decay = self.task_config.losses.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
model = factory.build_classification_model(
input_specs=input_specs,
model_config=self.task_config.model,
l2_regularizer=l2_regularizer)
return model
def initialize(self, model: tf.keras.Model):
"""Loads pretrained checkpoint."""
if not self.task_config.init_checkpoint:
return
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(model=model)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
raise ValueError(
"Only 'all' or 'backbone' can be used to initialize the model.")
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def build_inputs(
self,
params: exp_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Builds classification input."""
num_classes = self.task_config.model.num_classes
input_size = self.task_config.model.input_size
image_field_key = self.task_config.train_data.image_field_key
label_field_key = self.task_config.train_data.label_field_key
is_multilabel = self.task_config.train_data.is_multilabel
if params.tfds_name:
decoder = tfds_factory.get_classification_decoder(params.tfds_name)
else:
decoder = classification_input.Decoder(
image_field_key=image_field_key, label_field_key=label_field_key,
is_multilabel=is_multilabel)
parser = classification_input.Parser(
output_size=input_size[:2],
num_classes=num_classes,
image_field_key=image_field_key,
label_field_key=label_field_key,
decode_jpeg_only=params.decode_jpeg_only,
aug_rand_hflip=params.aug_rand_hflip,
aug_type=params.aug_type,
color_jitter=params.color_jitter,
random_erasing=params.random_erasing,
is_multilabel=is_multilabel,
dtype=params.dtype)
postprocess_fn = None
if params.mixup_and_cutmix:
postprocess_fn = augment.MixupAndCutmix(
mixup_alpha=params.mixup_and_cutmix.mixup_alpha,
cutmix_alpha=params.mixup_and_cutmix.cutmix_alpha,
prob=params.mixup_and_cutmix.prob,
label_smoothing=params.mixup_and_cutmix.label_smoothing,
num_classes=num_classes)
reader = input_reader_factory.input_reader_generator(
params,
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training),
postprocess_fn=postprocess_fn)
dataset = reader.read(input_context=input_context)
return dataset
def build_losses(self,
labels: tf.Tensor,
model_outputs: tf.Tensor,
aux_losses: Optional[Any] = None) -> tf.Tensor:
"""Builds sparse categorical cross entropy loss.
Args:
labels: Input groundtruth labels.
model_outputs: Output logits of the classifier.
aux_losses: The auxiliarly loss tensors, i.e. `losses` in tf.keras.Model.
Returns:
The total loss tensor.
"""
losses_config = self.task_config.losses
is_multilabel = self.task_config.train_data.is_multilabel
if not is_multilabel:
if losses_config.one_hot:
total_loss = tf.keras.losses.categorical_crossentropy(
labels,
model_outputs,
from_logits=True,
label_smoothing=losses_config.label_smoothing)
elif losses_config.soft_labels:
total_loss = tf.nn.softmax_cross_entropy_with_logits(
labels, model_outputs)
else:
total_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, model_outputs, from_logits=True)
else:
# Multi-label weighted binary cross entropy loss.
total_loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=labels, logits=model_outputs)
total_loss = tf.reduce_sum(total_loss, axis=-1)
total_loss = tf_utils.safe_mean(total_loss)
if aux_losses:
total_loss += tf.add_n(aux_losses)
total_loss = losses_config.loss_weight * total_loss
return total_loss
def build_metrics(self,
training: bool = True) -> List[tf.keras.metrics.Metric]:
"""Gets streaming metrics for training/validation."""
is_multilabel = self.task_config.train_data.is_multilabel
if not is_multilabel:
k = self.task_config.evaluation.top_k
if (self.task_config.losses.one_hot or
self.task_config.losses.soft_labels):
metrics = [
tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
tf.keras.metrics.TopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))]
else:
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))]
else:
metrics = []
# These metrics destablize the training if included in training. The jobs
# fail due to OOM.
# TODO(arashwan): Investigate adding following metric to train.
if not training:
metrics = [
tf.keras.metrics.AUC(
name='globalPR-AUC',
curve='PR',
multi_label=False,
from_logits=True),
tf.keras.metrics.AUC(
name='meanPR-AUC',
curve='PR',
multi_label=True,
num_labels=self.task_config.model.num_classes,
from_logits=True),
]
return metrics
def train_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics: Optional[List[Any]] = None):
"""Does forward and backward.
Args:
inputs: A tuple of of input tensors of (features, labels).
model: A tf.keras.Model instance.
optimizer: The optimizer for this training step.
metrics: A nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
is_multilabel = self.task_config.train_data.is_multilabel
if self.task_config.losses.one_hot and not is_multilabel:
labels = tf.one_hot(labels, self.task_config.model.num_classes)
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
outputs = model(features, training=True)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(
lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss.
loss = self.build_losses(
model_outputs=outputs,
labels=labels,
aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(
optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if isinstance(
optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
elif model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in model.metrics})
return logs
def validation_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
metrics: Optional[List[Any]] = None):
"""Runs validatation step.
Args:
inputs: A tuple of of input tensors of (features, labels).
model: A tf.keras.Model instance.
metrics: A nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
one_hot = self.task_config.losses.one_hot
soft_labels = self.task_config.losses.soft_labels
is_multilabel = self.task_config.train_data.is_multilabel
if (one_hot or soft_labels) and not is_multilabel:
labels = tf.one_hot(labels, self.task_config.model.num_classes)
outputs = self.inference_step(features, model)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
loss = self.build_losses(
model_outputs=outputs,
labels=labels,
aux_losses=model.losses)
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
elif model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in model.metrics})
return logs
def inference_step(self, inputs: tf.Tensor, model: tf.keras.Model):
"""Performs the forward step."""
return model(inputs, training=False)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MaskRCNN task definition."""
import os
from typing import Any, Optional, List, Tuple, Mapping
from absl import logging
import tensorflow as tf
from official.common import dataset_fn
from official.core import base_task
from official.core import task_factory
from official.vision.configs import maskrcnn as exp_cfg
from official.vision.dataloaders import input_reader_factory
from official.vision.dataloaders import maskrcnn_input
from official.vision.dataloaders import tf_example_decoder
from official.vision.dataloaders import tf_example_label_map_decoder
from official.vision.evaluation import coco_evaluator
from official.vision.evaluation import coco_utils
from official.vision.losses import maskrcnn_losses
from official.vision.modeling import factory
def zero_out_disallowed_class_ids(batch_class_ids: tf.Tensor,
allowed_class_ids: List[int]):
"""Zero out IDs of classes not in allowed_class_ids.
Args:
batch_class_ids: A [batch_size, num_instances] int tensor of input
class IDs.
allowed_class_ids: A python list of class IDs which we want to allow.
Returns:
filtered_class_ids: A [batch_size, num_instances] int tensor with any
class ID not in allowed_class_ids set to 0.
"""
allowed_class_ids = tf.constant(allowed_class_ids,
dtype=batch_class_ids.dtype)
match_ids = (batch_class_ids[:, :, tf.newaxis] ==
allowed_class_ids[tf.newaxis, tf.newaxis, :])
match_ids = tf.reduce_any(match_ids, axis=2)
return tf.where(match_ids, batch_class_ids, tf.zeros_like(batch_class_ids))
@task_factory.register_task_cls(exp_cfg.MaskRCNNTask)
class MaskRCNNTask(base_task.Task):
"""A single-replica view of training procedure.
Mask R-CNN task provides artifacts for training/evalution procedures,
including loading/iterating over Datasets, initializing the model, calculating
the loss, post-processing, and customized metrics with reduction.
"""
def build_model(self):
"""Build Mask R-CNN model."""
input_specs = tf.keras.layers.InputSpec(
shape=[None] + self.task_config.model.input_size)
l2_weight_decay = self.task_config.losses.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
model = factory.build_maskrcnn(
input_specs=input_specs,
model_config=self.task_config.model,
l2_regularizer=l2_regularizer)
return model
def initialize(self, model: tf.keras.Model):
"""Loading pretrained checkpoint."""
if not self.task_config.init_checkpoint:
return
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
ckpt_items = {}
if 'backbone' in self.task_config.init_checkpoint_modules:
ckpt_items.update(backbone=model.backbone)
if 'decoder' in self.task_config.init_checkpoint_modules:
ckpt_items.update(decoder=model.decoder)
ckpt = tf.train.Checkpoint(**ckpt_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def build_inputs(self,
params: exp_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
"""Build input dataset."""
decoder_cfg = params.decoder.get()
if params.decoder.type == 'simple_decoder':
decoder = tf_example_decoder.TfExampleDecoder(
include_mask=self._task_config.model.include_mask,
regenerate_source_id=decoder_cfg.regenerate_source_id,
mask_binarize_threshold=decoder_cfg.mask_binarize_threshold)
elif params.decoder.type == 'label_map_decoder':
decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
label_map=decoder_cfg.label_map,
include_mask=self._task_config.model.include_mask,
regenerate_source_id=decoder_cfg.regenerate_source_id,
mask_binarize_threshold=decoder_cfg.mask_binarize_threshold)
else:
raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type))
parser = maskrcnn_input.Parser(
output_size=self.task_config.model.input_size[:2],
min_level=self.task_config.model.min_level,
max_level=self.task_config.model.max_level,
num_scales=self.task_config.model.anchor.num_scales,
aspect_ratios=self.task_config.model.anchor.aspect_ratios,
anchor_size=self.task_config.model.anchor.anchor_size,
dtype=params.dtype,
rpn_match_threshold=params.parser.rpn_match_threshold,
rpn_unmatched_threshold=params.parser.rpn_unmatched_threshold,
rpn_batch_size_per_im=params.parser.rpn_batch_size_per_im,
rpn_fg_fraction=params.parser.rpn_fg_fraction,
aug_rand_hflip=params.parser.aug_rand_hflip,
aug_scale_min=params.parser.aug_scale_min,
aug_scale_max=params.parser.aug_scale_max,
skip_crowd_during_training=params.parser.skip_crowd_during_training,
max_num_instances=params.parser.max_num_instances,
include_mask=self._task_config.model.include_mask,
mask_crop_size=params.parser.mask_crop_size)
reader = input_reader_factory.input_reader_generator(
params,
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training))
dataset = reader.read(input_context=input_context)
return dataset
def build_losses(self,
outputs: Mapping[str, Any],
labels: Mapping[str, Any],
aux_losses: Optional[Any] = None):
"""Build Mask R-CNN losses."""
params = self.task_config
cascade_ious = params.model.roi_sampler.cascade_iou_thresholds
rpn_score_loss_fn = maskrcnn_losses.RpnScoreLoss(
tf.shape(outputs['box_outputs'])[1])
rpn_box_loss_fn = maskrcnn_losses.RpnBoxLoss(
params.losses.rpn_huber_loss_delta)
rpn_score_loss = tf.reduce_mean(
rpn_score_loss_fn(
outputs['rpn_scores'], labels['rpn_score_targets']))
rpn_box_loss = tf.reduce_mean(
rpn_box_loss_fn(
outputs['rpn_boxes'], labels['rpn_box_targets']))
frcnn_cls_loss_fn = maskrcnn_losses.FastrcnnClassLoss()
frcnn_box_loss_fn = maskrcnn_losses.FastrcnnBoxLoss(
params.losses.frcnn_huber_loss_delta,
params.model.detection_head.class_agnostic_bbox_pred)
# Final cls/box losses are computed as an average of all detection heads.
frcnn_cls_loss = 0.0
frcnn_box_loss = 0.0
num_det_heads = 1 if cascade_ious is None else 1 + len(cascade_ious)
for cas_num in range(num_det_heads):
frcnn_cls_loss_i = tf.reduce_mean(
frcnn_cls_loss_fn(
outputs['class_outputs_{}'
.format(cas_num) if cas_num else 'class_outputs'],
outputs['class_targets_{}'
.format(cas_num) if cas_num else 'class_targets']))
frcnn_box_loss_i = tf.reduce_mean(
frcnn_box_loss_fn(
outputs['box_outputs_{}'.format(cas_num
) if cas_num else 'box_outputs'],
outputs['class_targets_{}'
.format(cas_num) if cas_num else 'class_targets'],
outputs['box_targets_{}'.format(cas_num
) if cas_num else 'box_targets']))
frcnn_cls_loss += frcnn_cls_loss_i
frcnn_box_loss += frcnn_box_loss_i
frcnn_cls_loss /= num_det_heads
frcnn_box_loss /= num_det_heads
if params.model.include_mask:
mask_loss_fn = maskrcnn_losses.MaskrcnnLoss()
mask_class_targets = outputs['mask_class_targets']
if self._task_config.allowed_mask_class_ids is not None:
# Classes with ID=0 are ignored by mask_loss_fn in loss computation.
mask_class_targets = zero_out_disallowed_class_ids(
mask_class_targets, self._task_config.allowed_mask_class_ids)
mask_loss = tf.reduce_mean(
mask_loss_fn(
outputs['mask_outputs'],
outputs['mask_targets'],
mask_class_targets))
else:
mask_loss = 0.0
model_loss = (
params.losses.rpn_score_weight * rpn_score_loss +
params.losses.rpn_box_weight * rpn_box_loss +
params.losses.frcnn_class_weight * frcnn_cls_loss +
params.losses.frcnn_box_weight * frcnn_box_loss +
params.losses.mask_weight * mask_loss)
total_loss = model_loss
if aux_losses:
reg_loss = tf.reduce_sum(aux_losses)
total_loss = model_loss + reg_loss
total_loss = params.losses.loss_weight * total_loss
losses = {
'total_loss': total_loss,
'rpn_score_loss': rpn_score_loss,
'rpn_box_loss': rpn_box_loss,
'frcnn_cls_loss': frcnn_cls_loss,
'frcnn_box_loss': frcnn_box_loss,
'mask_loss': mask_loss,
'model_loss': model_loss,
}
return losses
def _build_coco_metrics(self):
"""Build COCO metrics evaluator."""
if (not self._task_config.model.include_mask
) or self._task_config.annotation_file:
self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file=self._task_config.annotation_file,
include_mask=self._task_config.model.include_mask,
per_category_metrics=self._task_config.per_category_metrics)
else:
# Builds COCO-style annotation file if include_mask is True, and
# annotation_file isn't provided.
annotation_path = os.path.join(self._logging_dir, 'annotation.json')
if tf.io.gfile.exists(annotation_path):
logging.info(
'annotation.json file exists, skipping creating the annotation'
' file.')
else:
if self._task_config.validation_data.num_examples <= 0:
logging.info('validation_data.num_examples needs to be > 0')
if not self._task_config.validation_data.input_path:
logging.info('Can not create annotation file for tfds.')
logging.info(
'Creating coco-style annotation file: %s', annotation_path)
coco_utils.scan_and_generator_annotation_file(
self._task_config.validation_data.input_path,
self._task_config.validation_data.file_type,
self._task_config.validation_data.num_examples,
self.task_config.model.include_mask, annotation_path,
regenerate_source_id=self._task_config.validation_data.decoder
.simple_decoder.regenerate_source_id)
self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file=annotation_path,
include_mask=self._task_config.model.include_mask,
per_category_metrics=self._task_config.per_category_metrics)
def build_metrics(self, training: bool = True):
"""Build detection metrics."""
metrics = []
if training:
metric_names = [
'total_loss',
'rpn_score_loss',
'rpn_box_loss',
'frcnn_cls_loss',
'frcnn_box_loss',
'mask_loss',
'model_loss'
]
for name in metric_names:
metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32))
else:
if self._task_config.use_coco_metrics:
self._build_coco_metrics()
if self._task_config.use_wod_metrics:
# To use Waymo open dataset metrics, please install one of the pip
# package `waymo-open-dataset-tf-*` from
# https://github.com/waymo-research/waymo-open-dataset/blob/master/docs/quick_start.md#use-pre-compiled-pippip3-packages-for-linux
# Note that the package is built with specific tensorflow version and
# will produce error if it does not match the tf version that is
# currently used.
try:
from official.vision.evaluation import wod_detection_evaluator # pylint: disable=g-import-not-at-top
except ModuleNotFoundError:
logging.error('waymo-open-dataset should be installed to enable Waymo'
' evaluator.')
raise
self.wod_metric = wod_detection_evaluator.WOD2dDetectionEvaluator()
return metrics
def train_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics: Optional[List[Any]] = None):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
images, labels = inputs
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
outputs = model(
images,
image_shape=labels['image_info'][:, 1, :],
anchor_boxes=labels['anchor_boxes'],
gt_boxes=labels['gt_boxes'],
gt_classes=labels['gt_classes'],
gt_masks=(labels['gt_masks'] if self.task_config.model.include_mask
else None),
training=True)
outputs = tf.nest.map_structure(
lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss.
losses = self.build_losses(
outputs=outputs, labels=labels, aux_losses=model.losses)
scaled_loss = losses['total_loss'] / num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient when LossScaleOptimizer is used.
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: losses['total_loss']}
if metrics:
for m in metrics:
m.update_state(losses[m.name])
return logs
def validation_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
metrics: Optional[List[Any]] = None):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
images, labels = inputs
outputs = model(
images,
anchor_boxes=labels['anchor_boxes'],
image_shape=labels['image_info'][:, 1, :],
training=False)
logs = {self.loss: 0}
if self._task_config.use_coco_metrics:
coco_model_outputs = {
'detection_boxes': outputs['detection_boxes'],
'detection_scores': outputs['detection_scores'],
'detection_classes': outputs['detection_classes'],
'num_detections': outputs['num_detections'],
'source_id': labels['groundtruths']['source_id'],
'image_info': labels['image_info']
}
if self.task_config.model.include_mask:
coco_model_outputs.update({
'detection_masks': outputs['detection_masks'],
})
logs.update(
{self.coco_metric.name: (labels['groundtruths'], coco_model_outputs)})
if self.task_config.use_wod_metrics:
wod_model_outputs = {
'detection_boxes': outputs['detection_boxes'],
'detection_scores': outputs['detection_scores'],
'detection_classes': outputs['detection_classes'],
'num_detections': outputs['num_detections'],
'source_id': labels['groundtruths']['source_id'],
'image_info': labels['image_info']
}
logs.update(
{self.wod_metric.name: (labels['groundtruths'], wod_model_outputs)})
return logs
def aggregate_logs(self, state=None, step_outputs=None):
if self._task_config.use_coco_metrics:
if state is None:
self.coco_metric.reset_states()
self.coco_metric.update_state(
step_outputs[self.coco_metric.name][0],
step_outputs[self.coco_metric.name][1])
if self._task_config.use_wod_metrics:
if state is None:
self.wod_metric.reset_states()
self.wod_metric.update_state(
step_outputs[self.wod_metric.name][0],
step_outputs[self.wod_metric.name][1])
if state is None:
# Create an arbitrary state to indicate it's not the first step in the
# following calls to this function.
state = True
return state
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
logs = {}
if self._task_config.use_coco_metrics:
logs.update(self.coco_metric.result())
if self._task_config.use_wod_metrics:
logs.update(self.wod_metric.result())
return logs
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RetinaNet task definition."""
from typing import Any, List, Mapping, Optional, Tuple
from absl import logging
import tensorflow as tf
from official.common import dataset_fn
from official.core import base_task
from official.core import task_factory
from official.vision.configs import retinanet as exp_cfg
from official.vision.dataloaders import input_reader_factory
from official.vision.dataloaders import retinanet_input
from official.vision.dataloaders import tf_example_decoder
from official.vision.dataloaders import tfds_factory
from official.vision.dataloaders import tf_example_label_map_decoder
from official.vision.evaluation import coco_evaluator
from official.vision.losses import focal_loss
from official.vision.losses import loss_utils
from official.vision.modeling import factory
@task_factory.register_task_cls(exp_cfg.RetinaNetTask)
class RetinaNetTask(base_task.Task):
"""A single-replica view of training procedure.
RetinaNet task provides artifacts for training/evalution procedures, including
loading/iterating over Datasets, initializing the model, calculating the loss,
post-processing, and customized metrics with reduction.
"""
def build_model(self):
"""Build RetinaNet model."""
input_specs = tf.keras.layers.InputSpec(
shape=[None] + self.task_config.model.input_size)
l2_weight_decay = self.task_config.losses.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
model = factory.build_retinanet(
input_specs=input_specs,
model_config=self.task_config.model,
l2_regularizer=l2_regularizer)
return model
def initialize(self, model: tf.keras.Model):
"""Loading pretrained checkpoint."""
if not self.task_config.init_checkpoint:
return
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
ckpt_items = {}
if 'backbone' in self.task_config.init_checkpoint_modules:
ckpt_items.update(backbone=model.backbone)
if 'decoder' in self.task_config.init_checkpoint_modules:
ckpt_items.update(decoder=model.decoder)
ckpt = tf.train.Checkpoint(**ckpt_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def build_inputs(self,
params: exp_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
"""Build input dataset."""
if params.tfds_name:
decoder = tfds_factory.get_detection_decoder(params.tfds_name)
else:
decoder_cfg = params.decoder.get()
if params.decoder.type == 'simple_decoder':
decoder = tf_example_decoder.TfExampleDecoder(
regenerate_source_id=decoder_cfg.regenerate_source_id)
elif params.decoder.type == 'label_map_decoder':
decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
label_map=decoder_cfg.label_map,
regenerate_source_id=decoder_cfg.regenerate_source_id)
else:
raise ValueError('Unknown decoder type: {}!'.format(
params.decoder.type))
parser = retinanet_input.Parser(
output_size=self.task_config.model.input_size[:2],
min_level=self.task_config.model.min_level,
max_level=self.task_config.model.max_level,
num_scales=self.task_config.model.anchor.num_scales,
aspect_ratios=self.task_config.model.anchor.aspect_ratios,
anchor_size=self.task_config.model.anchor.anchor_size,
dtype=params.dtype,
match_threshold=params.parser.match_threshold,
unmatched_threshold=params.parser.unmatched_threshold,
aug_type=params.parser.aug_type,
aug_rand_hflip=params.parser.aug_rand_hflip,
aug_scale_min=params.parser.aug_scale_min,
aug_scale_max=params.parser.aug_scale_max,
skip_crowd_during_training=params.parser.skip_crowd_during_training,
max_num_instances=params.parser.max_num_instances)
reader = input_reader_factory.input_reader_generator(
params,
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training))
dataset = reader.read(input_context=input_context)
return dataset
def build_attribute_loss(self,
attribute_heads: List[exp_cfg.AttributeHead],
outputs: Mapping[str, Any],
labels: Mapping[str, Any],
box_sample_weight: tf.Tensor) -> float:
"""Computes attribute loss.
Args:
attribute_heads: a list of attribute head configs.
outputs: RetinaNet model outputs.
labels: RetinaNet labels.
box_sample_weight: normalized bounding box sample weights.
Returns:
Attribute loss of all attribute heads.
"""
attribute_loss = 0.0
for head in attribute_heads:
if head.name not in labels['attribute_targets']:
raise ValueError(f'Attribute {head.name} not found in label targets.')
if head.name not in outputs['attribute_outputs']:
raise ValueError(f'Attribute {head.name} not found in model outputs.')
y_true_att = loss_utils.multi_level_flatten(
labels['attribute_targets'][head.name], last_dim=head.size)
y_pred_att = loss_utils.multi_level_flatten(
outputs['attribute_outputs'][head.name], last_dim=head.size)
if head.type == 'regression':
att_loss_fn = tf.keras.losses.Huber(
1.0, reduction=tf.keras.losses.Reduction.SUM)
att_loss = att_loss_fn(
y_true=y_true_att,
y_pred=y_pred_att,
sample_weight=box_sample_weight)
else:
raise ValueError(f'Attribute type {head.type} not supported.')
attribute_loss += att_loss
return attribute_loss
def build_losses(self,
outputs: Mapping[str, Any],
labels: Mapping[str, Any],
aux_losses: Optional[Any] = None):
"""Build RetinaNet losses."""
params = self.task_config
attribute_heads = self.task_config.model.head.attribute_heads
cls_loss_fn = focal_loss.FocalLoss(
alpha=params.losses.focal_loss_alpha,
gamma=params.losses.focal_loss_gamma,
reduction=tf.keras.losses.Reduction.SUM)
box_loss_fn = tf.keras.losses.Huber(
params.losses.huber_loss_delta, reduction=tf.keras.losses.Reduction.SUM)
# Sums all positives in a batch for normalization and avoids zero
# num_positives_sum, which would lead to inf loss during training
cls_sample_weight = labels['cls_weights']
box_sample_weight = labels['box_weights']
num_positives = tf.reduce_sum(box_sample_weight) + 1.0
cls_sample_weight = cls_sample_weight / num_positives
box_sample_weight = box_sample_weight / num_positives
y_true_cls = loss_utils.multi_level_flatten(
labels['cls_targets'], last_dim=None)
y_true_cls = tf.one_hot(y_true_cls, params.model.num_classes)
y_pred_cls = loss_utils.multi_level_flatten(
outputs['cls_outputs'], last_dim=params.model.num_classes)
y_true_box = loss_utils.multi_level_flatten(
labels['box_targets'], last_dim=4)
y_pred_box = loss_utils.multi_level_flatten(
outputs['box_outputs'], last_dim=4)
cls_loss = cls_loss_fn(
y_true=y_true_cls, y_pred=y_pred_cls, sample_weight=cls_sample_weight)
box_loss = box_loss_fn(
y_true=y_true_box, y_pred=y_pred_box, sample_weight=box_sample_weight)
model_loss = cls_loss + params.losses.box_loss_weight * box_loss
if attribute_heads:
model_loss += self.build_attribute_loss(attribute_heads, outputs, labels,
box_sample_weight)
total_loss = model_loss
if aux_losses:
reg_loss = tf.reduce_sum(aux_losses)
total_loss = model_loss + reg_loss
total_loss = params.losses.loss_weight * total_loss
return total_loss, cls_loss, box_loss, model_loss
def build_metrics(self, training: bool = True):
"""Build detection metrics."""
metrics = []
metric_names = ['total_loss', 'cls_loss', 'box_loss', 'model_loss']
for name in metric_names:
metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32))
if not training:
if self.task_config.validation_data.tfds_name and self.task_config.annotation_file:
raise ValueError(
"Can't evaluate using annotation file when TFDS is used.")
if self._task_config.use_coco_metrics:
self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file=self.task_config.annotation_file,
include_mask=False,
per_category_metrics=self.task_config.per_category_metrics)
if self._task_config.use_wod_metrics:
# To use Waymo open dataset metrics, please install one of the pip
# package `waymo-open-dataset-tf-*` from
# https://github.com/waymo-research/waymo-open-dataset/blob/master/docs/quick_start.md#use-pre-compiled-pippip3-packages-for-linux
# Note that the package is built with specific tensorflow version and
# will produce error if it does not match the tf version that is
# currently used.
try:
from official.vision.evaluation import wod_detection_evaluator # pylint: disable=g-import-not-at-top
except ModuleNotFoundError:
logging.error('waymo-open-dataset should be installed to enable Waymo'
' evaluator.')
raise
self.wod_metric = wod_detection_evaluator.WOD2dDetectionEvaluator()
return metrics
def train_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics: Optional[List[Any]] = None):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
outputs = model(features, training=True)
outputs = tf.nest.map_structure(
lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss.
loss, cls_loss, box_loss, model_loss = self.build_losses(
outputs=outputs, labels=labels, aux_losses=model.losses)
scaled_loss = loss / num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient when LossScaleOptimizer is used.
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss}
all_losses = {
'total_loss': loss,
'cls_loss': cls_loss,
'box_loss': box_loss,
'model_loss': model_loss,
}
if metrics:
for m in metrics:
m.update_state(all_losses[m.name])
logs.update({m.name: m.result()})
return logs
def validation_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
metrics: Optional[List[Any]] = None):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
outputs = model(features, anchor_boxes=labels['anchor_boxes'],
image_shape=labels['image_info'][:, 1, :],
training=False)
loss, cls_loss, box_loss, model_loss = self.build_losses(
outputs=outputs, labels=labels, aux_losses=model.losses)
logs = {self.loss: loss}
all_losses = {
'total_loss': loss,
'cls_loss': cls_loss,
'box_loss': box_loss,
'model_loss': model_loss,
}
if self._task_config.use_coco_metrics:
coco_model_outputs = {
'detection_boxes': outputs['detection_boxes'],
'detection_scores': outputs['detection_scores'],
'detection_classes': outputs['detection_classes'],
'num_detections': outputs['num_detections'],
'source_id': labels['groundtruths']['source_id'],
'image_info': labels['image_info']
}
logs.update(
{self.coco_metric.name: (labels['groundtruths'], coco_model_outputs)})
if self.task_config.use_wod_metrics:
wod_model_outputs = {
'detection_boxes': outputs['detection_boxes'],
'detection_scores': outputs['detection_scores'],
'detection_classes': outputs['detection_classes'],
'num_detections': outputs['num_detections'],
'source_id': labels['groundtruths']['source_id'],
'image_info': labels['image_info']
}
logs.update(
{self.wod_metric.name: (labels['groundtruths'], wod_model_outputs)})
if metrics:
for m in metrics:
m.update_state(all_losses[m.name])
logs.update({m.name: m.result()})
return logs
def aggregate_logs(self, state=None, step_outputs=None):
if self._task_config.use_coco_metrics:
if state is None:
self.coco_metric.reset_states()
self.coco_metric.update_state(step_outputs[self.coco_metric.name][0],
step_outputs[self.coco_metric.name][1])
if self._task_config.use_wod_metrics:
if state is None:
self.wod_metric.reset_states()
self.wod_metric.update_state(step_outputs[self.wod_metric.name][0],
step_outputs[self.wod_metric.name][1])
if state is None:
# Create an arbitrary state to indicate it's not the first step in the
# following calls to this function.
state = True
return state
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
logs = {}
if self._task_config.use_coco_metrics:
logs.update(self.coco_metric.result())
if self._task_config.use_wod_metrics:
logs.update(self.wod_metric.result())
return logs
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image segmentation task definition."""
from typing import Any, Optional, List, Tuple, Mapping, Union
from absl import logging
import tensorflow as tf
from official.common import dataset_fn
from official.core import base_task
from official.core import task_factory
from official.vision.configs import semantic_segmentation as exp_cfg
from official.vision.dataloaders import input_reader_factory
from official.vision.dataloaders import segmentation_input
from official.vision.dataloaders import tfds_factory
from official.vision.evaluation import segmentation_metrics
from official.vision.losses import segmentation_losses
from official.vision.modeling import factory
@task_factory.register_task_cls(exp_cfg.SemanticSegmentationTask)
class SemanticSegmentationTask(base_task.Task):
"""A task for semantic segmentation."""
def build_model(self):
"""Builds segmentation model."""
input_specs = tf.keras.layers.InputSpec(
shape=[None] + self.task_config.model.input_size)
l2_weight_decay = self.task_config.losses.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
model = factory.build_segmentation_model(
input_specs=input_specs,
model_config=self.task_config.model,
l2_regularizer=l2_regularizer)
return model
def initialize(self, model: tf.keras.Model):
"""Loads pretrained checkpoint."""
if not self.task_config.init_checkpoint:
return
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Restoring checkpoint.
if 'all' in self.task_config.init_checkpoint_modules:
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
ckpt_items = {}
if 'backbone' in self.task_config.init_checkpoint_modules:
ckpt_items.update(backbone=model.backbone)
if 'decoder' in self.task_config.init_checkpoint_modules:
ckpt_items.update(decoder=model.decoder)
ckpt = tf.train.Checkpoint(**ckpt_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def build_inputs(self,
params: exp_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
"""Builds classification input."""
ignore_label = self.task_config.losses.ignore_label
if params.tfds_name:
decoder = tfds_factory.get_segmentation_decoder(params.tfds_name)
else:
decoder = segmentation_input.Decoder()
parser = segmentation_input.Parser(
output_size=params.output_size,
crop_size=params.crop_size,
ignore_label=ignore_label,
resize_eval_groundtruth=params.resize_eval_groundtruth,
groundtruth_padded_size=params.groundtruth_padded_size,
aug_scale_min=params.aug_scale_min,
aug_scale_max=params.aug_scale_max,
aug_rand_hflip=params.aug_rand_hflip,
preserve_aspect_ratio=params.preserve_aspect_ratio,
dtype=params.dtype)
reader = input_reader_factory.input_reader_generator(
params,
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training))
dataset = reader.read(input_context=input_context)
return dataset
def build_losses(self,
labels: Mapping[str, tf.Tensor],
model_outputs: Union[Mapping[str, tf.Tensor], tf.Tensor],
aux_losses: Optional[Any] = None):
"""Segmentation loss.
Args:
labels: labels.
model_outputs: Output logits of the classifier.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
loss_params = self._task_config.losses
segmentation_loss_fn = segmentation_losses.SegmentationLoss(
loss_params.label_smoothing,
loss_params.class_weights,
loss_params.ignore_label,
use_groundtruth_dimension=loss_params.use_groundtruth_dimension,
top_k_percent_pixels=loss_params.top_k_percent_pixels)
total_loss = segmentation_loss_fn(model_outputs['logits'], labels['masks'])
if 'mask_scores' in model_outputs:
mask_scoring_loss_fn = segmentation_losses.MaskScoringLoss(
loss_params.ignore_label)
total_loss += mask_scoring_loss_fn(
model_outputs['mask_scores'],
model_outputs['logits'],
labels['masks'])
if aux_losses:
total_loss += tf.add_n(aux_losses)
total_loss = loss_params.loss_weight * total_loss
return total_loss
def process_metrics(self, metrics, labels, model_outputs, **kwargs):
"""Process and update metrics.
Called when using custom training loop API.
Args:
metrics: a nested structure of metrics objects. The return of function
self.build_metrics.
labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors. For example,
output of the keras model built by self.build_model.
**kwargs: other args.
"""
for metric in metrics:
if 'mask_scores_mse' is metric.name:
actual_mask_scores = segmentation_losses.get_actual_mask_scores(
model_outputs['logits'], labels['masks'],
self.task_config.losses.ignore_label)
metric.update_state(actual_mask_scores, model_outputs['mask_scores'])
else:
metric.update_state(labels, model_outputs['logits'])
def build_metrics(self, training: bool = True):
"""Gets streaming metrics for training/validation."""
metrics = []
if training and self.task_config.evaluation.report_train_mean_iou:
metrics.append(segmentation_metrics.MeanIoU(
name='mean_iou',
num_classes=self.task_config.model.num_classes,
rescale_predictions=False,
dtype=tf.float32))
if self.task_config.model.get('mask_scoring_head'):
metrics.append(
tf.keras.metrics.MeanSquaredError(name='mask_scores_mse'))
else:
self.iou_metric = segmentation_metrics.PerClassIoU(
name='per_class_iou',
num_classes=self.task_config.model.num_classes,
rescale_predictions=not self.task_config.validation_data
.resize_eval_groundtruth,
dtype=tf.float32)
if self.task_config.validation_data.resize_eval_groundtruth and self.task_config.model.get('mask_scoring_head'): # pylint: disable=line-too-long
# Masks scores metric can only be computed if labels are scaled to match
# preticted mask scores.
metrics.append(
tf.keras.metrics.MeanSquaredError(name='mask_scores_mse'))
# Update state on CPU if TPUStrategy due to dynamic resizing.
self._process_iou_metric_on_cpu = isinstance(
tf.distribute.get_strategy(), tf.distribute.TPUStrategy)
return metrics
def train_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics: Optional[List[Any]] = None):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
input_partition_dims = self.task_config.train_input_partition_dims
if input_partition_dims:
strategy = tf.distribute.get_strategy()
features = strategy.experimental_split_to_logical_devices(
features, input_partition_dims)
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
outputs = model(features, training=True)
if isinstance(outputs, tf.Tensor):
outputs = {'logits': outputs}
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(
lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss.
loss = self.build_losses(
model_outputs=outputs, labels=labels, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
return logs
def validation_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
metrics: Optional[List[Any]] = None):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
input_partition_dims = self.task_config.eval_input_partition_dims
if input_partition_dims:
strategy = tf.distribute.get_strategy()
features = strategy.experimental_split_to_logical_devices(
features, input_partition_dims)
outputs = self.inference_step(features, model)
if isinstance(outputs, tf.Tensor):
outputs = {'logits': outputs}
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
if self.task_config.validation_data.resize_eval_groundtruth:
loss = self.build_losses(model_outputs=outputs, labels=labels,
aux_losses=model.losses)
else:
loss = 0
logs = {self.loss: loss}
if self._process_iou_metric_on_cpu:
logs.update({self.iou_metric.name: (labels, outputs['logits'])})
else:
self.iou_metric.update_state(labels, outputs['logits'])
if metrics:
self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
return logs
def inference_step(self, inputs: tf.Tensor, model: tf.keras.Model):
"""Performs the forward step."""
return model(inputs, training=False)
def aggregate_logs(self, state=None, step_outputs=None):
if state is None:
self.iou_metric.reset_states()
state = self.iou_metric
if self._process_iou_metric_on_cpu:
self.iou_metric.update_state(step_outputs[self.iou_metric.name][0],
step_outputs[self.iou_metric.name][1])
return state
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
result = {}
ious = self.iou_metric.result()
# TODO(arashwan): support loading class name from a label map file.
if self.task_config.evaluation.report_per_class_iou:
for i, value in enumerate(ious.numpy()):
result.update({'iou/{}'.format(i): value})
# Computes mean IoU
result.update({'mean_iou': tf.reduce_mean(ious).numpy()})
return result
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Video classification task definition."""
from typing import Any, Optional, List, Tuple
from absl import logging
import tensorflow as tf
from official.core import base_task
from official.core import task_factory
from official.modeling import tf_utils
from official.vision.configs import video_classification as exp_cfg
from official.vision.dataloaders import input_reader_factory
from official.vision.dataloaders import video_input
from official.vision.modeling import factory_3d
@task_factory.register_task_cls(exp_cfg.VideoClassificationTask)
class VideoClassificationTask(base_task.Task):
"""A task for video classification."""
def _get_num_classes(self):
"""Gets the number of classes."""
return self.task_config.train_data.num_classes
def _get_feature_shape(self):
"""Get the common feature shape for train and eval."""
return [
d1 if d1 == d2 else None
for d1, d2 in zip(self.task_config.train_data.feature_shape,
self.task_config.validation_data.feature_shape)
]
def _get_num_test_views(self):
"""Gets number of views for test."""
num_test_clips = self.task_config.validation_data.num_test_clips
num_test_crops = self.task_config.validation_data.num_test_crops
num_test_views = num_test_clips * num_test_crops
return num_test_views
def _is_multilabel(self):
"""If the label is multi-labels."""
return self.task_config.train_data.is_multilabel
def build_model(self):
"""Builds video classification model."""
common_input_shape = self._get_feature_shape()
input_specs = tf.keras.layers.InputSpec(shape=[None] + common_input_shape)
logging.info('Build model input %r', common_input_shape)
l2_weight_decay = self.task_config.losses.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
model = factory_3d.build_model(
self.task_config.model.model_type,
input_specs=input_specs,
model_config=self.task_config.model,
num_classes=self._get_num_classes(),
l2_regularizer=l2_regularizer)
return model
def initialize(self, model: tf.keras.Model):
"""Loads pretrained checkpoint."""
if not self.task_config.init_checkpoint:
return
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
raise ValueError(
"Only 'all' or 'backbone' can be used to initialize the model.")
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def _get_dataset_fn(self, params):
if params.file_type == 'tfrecord':
return tf.data.TFRecordDataset
else:
raise ValueError('Unknown input file type {!r}'.format(params.file_type))
def _get_decoder_fn(self, params):
if params.tfds_name:
decoder = video_input.VideoTfdsDecoder(
image_key=params.image_field_key, label_key=params.label_field_key)
else:
decoder = video_input.Decoder(
image_key=params.image_field_key, label_key=params.label_field_key)
if self.task_config.train_data.output_audio:
assert self.task_config.train_data.audio_feature, 'audio feature is empty'
decoder.add_feature(self.task_config.train_data.audio_feature,
tf.io.VarLenFeature(dtype=tf.float32))
return decoder.decode
def build_inputs(self,
params: exp_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
"""Builds classification input."""
parser = video_input.Parser(
input_params=params,
image_key=params.image_field_key,
label_key=params.label_field_key)
postprocess_fn = video_input.PostBatchProcessor(params)
reader = input_reader_factory.input_reader_generator(
params,
dataset_fn=self._get_dataset_fn(params),
decoder_fn=self._get_decoder_fn(params),
parser_fn=parser.parse_fn(params.is_training),
postprocess_fn=postprocess_fn)
dataset = reader.read(input_context=input_context)
return dataset
def build_losses(self,
labels: Any,
model_outputs: Any,
aux_losses: Optional[Any] = None):
"""Sparse categorical cross entropy loss.
Args:
labels: labels.
model_outputs: Output logits of the classifier.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
all_losses = {}
losses_config = self.task_config.losses
total_loss = None
if self._is_multilabel():
entropy = -tf.reduce_mean(
tf.reduce_sum(model_outputs * tf.math.log(model_outputs + 1e-8), -1))
total_loss = tf.keras.losses.binary_crossentropy(
labels, model_outputs, from_logits=False)
all_losses.update({
'class_loss': total_loss,
'entropy': entropy,
})
else:
if losses_config.one_hot:
total_loss = tf.keras.losses.categorical_crossentropy(
labels,
model_outputs,
from_logits=False,
label_smoothing=losses_config.label_smoothing)
else:
total_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, model_outputs, from_logits=False)
total_loss = tf_utils.safe_mean(total_loss)
all_losses.update({
'class_loss': total_loss,
})
if aux_losses:
all_losses.update({
'reg_loss': aux_losses,
})
total_loss += tf.add_n(aux_losses)
all_losses[self.loss] = total_loss
return all_losses
def build_metrics(self, training: bool = True):
"""Gets streaming metrics for training/validation."""
if self.task_config.losses.one_hot:
metrics = [
tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
tf.keras.metrics.TopKCategoricalAccuracy(k=1, name='top_1_accuracy'),
tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='top_5_accuracy')
]
if self._is_multilabel():
metrics.append(
tf.keras.metrics.AUC(
curve='ROC', multi_label=self._is_multilabel(), name='ROC-AUC'))
metrics.append(
tf.keras.metrics.RecallAtPrecision(
0.95, name='RecallAtPrecision95'))
metrics.append(
tf.keras.metrics.AUC(
curve='PR', multi_label=self._is_multilabel(), name='PR-AUC'))
if self.task_config.metrics.use_per_class_recall:
for i in range(self._get_num_classes()):
metrics.append(
tf.keras.metrics.Recall(class_id=i, name=f'recall-{i}'))
else:
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=1, name='top_1_accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=5, name='top_5_accuracy')
]
return metrics
def process_metrics(self, metrics: List[Any], labels: Any,
model_outputs: Any):
"""Process and update metrics.
Called when using custom training loop API.
Args:
metrics: a nested structure of metrics objects. The return of function
self.build_metrics.
labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors. For example,
output of the keras model built by self.build_model.
"""
for metric in metrics:
metric.update_state(labels, model_outputs)
def train_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics: Optional[List[Any]] = None):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
input_partition_dims = self.task_config.train_input_partition_dims
if input_partition_dims:
strategy = tf.distribute.get_strategy()
features['image'] = strategy.experimental_split_to_logical_devices(
features['image'], input_partition_dims)
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
outputs = model(features, training=True)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(
lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss.
if self._is_multilabel():
outputs = tf.math.sigmoid(outputs)
else:
outputs = tf.math.softmax(outputs)
all_losses = self.build_losses(
model_outputs=outputs, labels=labels, aux_losses=model.losses)
loss = all_losses[self.loss]
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(
optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = all_losses
if metrics:
self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
elif model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in model.metrics})
return logs
def validation_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
metrics: Optional[List[Any]] = None):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
input_partition_dims = self.task_config.eval_input_partition_dims
if input_partition_dims:
strategy = tf.distribute.get_strategy()
features['image'] = strategy.experimental_split_to_logical_devices(
features['image'], input_partition_dims)
outputs = self.inference_step(features, model)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
logs = self.build_losses(model_outputs=outputs, labels=labels,
aux_losses=model.losses)
if metrics:
self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
elif model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in model.metrics})
return logs
def inference_step(self, features: tf.Tensor, model: tf.keras.Model):
"""Performs the forward step."""
outputs = model(features, training=False)
if self._is_multilabel():
outputs = tf.math.sigmoid(outputs)
else:
outputs = tf.math.softmax(outputs)
num_test_views = self._get_num_test_views()
if num_test_views > 1:
# Averaging output probabilities across multiples views.
outputs = tf.reshape(outputs, [-1, num_test_views, outputs.shape[-1]])
outputs = tf.reduce_mean(outputs, axis=1)
return outputs
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""TensorFlow Model Garden Vision training driver."""
from absl import app
from absl import flags
import gin
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
# pylint: disable=unused-import
from official.vision import registry_imports
# pylint: enable=unused-import
FLAGS = flags.FLAGS
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
app.run(main)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""TensorFlow Model Garden Vision training driver with spatial partitioning."""
from typing import Sequence
from absl import app
from absl import flags
import gin
import numpy as np
import tensorflow as tf
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
from official.vision import registry_imports # pylint: disable=unused-import
FLAGS = flags.FLAGS
def get_computation_shape_for_model_parallelism(
input_partition_dims: Sequence[int]) -> Sequence[int]:
"""Returns computation shape to be used for TPUStrategy spatial partition.
Args:
input_partition_dims: The number of partitions along each dimension.
Returns:
A list of integers specifying the computation shape.
Raises:
ValueError: If the number of logical devices is not supported.
"""
num_logical_devices = np.prod(input_partition_dims)
if num_logical_devices == 1:
return [1, 1, 1, 1]
elif num_logical_devices == 2:
return [1, 1, 1, 2]
elif num_logical_devices == 4:
return [1, 2, 1, 2]
elif num_logical_devices == 8:
return [2, 2, 1, 2]
elif num_logical_devices == 16:
return [4, 2, 1, 2]
else:
raise ValueError(
'The number of logical devices %d is not supported. Supported numbers '
'are 1, 2, 4, 8, 16' % num_logical_devices)
def create_distribution_strategy(distribution_strategy,
tpu_address,
input_partition_dims=None,
num_gpus=None):
"""Creates distribution strategy to use for computation."""
if input_partition_dims is not None:
if distribution_strategy != 'tpu':
raise ValueError('Spatial partitioning is only supported '
'for TPUStrategy.')
# When `input_partition_dims` is specified create custom TPUStrategy
# instance with computation shape for model parallelism.
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=tpu_address)
if tpu_address not in ('', 'local'):
tf.config.experimental_connect_to_cluster(resolver)
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
num_replicas = resolver.get_tpu_system_metadata().num_cores // np.prod(
input_partition_dims)
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
topology,
num_replicas=num_replicas,
computation_shape=input_partition_dims)
return tf.distribute.TPUStrategy(
resolver, experimental_device_assignment=device_assignment)
return distribute_utils.get_distribution_strategy(
distribution_strategy=distribution_strategy,
tpu_address=tpu_address,
num_gpus=num_gpus)
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
input_partition_dims = None
if FLAGS.mode == 'train_and_eval':
if np.prod(params.task.train_input_partition_dims) != np.prod(
params.task.eval_input_partition_dims):
raise ValueError('Train and eval input partition dims can not be'
'partitioned on the same node')
else:
input_partition_dims = get_computation_shape_for_model_parallelism(
params.task.train_input_partition_dims)
elif FLAGS.mode == 'train':
if params.task.train_input_partition_dims:
input_partition_dims = get_computation_shape_for_model_parallelism(
params.task.train_input_partition_dims)
elif FLAGS.mode == 'eval' or FLAGS.mode == 'continuous_eval':
if params.task.eval_input_partition_dims:
input_partition_dims = get_computation_shape_for_model_parallelism(
params.task.eval_input_partition_dims)
distribution_strategy = create_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
num_gpus=params.runtime.num_gpus,
input_partition_dims=input_partition_dims,
tpu_address=params.runtime.tpu)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment