"driver/device_direct_convolution_2.hpp" did not exist on "28354a0fa374f71ceeb72ddccf09796701981b3c"
Commit a04d9e0e authored by Vishnu Banna's avatar Vishnu Banna
Browse files

merged

parents 64f16d61 bcbce005
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Sets up TensorFlow Official Models.""" """Sets up TensorFlow Official Models."""
import datetime import datetime
import os import os
......
This directory contains projects using TensorFlow Model Garden Modeling
libraries.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,4 +11,4 @@ ...@@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Some gradient util functions to help users writing custom training loop.""" """Some gradient util functions to help users writing custom training loop."""
from absl import logging from absl import logging
......
...@@ -69,7 +69,15 @@ class PseudoLabelDataConfig(cfg.DataConfig): ...@@ -69,7 +69,15 @@ class PseudoLabelDataConfig(cfg.DataConfig):
"""Psuedo Label input config for training.""" """Psuedo Label input config for training."""
input_path: str = '' input_path: str = ''
data_ratio: float = 1.0 # Per-batch ratio of pseudo-labeled to labeled data. data_ratio: float = 1.0 # Per-batch ratio of pseudo-labeled to labeled data.
is_training: bool = True
dtype: str = 'float32'
shuffle_buffer_size: int = 10000
cycle_length: int = 10
aug_rand_hflip: bool = True aug_rand_hflip: bool = True
aug_type: Optional[ aug_type: Optional[
Augmentation] = None # Choose from AutoAugment and RandAugment. Augmentation] = None # Choose from AutoAugment and RandAugment.
file_type: str = 'tfrecord' file_type: str = 'tfrecord'
# Keep for backward compatibility.
aug_policy: Optional[str] = None # None, 'autoaug', or 'randaug'.
randaug_magnitude: Optional[int] = 10
# 3D ResNet-50g video classification on Kinetics-600.
#
# --experiment_type=video_classification_kinetics600
# Expected accuracy: 78.7% accuracy, 93.6% top-5.
# Train on TPU: v3-128, eval on TPU: v3-32
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
init_checkpoint: null
init_checkpoint_modules: all
losses:
l2_weight_decay: 0.0001
label_smoothing: 0.0
model:
aggregate_endpoints: false
backbone:
resnet_3d:
block_specs: !!python/tuple
- temporal_kernel_sizes: !!python/tuple
- 3
- 3
- 3
temporal_strides: 1
use_self_gating: true
- temporal_kernel_sizes: !!python/tuple
- 3
- 1
- 3
- 1
temporal_strides: 1
use_self_gating: true
- temporal_kernel_sizes: !!python/tuple
- 3
- 1
- 3
- 1
- 3
- 1
temporal_strides: 1
use_self_gating: true
- temporal_kernel_sizes: !!python/tuple
- 1
- 3
- 1
temporal_strides: 1
use_self_gating: true
model_id: 50
stem_conv_temporal_kernel_size: 5
stem_conv_temporal_stride: 2
stem_pool_temporal_stride: 2
stem_type: v0
stochastic_depth_drop_rate: 0.0
type: resnet_3d
dropout_rate: 0.2
model_type: video_classification
norm_activation:
activation: relu
norm_epsilon: 1.0e-05
norm_momentum: 0.9
use_sync_bn: false
train_data:
aug_max_area_ratio: 1.0
aug_max_aspect_ratio: 2.0
aug_min_area_ratio: 0.49
aug_min_aspect_ratio: 0.5
drop_remainder: true
dtype: 'bfloat16'
feature_shape: !!python/tuple
- 64
- 224
- 224
- 3
global_batch_size: 1024
min_image_size: 256
name: kinetics600
num_classes: 600
split: train
validation_data:
dtype: 'bfloat16'
feature_shape: !!python/tuple
- 250
- 224
- 224
- 3
global_batch_size: 64
min_image_size: 256
name: kinetics600
num_classes: 600
num_examples: 27780
num_test_clips: 1
num_test_crops: 1
one_hot: true
trainer:
optimizer_config:
learning_rate:
cosine:
alpha: 0.0
decay_steps: 71400
initial_learning_rate: 1.6
name: CosineDecay
type: cosine
warmup:
linear:
name: linear
warmup_learning_rate: 0
warmup_steps: 1785
type: linear
train_steps: 71400
steps_per_loop: 500
summary_interval: 500
validation_interval: 500
...@@ -43,6 +43,7 @@ class DataConfig(cfg.DataConfig): ...@@ -43,6 +43,7 @@ class DataConfig(cfg.DataConfig):
file_type: str = 'tfrecord' file_type: str = 'tfrecord'
image_field_key: str = 'image/encoded' image_field_key: str = 'image/encoded'
label_field_key: str = 'image/class/label' label_field_key: str = 'image/class/label'
decode_jpeg_only: bool = True
# Keep for backward compatibility. # Keep for backward compatibility.
aug_policy: Optional[str] = None # None, 'autoaug', or 'randaug'. aug_policy: Optional[str] = None # None, 'autoaug', or 'randaug'.
......
...@@ -46,7 +46,7 @@ from official.vision.beta.data import tfrecord_lib ...@@ -46,7 +46,7 @@ from official.vision.beta.data import tfrecord_lib
flags.DEFINE_boolean( flags.DEFINE_boolean(
'include_masks', False, 'Whether to include instance segmentations masks ' 'include_masks', False, 'Whether to include instance segmentations masks '
'(PNG encoded) in the result. default: False.') '(PNG encoded) in the result. default: False.')
flags.DEFINE_string('image_dir', '', 'Directory containing images.') flags.DEFINE_multi_string('image_dir', '', 'Directory containing images.')
flags.DEFINE_string( flags.DEFINE_string(
'image_info_file', '', 'File containing image information. ' 'image_info_file', '', 'File containing image information. '
'Tf Examples in the output files correspond to the image ' 'Tf Examples in the output files correspond to the image '
...@@ -159,7 +159,7 @@ def encode_caption_annotations(caption_annotations): ...@@ -159,7 +159,7 @@ def encode_caption_annotations(caption_annotations):
def create_tf_example(image, def create_tf_example(image,
image_dir, image_dirs,
bbox_annotations=None, bbox_annotations=None,
id_to_name_map=None, id_to_name_map=None,
caption_annotations=None, caption_annotations=None,
...@@ -169,7 +169,7 @@ def create_tf_example(image, ...@@ -169,7 +169,7 @@ def create_tf_example(image,
Args: Args:
image: dict with keys: [u'license', u'file_name', u'coco_url', u'height', image: dict with keys: [u'license', u'file_name', u'coco_url', u'height',
u'width', u'date_captured', u'flickr_url', u'id'] u'width', u'date_captured', u'flickr_url', u'id']
image_dir: directory containing the image files. image_dirs: list of directories containing the image files.
bbox_annotations: bbox_annotations:
list of dicts with keys: [u'segmentation', u'area', u'iscrowd', list of dicts with keys: [u'segmentation', u'area', u'iscrowd',
u'image_id', u'bbox', u'category_id', u'id'] Notice that bounding box u'image_id', u'bbox', u'category_id', u'id'] Notice that bounding box
...@@ -190,14 +190,31 @@ def create_tf_example(image, ...@@ -190,14 +190,31 @@ def create_tf_example(image,
num_annotations_skipped: Number of (invalid) annotations that were ignored. num_annotations_skipped: Number of (invalid) annotations that were ignored.
Raises: Raises:
ValueError: if the image pointed to by data['filename'] is not a valid JPEG ValueError: if the image pointed to by data['filename'] is not a valid JPEG,
does not exist, or is not unique across image directories.
""" """
image_height = image['height'] image_height = image['height']
image_width = image['width'] image_width = image['width']
filename = image['file_name'] filename = image['file_name']
image_id = image['id'] image_id = image['id']
full_path = os.path.join(image_dir, filename) if len(image_dirs) > 1:
full_paths = [os.path.join(image_dir, filename) for image_dir in image_dirs]
full_existing_paths = [p for p in full_paths if tf.io.gfile.exists(p)]
if not full_existing_paths:
raise ValueError(
'{} does not exist across image directories.'.format(filename))
if len(full_existing_paths) > 1:
raise ValueError(
'{} is not unique across image directories'.format(filename))
full_path, = full_existing_paths
# If there is only one image directory, it's not worth checking for existence,
# since trying to open the file will raise an informative error message if it
# does not exist.
else:
image_dir, = image_dirs
full_path = os.path.join(image_dir, filename)
with tf.io.gfile.GFile(full_path, 'rb') as fid: with tf.io.gfile.GFile(full_path, 'rb') as fid:
encoded_jpg = fid.read() encoded_jpg = fid.read()
...@@ -276,7 +293,7 @@ def _load_images_info(images_info_file): ...@@ -276,7 +293,7 @@ def _load_images_info(images_info_file):
return info_dict['images'] return info_dict['images']
def generate_annotations(images, image_dir, def generate_annotations(images, image_dirs,
img_to_obj_annotation=None, img_to_obj_annotation=None,
img_to_caption_annotation=None, id_to_name_map=None, img_to_caption_annotation=None, id_to_name_map=None,
include_masks=False): include_masks=False):
...@@ -289,12 +306,12 @@ def generate_annotations(images, image_dir, ...@@ -289,12 +306,12 @@ def generate_annotations(images, image_dir,
caption_annotaion = (img_to_caption_annotation.get(image['id'], None) if caption_annotaion = (img_to_caption_annotation.get(image['id'], None) if
img_to_caption_annotation else None) img_to_caption_annotation else None)
yield (image, image_dir, object_annotation, id_to_name_map, yield (image, image_dirs, object_annotation, id_to_name_map,
caption_annotaion, include_masks) caption_annotaion, include_masks)
def _create_tf_record_from_coco_annotations(images_info_file, def _create_tf_record_from_coco_annotations(images_info_file,
image_dir, image_dirs,
output_path, output_path,
num_shards, num_shards,
object_annotations_file=None, object_annotations_file=None,
...@@ -309,7 +326,7 @@ def _create_tf_record_from_coco_annotations(images_info_file, ...@@ -309,7 +326,7 @@ def _create_tf_record_from_coco_annotations(images_info_file,
files Eg. 'image_info_test-dev2017.json', files Eg. 'image_info_test-dev2017.json',
'instance_annotations_train2017.json', 'instance_annotations_train2017.json',
'caption_annotations_train2017.json', etc. 'caption_annotations_train2017.json', etc.
image_dir: Directory containing the image files. image_dirs: List of directories containing the image files.
output_path: Path to output tf.Record file. output_path: Path to output tf.Record file.
num_shards: Number of output files to create. num_shards: Number of output files to create.
object_annotations_file: JSON file containing bounding box annotations. object_annotations_file: JSON file containing bounding box annotations.
...@@ -333,7 +350,7 @@ def _create_tf_record_from_coco_annotations(images_info_file, ...@@ -333,7 +350,7 @@ def _create_tf_record_from_coco_annotations(images_info_file,
_load_caption_annotations(caption_annotations_file)) _load_caption_annotations(caption_annotations_file))
coco_annotations_iter = generate_annotations( coco_annotations_iter = generate_annotations(
images, image_dir, img_to_obj_annotation, img_to_caption_annotation, images, image_dirs, img_to_obj_annotation, img_to_caption_annotation,
id_to_name_map=id_to_name_map, include_masks=include_masks) id_to_name_map=id_to_name_map, include_masks=include_masks)
num_skipped = tfrecord_lib.write_tf_record_dataset( num_skipped = tfrecord_lib.write_tf_record_dataset(
......
#!/bin/bash
#
# Processes the COCO few-shot benchmark into TFRecord files. Requires `wget`.
tmp_dir=$(mktemp -d -t coco-XXXXXXXXXX)
output_dir="/tmp/coco_few_shot"
while getopts "o:" o; do
case "${o}" in
o) output_dir=${OPTARG} ;;
*) echo "Usage: ${0} [-o <output_dir>]" 1>&2; exit 1 ;;
esac
done
cocosplit_url="dl.yf.io/fs-det/datasets/cocosplit"
wget --recursive --no-parent -q --show-progress --progress=bar:force:noscroll \
-P "${tmp_dir}" -A "5k.json,*10shot*.json,*30shot*.json" \
"http://${cocosplit_url}/"
mv "${tmp_dir}/${cocosplit_url}/"* "${tmp_dir}"
rm -rf "${tmp_dir}/${cocosplit_url}/"
python process_coco_few_shot_json_files.py \
--logtostderr --workdir="${tmp_dir}"
for seed in {0..9}; do
for shots in 10 30; do
python create_coco_tf_record.py \
--logtostderr \
--image_dir=/namespace/vale-project/datasets/mscoco_raw/images/train2014 \
--image_dir=/namespace/vale-project/datasets/mscoco_raw/images/val2014 \
--image_info_file="${tmp_dir}/${shots}shot_seed${seed}.json" \
--object_annotations_file="${tmp_dir}/${shots}shot_seed${seed}.json" \
--caption_annotations_file="" \
--output_file_prefix="${output_dir}/${shots}shot_seed${seed}" \
--num_shards=4
done
done
python create_coco_tf_record.py \
--logtostderr \
--image_dir=/namespace/vale-project/datasets/mscoco_raw/images/train2014 \
--image_dir=/namespace/vale-project/datasets/mscoco_raw/images/val2014 \
--image_info_file="${tmp_dir}/datasplit/5k.json" \
--object_annotations_file="${tmp_dir}/datasplit/5k.json" \
--caption_annotations_file="" \
--output_file_prefix="${output_dir}/5k" \
--num_shards=10
rm -rf "${tmp_dir}"
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Processes the JSON files for COCO few-shot.
We assume that `workdir` mirrors the contents of
http://dl.yf.io/fs-det/datasets/cocosplit/, which contains the official JSON
files for the few-shot COCO evaluation procedure that Wang et al. (2020)'s
"Frustratingly Simple Few-Shot Object Detection" paper uses.
"""
import collections
import itertools
import json
import logging
import os
from absl import app
from absl import flags
import tensorflow as tf
logger = tf.get_logger()
logger.setLevel(logging.INFO)
flags.DEFINE_string('workdir', None, 'Working directory.')
FLAGS = flags.FLAGS
CATEGORIES = ['airplane', 'apple', 'backpack', 'banana', 'baseball bat',
'baseball glove', 'bear', 'bed', 'bench', 'bicycle', 'bird',
'boat', 'book', 'bottle', 'bowl', 'broccoli', 'bus', 'cake',
'car', 'carrot', 'cat', 'cell phone', 'chair', 'clock', 'couch',
'cow', 'cup', 'dining table', 'dog', 'donut', 'elephant',
'fire hydrant', 'fork', 'frisbee', 'giraffe', 'hair drier',
'handbag', 'horse', 'hot dog', 'keyboard', 'kite', 'knife',
'laptop', 'microwave', 'motorcycle', 'mouse', 'orange', 'oven',
'parking meter', 'person', 'pizza', 'potted plant',
'refrigerator', 'remote', 'sandwich', 'scissors', 'sheep',
'sink', 'skateboard', 'skis', 'snowboard', 'spoon', 'sports ball',
'stop sign', 'suitcase', 'surfboard', 'teddy bear',
'tennis racket', 'tie', 'toaster', 'toilet', 'toothbrush',
'traffic light', 'train', 'truck', 'tv', 'umbrella', 'vase',
'wine glass', 'zebra']
SEEDS = list(range(10))
SHOTS = [10, 30]
FILE_SUFFIXES = collections.defaultdict(list)
for _seed, _shots in itertools.product(SEEDS, SHOTS):
for _category in CATEGORIES:
FILE_SUFFIXES[(_seed, _shots)].append(
'{}full_box_{}shot_{}_trainval.json'.format(
# http://dl.yf.io/fs-det/datasets/cocosplit/ is organized like so:
#
# datasplit/
# trainvalno5k.json
# 5k.json
# full_box_{1,2,3,5,10,30}shot_{category}_trainval.json
# seed{1-9}/
# full_box_{1,2,3,5,10,30}shot_{category}_trainval.json
#
# This means that the JSON files for seed0 are located in the root
# directory rather than in a `seed?/` subdirectory, hence the
# conditional expression below.
'' if _seed == 0 else 'seed{}/'.format(_seed),
_shots,
_category))
def main(unused_argv):
workdir = FLAGS.workdir
for seed, shots in itertools.product(SEEDS, SHOTS):
# Retrieve all examples for a given seed and shots setting.
file_paths = [os.path.join(workdir, suffix)
for suffix in FILE_SUFFIXES[(seed, shots)]]
json_dicts = []
for file_path in file_paths:
with tf.io.gfile.GFile(file_path, 'r') as f:
json_dicts.append(json.load(f))
# Make sure that all JSON files for a given seed and shots setting have the
# same metadata. We count on this to fuse them later on.
metadata_dicts = [{'info': d['info'], 'licenses': d['licenses'],
'categories': d['categories']} for d in json_dicts]
if not all(d == metadata_dicts[0] for d in metadata_dicts[1:]):
raise RuntimeError(
'JSON files for {} shots (seed {}) '.format(shots, seed) +
'have different info, licences, or categories fields')
# Retrieve images across all JSON files.
images = sum((d['images'] for d in json_dicts), [])
# Remove duplicate image entries.
images = list({image['id']: image for image in images}.values())
output_dict = {
'info': json_dicts[0]['info'],
'licenses': json_dicts[0]['licenses'],
'categories': json_dicts[0]['categories'],
'images': images,
'annotations': sum((d['annotations'] for d in json_dicts), [])
}
output_path = os.path.join(workdir,
'{}shot_seed{}.json'.format(shots, seed))
with tf.io.gfile.GFile(output_path, 'w') as f:
json.dump(output_dict, f)
logger.info('Processed %d shots (seed %d) and saved to %s',
shots, seed, output_path)
if __name__ == '__main__':
flags.mark_flag_as_required('workdir')
app.run(main)
...@@ -66,6 +66,7 @@ class Parser(parser.Parser): ...@@ -66,6 +66,7 @@ class Parser(parser.Parser):
num_classes: float, num_classes: float,
image_field_key: str = DEFAULT_IMAGE_FIELD_KEY, image_field_key: str = DEFAULT_IMAGE_FIELD_KEY,
label_field_key: str = DEFAULT_LABEL_FIELD_KEY, label_field_key: str = DEFAULT_LABEL_FIELD_KEY,
decode_jpeg_only: bool = True,
aug_rand_hflip: bool = True, aug_rand_hflip: bool = True,
aug_type: Optional[common.Augmentation] = None, aug_type: Optional[common.Augmentation] = None,
is_multilabel: bool = False, is_multilabel: bool = False,
...@@ -78,6 +79,8 @@ class Parser(parser.Parser): ...@@ -78,6 +79,8 @@ class Parser(parser.Parser):
num_classes: `float`, number of classes. num_classes: `float`, number of classes.
image_field_key: `str`, the key name to encoded image in tf.Example. image_field_key: `str`, the key name to encoded image in tf.Example.
label_field_key: `str`, the key name to label in tf.Example. label_field_key: `str`, the key name to label in tf.Example.
decode_jpeg_only: `bool`, if True, only JPEG format is decoded, this is
faster than decoding other types. Default is True.
aug_rand_hflip: `bool`, if True, augment training with random aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip. horizontal flip.
aug_type: An optional Augmentation object to choose from AutoAugment and aug_type: An optional Augmentation object to choose from AutoAugment and
...@@ -118,6 +121,7 @@ class Parser(parser.Parser): ...@@ -118,6 +121,7 @@ class Parser(parser.Parser):
self._augmenter = None self._augmenter = None
self._label_field_key = label_field_key self._label_field_key = label_field_key
self._is_multilabel = is_multilabel self._is_multilabel = is_multilabel
self._decode_jpeg_only = decode_jpeg_only
def _parse_train_data(self, decoded_tensors): def _parse_train_data(self, decoded_tensors):
"""Parses data for training.""" """Parses data for training."""
...@@ -142,16 +146,29 @@ class Parser(parser.Parser): ...@@ -142,16 +146,29 @@ class Parser(parser.Parser):
def _parse_train_image(self, decoded_tensors): def _parse_train_image(self, decoded_tensors):
"""Parses image data for training.""" """Parses image data for training."""
image_bytes = decoded_tensors[self._image_field_key] image_bytes = decoded_tensors[self._image_field_key]
image_shape = tf.image.extract_jpeg_shape(image_bytes)
# Crops image. if self._decode_jpeg_only:
# TODO(pengchong): support image format other than JPEG. image_shape = tf.image.extract_jpeg_shape(image_bytes)
cropped_image = preprocess_ops.random_crop_image_v2(
image_bytes, image_shape) # Crops image.
image = tf.cond( cropped_image = preprocess_ops.random_crop_image_v2(
tf.reduce_all(tf.equal(tf.shape(cropped_image), image_shape)), image_bytes, image_shape)
lambda: preprocess_ops.center_crop_image_v2(image_bytes, image_shape), image = tf.cond(
lambda: cropped_image) tf.reduce_all(tf.equal(tf.shape(cropped_image), image_shape)),
lambda: preprocess_ops.center_crop_image_v2(image_bytes, image_shape),
lambda: cropped_image)
else:
# Decodes image.
image = tf.io.decode_image(image_bytes, channels=3)
image.set_shape([None, None, 3])
# Crops image.
cropped_image = preprocess_ops.random_crop_image(image)
image = tf.cond(
tf.reduce_all(tf.equal(tf.shape(cropped_image), tf.shape(image))),
lambda: preprocess_ops.center_crop_image(image),
lambda: cropped_image)
if self._aug_rand_hflip: if self._aug_rand_hflip:
image = tf.image.random_flip_left_right(image) image = tf.image.random_flip_left_right(image)
...@@ -159,6 +176,7 @@ class Parser(parser.Parser): ...@@ -159,6 +176,7 @@ class Parser(parser.Parser):
# Resizes image. # Resizes image.
image = tf.image.resize( image = tf.image.resize(
image, self._output_size, method=tf.image.ResizeMethod.BILINEAR) image, self._output_size, method=tf.image.ResizeMethod.BILINEAR)
image.set_shape([self._output_size[0], self._output_size[1], 3])
# Apply autoaug or randaug. # Apply autoaug or randaug.
if self._augmenter is not None: if self._augmenter is not None:
...@@ -177,15 +195,23 @@ class Parser(parser.Parser): ...@@ -177,15 +195,23 @@ class Parser(parser.Parser):
def _parse_eval_image(self, decoded_tensors): def _parse_eval_image(self, decoded_tensors):
"""Parses image data for evaluation.""" """Parses image data for evaluation."""
image_bytes = decoded_tensors[self._image_field_key] image_bytes = decoded_tensors[self._image_field_key]
image_shape = tf.image.extract_jpeg_shape(image_bytes)
# Center crops and resizes image. if self._decode_jpeg_only:
image = preprocess_ops.center_crop_image_v2(image_bytes, image_shape) image_shape = tf.image.extract_jpeg_shape(image_bytes)
# Center crops.
image = preprocess_ops.center_crop_image_v2(image_bytes, image_shape)
else:
# Decodes image.
image = tf.io.decode_image(image_bytes, channels=3)
image.set_shape([None, None, 3])
# Center crops.
image = preprocess_ops.center_crop_image(image)
image = tf.image.resize( image = tf.image.resize(
image, self._output_size, method=tf.image.ResizeMethod.BILINEAR) image, self._output_size, method=tf.image.ResizeMethod.BILINEAR)
image.set_shape([self._output_size[0], self._output_size[1], 3])
image = tf.reshape(image, [self._output_size[0], self._output_size[1], 3])
# Normalizes image with mean and std pixel values. # Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image, image = preprocess_ops.normalize_image(image,
......
...@@ -127,10 +127,12 @@ def _encode_image(image_array: np.ndarray, fmt: str) -> bytes: ...@@ -127,10 +127,12 @@ def _encode_image(image_array: np.ndarray, fmt: str) -> bytes:
def create_classification_example( def create_classification_example(
image_height: int, image_height: int,
image_width: int, image_width: int,
image_format: str = 'JPEG',
is_multilabel: bool = False) -> tf.train.Example: is_multilabel: bool = False) -> tf.train.Example:
"""Creates image and labels for image classification input pipeline.""" """Creates image and labels for image classification input pipeline."""
image = _encode_image( image = _encode_image(
np.uint8(np.random.rand(image_height, image_width, 3) * 255), fmt='JPEG') np.uint8(np.random.rand(image_height, image_width, 3) * 255),
fmt=image_format)
labels = [0, 1] if is_multilabel else [0] labels = [0, 1] if is_multilabel else [0]
serialized_example = tf.train.Example( serialized_example = tf.train.Example(
features=tf.train.Features( features=tf.train.Features(
......
...@@ -502,7 +502,7 @@ class MobileNet(tf.keras.Model): ...@@ -502,7 +502,7 @@ class MobileNet(tf.keras.Model):
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
# The followings should be kept the same most of the times. # The followings should be kept the same most of the times.
output_stride: int = None, output_stride: Optional[int] = None,
min_depth: int = 8, min_depth: int = 8,
# divisible is not used in MobileNetV1. # divisible is not used in MobileNetV1.
divisible_by: int = 8, divisible_by: int = 8,
...@@ -768,7 +768,8 @@ def build_mobilenet( ...@@ -768,7 +768,8 @@ def build_mobilenet(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
backbone_config: hyperparams.Config, backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config, norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
"""Builds MobileNet backbone from a config.""" """Builds MobileNet backbone from a config."""
backbone_type = backbone_config.type backbone_type = backbone_config.type
backbone_cfg = backbone_config.get() backbone_cfg = backbone_config.get()
......
...@@ -81,7 +81,7 @@ class ResNet3D(tf.keras.Model): ...@@ -81,7 +81,7 @@ class ResNet3D(tf.keras.Model):
model_id: int, model_id: int,
temporal_strides: List[int], temporal_strides: List[int],
temporal_kernel_sizes: List[Tuple[int]], temporal_kernel_sizes: List[Tuple[int]],
use_self_gating: List[int] = None, use_self_gating: Optional[List[int]] = None,
input_specs: tf.keras.layers.InputSpec = layers.InputSpec( input_specs: tf.keras.layers.InputSpec = layers.InputSpec(
shape=[None, None, None, None, 3]), shape=[None, None, None, None, 3]),
stem_type: str = 'v0', stem_type: str = 'v0',
...@@ -380,7 +380,8 @@ def build_resnet3d( ...@@ -380,7 +380,8 @@ def build_resnet3d(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
backbone_config: hyperparams.Config, backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config, norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
"""Builds ResNet 3d backbone from a config.""" """Builds ResNet 3d backbone from a config."""
backbone_cfg = backbone_config.get() backbone_cfg = backbone_config.get()
...@@ -418,7 +419,8 @@ def build_resnet3d_rs( ...@@ -418,7 +419,8 @@ def build_resnet3d_rs(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
backbone_config: hyperparams.Config, backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config, norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
"""Builds ResNet-3D-RS backbone from a config.""" """Builds ResNet-3D-RS backbone from a config."""
backbone_cfg = backbone_config.get() backbone_cfg = backbone_config.get()
......
...@@ -36,7 +36,7 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -36,7 +36,7 @@ class RetinaNetHead(tf.keras.layers.Layer):
num_anchors_per_location: int, num_anchors_per_location: int,
num_convs: int = 4, num_convs: int = 4,
num_filters: int = 256, num_filters: int = 256,
attribute_heads: List[Dict[str, Any]] = None, attribute_heads: Optional[List[Dict[str, Any]]] = None,
use_separable_conv: bool = False, use_separable_conv: bool = False,
activation: str = 'relu', activation: str = 'relu',
use_sync_bn: bool = False, use_sync_bn: bool = False,
......
...@@ -593,7 +593,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer): ...@@ -593,7 +593,7 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
raw_scores: Mapping[str, tf.Tensor], raw_scores: Mapping[str, tf.Tensor],
anchor_boxes: tf.Tensor, anchor_boxes: tf.Tensor,
image_shape: tf.Tensor, image_shape: tf.Tensor,
raw_attributes: Mapping[str, tf.Tensor] = None): raw_attributes: Optional[Mapping[str, tf.Tensor]] = None):
"""Generates final detections. """Generates final detections.
Args: Args:
......
...@@ -132,8 +132,7 @@ class SqueezeExcitation(tf.keras.layers.Layer): ...@@ -132,8 +132,7 @@ class SqueezeExcitation(tf.keras.layers.Layer):
def build(self, input_shape): def build(self, input_shape):
num_reduced_filters = make_divisible( num_reduced_filters = make_divisible(
max(1, int(self._in_filters * self._se_ratio)), self._in_filters * self._se_ratio, divisor=self._divisible_by)
divisor=self._divisible_by)
self._se_reduce = tf.keras.layers.Conv2D( self._se_reduce = tf.keras.layers.Conv2D(
filters=num_reduced_filters, filters=num_reduced_filters,
...@@ -282,9 +281,6 @@ class Scale(tf.keras.layers.Layer): ...@@ -282,9 +281,6 @@ class Scale(tf.keras.layers.Layer):
This is useful for applying ReZero to layers, which improves convergence This is useful for applying ReZero to layers, which improves convergence
speed. This implements the paper: speed. This implements the paper:
Thomas Bachlechner, Bodhisattwa Prasad Majumder, Huanru Henry Mao,
Garrison W. Cottrell, Julian McAuley.
ReZero is All You Need: Fast Convergence at Large Depth. ReZero is All You Need: Fast Convergence at Large Depth.
(https://arxiv.org/pdf/2003.04887.pdf). (https://arxiv.org/pdf/2003.04887.pdf).
""" """
...@@ -372,6 +368,7 @@ class PositionalEncoding(tf.keras.layers.Layer): ...@@ -372,6 +368,7 @@ class PositionalEncoding(tf.keras.layers.Layer):
def __init__(self, def __init__(self,
initializer: tf.keras.initializers.Initializer = 'zeros', initializer: tf.keras.initializers.Initializer = 'zeros',
cache_encoding: bool = False, cache_encoding: bool = False,
state_prefix: Optional[str] = None,
**kwargs): **kwargs):
"""Initializes positional encoding. """Initializes positional encoding.
...@@ -381,6 +378,7 @@ class PositionalEncoding(tf.keras.layers.Layer): ...@@ -381,6 +378,7 @@ class PositionalEncoding(tf.keras.layers.Layer):
after calling build. Otherwise, rebuild the tensor for every call. after calling build. Otherwise, rebuild the tensor for every call.
Setting this to False can be useful when we want to input a variable Setting this to False can be useful when we want to input a variable
number of frames, so the positional encoding tensor can change shape. number of frames, so the positional encoding tensor can change shape.
state_prefix: a prefix string to identify states.
**kwargs: Additional keyword arguments to be passed to this layer. **kwargs: Additional keyword arguments to be passed to this layer.
Returns: Returns:
...@@ -391,33 +389,43 @@ class PositionalEncoding(tf.keras.layers.Layer): ...@@ -391,33 +389,43 @@ class PositionalEncoding(tf.keras.layers.Layer):
self._cache_encoding = cache_encoding self._cache_encoding = cache_encoding
self._pos_encoding = None self._pos_encoding = None
self._rezero = Scale(initializer=initializer, name='rezero') self._rezero = Scale(initializer=initializer, name='rezero')
state_prefix = state_prefix if state_prefix is not None else ''
self._state_prefix = state_prefix
self._frame_count_name = f'{state_prefix}/pos_enc_frame_count'
def get_config(self): def get_config(self):
"""Returns a dictionary containing the config used for initialization.""" """Returns a dictionary containing the config used for initialization."""
config = { config = {
'initializer': self._initializer, 'initializer': self._initializer,
'cache_encoding': self._cache_encoding, 'cache_encoding': self._cache_encoding,
'state_prefix': self._state_prefix,
} }
base_config = super(PositionalEncoding, self).get_config() base_config = super(PositionalEncoding, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def _positional_encoding(self, def _positional_encoding(self,
num_positions: int, num_positions: Union[int, tf.Tensor],
hidden_size: int, hidden_size: Union[int, tf.Tensor],
dtype: tf.DType = tf.float32): start_position: Union[int, tf.Tensor] = 0,
dtype: str = 'float32') -> tf.Tensor:
"""Creates a sequence of sinusoidal positional encoding vectors. """Creates a sequence of sinusoidal positional encoding vectors.
Args: Args:
num_positions: An `int` of number of positions (frames). num_positions: the total number of positions (frames).
hidden_size: An `int` of number of channels used for the hidden vectors. hidden_size: the number of channels used for the hidden vectors.
dtype: The dtype of the output tensor. start_position: the start position.
dtype: the dtype of the output tensor.
Returns: Returns:
The positional encoding tensor with shape [num_positions, hidden_size]. The positional encoding tensor with shape [num_positions, hidden_size].
""" """
if isinstance(start_position, tf.Tensor) and start_position.shape.rank == 1:
start_position = start_position[0]
# Calling `tf.range` with `dtype=tf.bfloat16` results in an error, # Calling `tf.range` with `dtype=tf.bfloat16` results in an error,
# so we cast afterward. # so we cast afterward.
positions = tf.cast(tf.range(num_positions)[:, tf.newaxis], dtype) positions = tf.range(start_position, start_position + num_positions)
positions = tf.cast(positions, dtype)[:, tf.newaxis]
idx = tf.range(hidden_size)[tf.newaxis, :] idx = tf.range(hidden_size)[tf.newaxis, :]
power = tf.cast(2 * (idx // 2), dtype) power = tf.cast(2 * (idx // 2), dtype)
...@@ -431,11 +439,24 @@ class PositionalEncoding(tf.keras.layers.Layer): ...@@ -431,11 +439,24 @@ class PositionalEncoding(tf.keras.layers.Layer):
return pos_encoding return pos_encoding
def _get_pos_encoding(self, input_shape): def _get_pos_encoding(self,
"""Calculates the positional encoding from the input shape.""" input_shape: tf.Tensor,
frame_count: int = 0) -> tf.Tensor:
"""Calculates the positional encoding from the input shape.
Args:
input_shape: the shape of the input.
frame_count: a count of frames that indicates the index of the first
frame.
Returns:
The positional encoding tensor with shape [num_positions, hidden_size].
"""
frames = input_shape[1] frames = input_shape[1]
channels = input_shape[-1] channels = input_shape[-1]
pos_encoding = self._positional_encoding(frames, channels, dtype=self.dtype) pos_encoding = self._positional_encoding(
frames, channels, start_position=frame_count, dtype=self.dtype)
pos_encoding = tf.reshape(pos_encoding, [1, frames, 1, 1, channels]) pos_encoding = tf.reshape(pos_encoding, [1, frames, 1, 1, channels])
return pos_encoding return pos_encoding
...@@ -456,16 +477,46 @@ class PositionalEncoding(tf.keras.layers.Layer): ...@@ -456,16 +477,46 @@ class PositionalEncoding(tf.keras.layers.Layer):
super(PositionalEncoding, self).build(input_shape) super(PositionalEncoding, self).build(input_shape)
def call(self, inputs): def call(
"""Calls the layer with the given inputs.""" self,
inputs: tf.Tensor,
states: Optional[States] = None,
output_states: bool = True,
) -> Union[tf.Tensor, Tuple[tf.Tensor, States]]:
"""Calls the layer with the given inputs.
Args:
inputs: An input `tf.Tensor`.
states: A `dict` of states such that, if any of the keys match for this
layer, will overwrite the contents of the buffer(s). Expected keys
include `state_prefix + '/pos_enc_frame_count'`.
output_states: A `bool`. If True, returns the output tensor and output
states. Returns just the output tensor otherwise.
Returns:
An output `tf.Tensor` (and optionally the states if `output_states=True`).
Raises:
ValueError: If using 'channels_first' data format.
"""
states = dict(states) if states is not None else {}
# Keep a count of frames encountered across input iterations in
# num_frames to be able to accurately update the positional encoding.
num_frames = tf.shape(inputs)[1]
frame_count = tf.cast(states.get(self._frame_count_name, [0]), tf.int32)
states[self._frame_count_name] = frame_count + num_frames
if self._cache_encoding: if self._cache_encoding:
pos_encoding = self._pos_encoding pos_encoding = self._pos_encoding
else: else:
pos_encoding = self._get_pos_encoding(tf.shape(inputs)) pos_encoding = self._get_pos_encoding(
tf.shape(inputs), frame_count=frame_count)
pos_encoding = tf.cast(pos_encoding, inputs.dtype) pos_encoding = tf.cast(pos_encoding, inputs.dtype)
pos_encoding = tf.stop_gradient(pos_encoding)
pos_encoding = self._rezero(pos_encoding) pos_encoding = self._rezero(pos_encoding)
return inputs + pos_encoding outputs = inputs + pos_encoding
return (outputs, states) if output_states else outputs
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
...@@ -481,6 +532,7 @@ class GlobalAveragePool3D(tf.keras.layers.Layer): ...@@ -481,6 +532,7 @@ class GlobalAveragePool3D(tf.keras.layers.Layer):
def __init__(self, def __init__(self,
keepdims: bool = False, keepdims: bool = False,
causal: bool = False, causal: bool = False,
state_prefix: Optional[str] = None,
**kwargs): **kwargs):
"""Initializes a global average pool layer. """Initializes a global average pool layer.
...@@ -488,6 +540,7 @@ class GlobalAveragePool3D(tf.keras.layers.Layer): ...@@ -488,6 +540,7 @@ class GlobalAveragePool3D(tf.keras.layers.Layer):
keepdims: A `bool`. If True, keep the averaged dimensions. keepdims: A `bool`. If True, keep the averaged dimensions.
causal: A `bool` of whether to run in causal mode with a cumulative sum causal: A `bool` of whether to run in causal mode with a cumulative sum
across frames. across frames.
state_prefix: a prefix string to identify states.
**kwargs: Additional keyword arguments to be passed to this layer. **kwargs: Additional keyword arguments to be passed to this layer.
Returns: Returns:
...@@ -497,29 +550,22 @@ class GlobalAveragePool3D(tf.keras.layers.Layer): ...@@ -497,29 +550,22 @@ class GlobalAveragePool3D(tf.keras.layers.Layer):
self._keepdims = keepdims self._keepdims = keepdims
self._causal = causal self._causal = causal
state_prefix = state_prefix if state_prefix is not None else ''
self._state_prefix = state_prefix
self._frame_count = None self._state_name = f'{state_prefix}/pool_buffer'
self._frame_count_name = f'{state_prefix}/pool_frame_count'
def get_config(self): def get_config(self):
"""Returns a dictionary containing the config used for initialization.""" """Returns a dictionary containing the config used for initialization."""
config = { config = {
'keepdims': self._keepdims, 'keepdims': self._keepdims,
'causal': self._causal, 'causal': self._causal,
'state_prefix': self._state_prefix,
} }
base_config = super(GlobalAveragePool3D, self).get_config() base_config = super(GlobalAveragePool3D, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
"""Builds the layer with the given input shape."""
# Here we define strings that will uniquely reference the buffer states
# in the TF graph. These will be used for passing in a mapping of states
# for streaming mode. To do this, we can use a name scope.
with tf.name_scope('buffer') as state_name:
self._state_name = state_name
self._frame_count_name = state_name + '_frame_count'
super(GlobalAveragePool3D, self).build(input_shape)
def call(self, def call(self,
inputs: tf.Tensor, inputs: tf.Tensor,
states: Optional[States] = None, states: Optional[States] = None,
...@@ -531,6 +577,8 @@ class GlobalAveragePool3D(tf.keras.layers.Layer): ...@@ -531,6 +577,8 @@ class GlobalAveragePool3D(tf.keras.layers.Layer):
inputs: An input `tf.Tensor`. inputs: An input `tf.Tensor`.
states: A `dict` of states such that, if any of the keys match for this states: A `dict` of states such that, if any of the keys match for this
layer, will overwrite the contents of the buffer(s). layer, will overwrite the contents of the buffer(s).
Expected keys include `state_prefix + '/pool_buffer'` and
`state_prefix + '/pool_frame_count'`.
output_states: A `bool`. If True, returns the output tensor and output output_states: A `bool`. If True, returns the output tensor and output
states. Returns just the output tensor otherwise. states. Returns just the output tensor otherwise.
...@@ -562,7 +610,8 @@ class GlobalAveragePool3D(tf.keras.layers.Layer): ...@@ -562,7 +610,8 @@ class GlobalAveragePool3D(tf.keras.layers.Layer):
# num_frames to be able to accurately take a cumulative average across # num_frames to be able to accurately take a cumulative average across
# all frames when running in streaming mode # all frames when running in streaming mode
num_frames = tf.shape(inputs)[1] num_frames = tf.shape(inputs)[1]
frame_count = states.get(self._frame_count_name, 0) frame_count = states.get(self._frame_count_name, tf.constant([0]))
frame_count = tf.cast(frame_count, tf.int32)
states[self._frame_count_name] = frame_count + num_frames states[self._frame_count_name] = frame_count + num_frames
if self._causal: if self._causal:
......
...@@ -48,8 +48,8 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -48,8 +48,8 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
initializer='ones', cache_encoding=True) initializer='ones', cache_encoding=True)
inputs = tf.ones([1, 4, 1, 1, 3]) inputs = tf.ones([1, 4, 1, 1, 3])
outputs = pos_encoding(inputs) outputs, _ = pos_encoding(inputs)
outputs_cached = pos_encoding_cached(inputs) outputs_cached, _ = pos_encoding_cached(inputs)
expected = tf.constant( expected = tf.constant(
[[[[[1.0000000, 1.0000000, 2.0000000]]], [[[[[1.0000000, 1.0000000, 2.0000000]]],
...@@ -70,7 +70,7 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -70,7 +70,7 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
pos_encoding = nn_layers.PositionalEncoding(initializer='ones') pos_encoding = nn_layers.PositionalEncoding(initializer='ones')
inputs = tf.ones([1, 4, 1, 1, 3], dtype=tf.bfloat16) inputs = tf.ones([1, 4, 1, 1, 3], dtype=tf.bfloat16)
outputs = pos_encoding(inputs) outputs, _ = pos_encoding(inputs)
expected = tf.constant( expected = tf.constant(
[[[[[1.0000000, 1.0000000, 2.0000000]]], [[[[[1.0000000, 1.0000000, 2.0000000]]],
...@@ -92,6 +92,31 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -92,6 +92,31 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(outputs.shape, expected.shape) self.assertEqual(outputs.shape, expected.shape)
self.assertAllEqual(outputs, expected) self.assertAllEqual(outputs, expected)
def test_positional_encoding_stream(self):
pos_encoding = nn_layers.PositionalEncoding(
initializer='ones', cache_encoding=False)
inputs = tf.range(4, dtype=tf.float32) + 1.
inputs = tf.reshape(inputs, [1, 4, 1, 1, 1])
inputs = tf.tile(inputs, [1, 1, 1, 1, 3])
expected, _ = pos_encoding(inputs)
for num_splits in [1, 2, 4]:
frames = tf.split(inputs, num_splits, axis=1)
states = {}
predicted = []
for frame in frames:
output, states = pos_encoding(frame, states=states)
predicted.append(output)
predicted = tf.concat(predicted, axis=1)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
self.assertAllClose(predicted, [[[[[1.0000000, 1.0000000, 2.0000000]]],
[[[2.8414710, 2.0021544, 2.5403023]]],
[[[3.9092975, 3.0043090, 2.5838532]]],
[[[4.1411200, 4.0064630, 3.0100074]]]]])
def test_global_average_pool_keras(self): def test_global_average_pool_keras(self):
pool = nn_layers.GlobalAveragePool3D(keepdims=False) pool = nn_layers.GlobalAveragePool3D(keepdims=False)
keras_pool = tf.keras.layers.GlobalAveragePooling3D() keras_pool = tf.keras.layers.GlobalAveragePooling3D()
......
...@@ -140,10 +140,10 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -140,10 +140,10 @@ class MaskRCNNModel(tf.keras.Model):
images: tf.Tensor, images: tf.Tensor,
image_shape: tf.Tensor, image_shape: tf.Tensor,
anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None, anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
gt_boxes: tf.Tensor = None, gt_boxes: Optional[tf.Tensor] = None,
gt_classes: tf.Tensor = None, gt_classes: Optional[tf.Tensor] = None,
gt_masks: tf.Tensor = None, gt_masks: Optional[tf.Tensor] = None,
training: bool = None) -> Mapping[str, tf.Tensor]: training: Optional[bool] = None) -> Mapping[str, tf.Tensor]:
model_outputs = {} model_outputs = {}
# Feature extraction. # Feature extraction.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment