Unverified Commit cdd61f61 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'panoptic-segmentation' into panoptic-deeplab-modeling

parents 0225b135 a9322830
...@@ -1087,16 +1087,17 @@ class Encoder(Module): ...@@ -1087,16 +1087,17 @@ class Encoder(Module):
@tf.Module.with_name_scope @tf.Module.with_name_scope
def __call__(self, def __call__(self,
inputs, inputs=None,
encoder_mask=None, encoder_mask=None,
dense_inputs=None, dense_inputs=None,
training=False): training=False):
"""Applies Transformer model on the inputs. """Applies Transformer model on the inputs.
Args: Args:
inputs: input data inputs: input word ids. Optional if dense data are provided.
encoder_mask: the encoder self-attention mask. encoder_mask: the encoder self-attention mask.
dense_inputs: dense input data, concat after the embedding. dense_inputs: dense input data. Concat after the embedding if word ids
are provided.
training: whether it is training pass, affecting dropouts. training: whether it is training pass, affecting dropouts.
Returns: Returns:
...@@ -1106,16 +1107,27 @@ class Encoder(Module): ...@@ -1106,16 +1107,27 @@ class Encoder(Module):
if encoder_mask is not None: if encoder_mask is not None:
encoder_mask = tf.cast(encoder_mask, self.compute_dtype) encoder_mask = tf.cast(encoder_mask, self.compute_dtype)
cfg = self.config cfg = self.config
x = self.input_embed(inputs, one_hot=cfg.one_hot_embedding) inputs_array = []
if inputs is not None:
inputs_array.append(
self.input_embed(inputs, one_hot=cfg.one_hot_embedding))
if dense_inputs is not None: if dense_inputs is not None:
x = tf.concat([x, dense_inputs], axis=1) inputs_array.append(dense_inputs)
if not inputs_array:
raise ValueError("At least one of inputs and dense_inputs must not be "
"None.")
x = tf.concat(inputs_array, axis=1)
tensor_shape = tf_utils.get_shape_list(x) tensor_shape = tf_utils.get_shape_list(x)
tensor_shape[-2] = 1 tensor_shape[-2] = 1
x = self.input_dropout(x, noise_shape=tensor_shape, training=training) x = self.input_dropout(x, noise_shape=tensor_shape, training=training)
input_length = tf_utils.get_shape_list(inputs)[1] if inputs is not None:
input_length = tf_utils.get_shape_list(inputs)[1]
else:
input_length = 0
position_bias = self.relative_embedding(input_length, input_length) position_bias = self.relative_embedding(input_length, input_length)
if dense_inputs is not None: if dense_inputs is not None:
# Here we ignore relative position bias for dense embeddings. # Here we ignore relative position bias for dense embeddings.
# TODO(yejiayu): If we proceed to video use cases, rework this part.
dense_input_length = tf_utils.get_shape_list(dense_inputs)[1] dense_input_length = tf_utils.get_shape_list(dense_inputs)[1]
# Position bias shape: [batch, 1, len, len] # Position bias shape: [batch, 1, len, len]
paddings = tf.constant([[0, 0], [0, 0], [0, dense_input_length], paddings = tf.constant([[0, 0], [0, 0], [0, dense_input_length],
...@@ -1320,25 +1332,35 @@ class T5Transformer(Module): ...@@ -1320,25 +1332,35 @@ class T5Transformer(Module):
compute_dtype=self.compute_dtype) compute_dtype=self.compute_dtype)
def encode(self, def encode(self,
encoder_input_tokens, encoder_input_tokens=None,
encoder_segment_ids=None, encoder_segment_ids=None,
encoder_dense_inputs=None, encoder_dense_inputs=None,
encoder_dense_segment_ids=None, encoder_dense_segment_ids=None,
training=False): training=False):
eligible_positions = tf.cast( eligible_position_array = []
tf.not_equal(encoder_input_tokens, 0), self.compute_dtype) if encoder_input_tokens is not None:
eligible_position_array.append(
tf.cast(tf.not_equal(encoder_input_tokens, 0), self.compute_dtype))
if encoder_dense_inputs is not None: if encoder_dense_inputs is not None:
eligible_dense_position = tf.cast( eligible_dense_positions = tf.cast(
tf.reduce_any(tf.not_equal(encoder_dense_inputs, 0), axis=-1), tf.reduce_any(tf.not_equal(encoder_dense_inputs, 0), axis=-1),
self.compute_dtype) self.compute_dtype)
eligible_positions = tf.concat( eligible_position_array.append(eligible_dense_positions)
[eligible_positions, eligible_dense_position], axis=1) if not eligible_position_array:
raise ValueError("At least one of encoder_input_tokens and"
" encoder_dense_inputs must be provided.")
eligible_positions = tf.concat(eligible_position_array, axis=1)
encoder_mask = make_attention_mask( encoder_mask = make_attention_mask(
eligible_positions, eligible_positions, dtype=tf.bool) eligible_positions, eligible_positions, dtype=tf.bool)
encoder_segment_id_array = []
if encoder_segment_ids is not None: if encoder_segment_ids is not None:
if encoder_dense_segment_ids is not None: encoder_segment_id_array.append(encoder_segment_ids)
encoder_segment_ids = tf.concat( if encoder_dense_segment_ids is not None:
[encoder_segment_ids, encoder_dense_segment_ids], axis=1) encoder_segment_id_array.append(encoder_dense_segment_ids)
if encoder_segment_id_array:
encoder_segment_ids = tf.concat(encoder_segment_id_array, axis=1)
segment_mask = make_attention_mask( segment_mask = make_attention_mask(
encoder_segment_ids, encoder_segment_ids, tf.equal, dtype=tf.bool) encoder_segment_ids, encoder_segment_ids, tf.equal, dtype=tf.bool)
encoder_mask = tf.math.logical_and(encoder_mask, segment_mask) encoder_mask = tf.math.logical_and(encoder_mask, segment_mask)
...@@ -1353,7 +1375,7 @@ class T5Transformer(Module): ...@@ -1353,7 +1375,7 @@ class T5Transformer(Module):
self, self,
encoded, encoded,
decoder_target_tokens, decoder_target_tokens,
encoder_input_tokens, # only used for masks encoder_input_tokens=None, # only used for masks
encoder_dense_inputs=None, encoder_dense_inputs=None,
decoder_input_tokens=None, decoder_input_tokens=None,
encoder_segment_ids=None, encoder_segment_ids=None,
...@@ -1364,14 +1386,18 @@ class T5Transformer(Module): ...@@ -1364,14 +1386,18 @@ class T5Transformer(Module):
max_decode_len=None, max_decode_len=None,
decode=False, decode=False,
training=False): training=False):
eligible_inputs = tf.cast( eligible_inputs_array = []
tf.not_equal(encoder_input_tokens, 0), self.compute_dtype) if encoder_input_tokens is not None:
eligible_inputs = tf.cast(
tf.not_equal(encoder_input_tokens, 0), self.compute_dtype)
eligible_inputs_array.append(eligible_inputs)
if encoder_dense_inputs is not None: if encoder_dense_inputs is not None:
eligible_dense_inputs = tf.cast( eligible_dense_inputs = tf.cast(
tf.reduce_any(tf.not_equal(encoder_dense_inputs, 0), axis=-1), tf.reduce_any(tf.not_equal(encoder_dense_inputs, 0), axis=-1),
self.compute_dtype) self.compute_dtype)
eligible_inputs = tf.concat([eligible_inputs, eligible_dense_inputs], eligible_inputs_array.append(eligible_dense_inputs)
axis=1) eligible_inputs = tf.concat(eligible_inputs_array, axis=1)
if decode: if decode:
# For decoding, the decoder_input_tokens is the decoder_target_tokens. # For decoding, the decoder_input_tokens is the decoder_target_tokens.
decoder_input_tokens = decoder_target_tokens decoder_input_tokens = decoder_target_tokens
...@@ -1430,8 +1456,8 @@ class T5Transformer(Module): ...@@ -1430,8 +1456,8 @@ class T5Transformer(Module):
@tf.Module.with_name_scope @tf.Module.with_name_scope
def __call__(self, def __call__(self,
encoder_input_tokens, encoder_input_tokens=None,
decoder_target_tokens, decoder_target_tokens=None,
encoder_dense_inputs=None, encoder_dense_inputs=None,
encoder_dense_segment_ids=None, encoder_dense_segment_ids=None,
decoder_input_tokens=None, decoder_input_tokens=None,
...@@ -1456,7 +1482,7 @@ class T5Transformer(Module): ...@@ -1456,7 +1482,7 @@ class T5Transformer(Module):
a dictionary of logits/cache. a dictionary of logits/cache.
""" """
encoded = self.encode( encoded = self.encode(
encoder_input_tokens, encoder_input_tokens=encoder_input_tokens,
encoder_segment_ids=encoder_segment_ids, encoder_segment_ids=encoder_segment_ids,
encoder_dense_inputs=encoder_dense_inputs, encoder_dense_inputs=encoder_dense_inputs,
encoder_dense_segment_ids=encoder_dense_segment_ids, encoder_dense_segment_ids=encoder_dense_segment_ids,
......
...@@ -372,6 +372,22 @@ class T5Test(tf.test.TestCase, parameterized.TestCase): ...@@ -372,6 +372,22 @@ class T5Test(tf.test.TestCase, parameterized.TestCase):
dense_inputs=tf.ones((4, 2, 4), dtype=dtype)) dense_inputs=tf.ones((4, 2, 4), dtype=dtype))
self.assertEqual(encoded.shape, (4, 10, config.d_model)) self.assertEqual(encoded.shape, (4, 10, config.d_model))
@parameterized.named_parameters(("bfloat16", tf.bfloat16),
("float32", tf.float32))
def test_encoder_only_dense(self, dtype):
config = t5.T5TransformerParams(
num_layers=2,
d_model=4,
d_kv=3,
num_heads=4,
d_ff=16,
vocab_size=10,
vocab_embeddings_initializer=tf.keras.initializers.Ones(),
relative_embeddings_initializer=tf.keras.initializers.Ones())
encoder = t5.Encoder(config, compute_dtype=dtype)
encoded = encoder(dense_inputs=tf.ones((4, 2, 4), dtype=dtype))
self.assertEqual(encoded.shape, (4, 2, config.d_model))
def test_decoder(self): def test_decoder(self):
max_decode_len = 10 max_decode_len = 10
config = t5.T5TransformerParams( config = t5.T5TransformerParams(
...@@ -515,6 +531,58 @@ class T5Test(tf.test.TestCase, parameterized.TestCase): ...@@ -515,6 +531,58 @@ class T5Test(tf.test.TestCase, parameterized.TestCase):
print(v.name, v.shape) print(v.name, v.shape)
self.assertEqual(v.dtype, tf.float32) self.assertEqual(v.dtype, tf.float32)
@parameterized.named_parameters(
("t5_10", ("relu",), True, 26, False, tf.float32),)
def test_transformer_with_dense_only(self, ffn_activations,
logits_via_embedding,
expect_num_variables, layer_sharing,
dtype):
max_decode_len = 10
config = t5.T5TransformerParams(
num_layers=1,
d_model=8,
d_kv=4,
num_heads=4,
d_ff=32,
vocab_size=10,
shared_embedding=True,
layer_sharing=layer_sharing,
ffn_activations=ffn_activations,
logits_via_embedding=logits_via_embedding)
transformer = t5.T5Transformer(config, compute_dtype=dtype)
self.assertLen(transformer.trainable_variables, expect_num_variables)
decoder_inputs = tf.convert_to_tensor(
np.array([[2, 2, 1, 3, 1, 0], [3, 3, 1, 2, 2, 1]]))
decoder_segments = tf.convert_to_tensor(
np.array([[1, 1, 1, 2, 2, 0], [1, 1, 1, 2, 2, 2]]))
dense_inputs = tf.convert_to_tensor(np.random.randn(2, 2, 8), dtype=dtype)
dense_segments = tf.convert_to_tensor(np.array([[1, 2], [1, 2]]))
outputs = transformer(
encoder_dense_inputs=dense_inputs,
encoder_dense_segment_ids=dense_segments,
decoder_input_tokens=decoder_inputs,
decoder_target_tokens=decoder_inputs,
decoder_segment_ids=decoder_segments)
cache = {}
batch_size = 2
cache[0] = _create_cache(
batch_size, max_decode_len, config.num_heads, config.d_kv, dtype=dtype)
outputs = transformer.decode(
encoder_dense_inputs=dense_inputs,
encoded=outputs["encoded"],
decoder_target_tokens=tf.ones((batch_size, 1), dtype=tf.int32),
decode_position=1,
decode=True,
max_decode_len=max_decode_len,
cache=cache)
self.assertEqual(outputs["logits"].shape,
(batch_size, 1, config.vocab_size))
for v in transformer.trainable_variables:
print(v.name, v.shape)
self.assertEqual(v.dtype, tf.float32)
@parameterized.named_parameters( @parameterized.named_parameters(
("t5_10", ("relu",), True, 39, tf.float32, 2), ("t5_10", ("relu",), True, 39, tf.float32, 2),
("t5_10_bfloat16", ("relu",), True, 39, tf.bfloat16, 2)) ("t5_10_bfloat16", ("relu",), True, 39, tf.bfloat16, 2))
......
...@@ -24,8 +24,8 @@ from typing import Any, Mapping, Optional ...@@ -24,8 +24,8 @@ from typing import Any, Mapping, Optional
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import optimization from official.modeling import optimization
from official.vision.beta.configs import common from official.vision.configs import common
from official.vision.beta.configs import image_classification as base_config from official.vision.configs import image_classification as base_config
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -25,10 +25,10 @@ from official.core import config_definitions as cfg ...@@ -25,10 +25,10 @@ from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling import optimization from official.modeling import optimization
from official.vision.beta.configs import backbones from official.vision.configs import backbones
from official.vision.beta.configs import common from official.vision.configs import common
from official.vision.beta.configs import decoders from official.vision.configs import decoders
from official.vision.beta.configs import semantic_segmentation as base_cfg from official.vision.configs import semantic_segmentation as base_cfg
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -26,8 +26,8 @@ from official.core import config_definitions as cfg ...@@ -26,8 +26,8 @@ from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling import optimization from official.modeling import optimization
from official.vision.beta.configs import backbones from official.vision.configs import backbones
from official.vision.beta.configs import semantic_segmentation as base_cfg from official.vision.configs import semantic_segmentation as base_cfg
# ADE 20K Dataset # ADE 20K Dataset
ADE20K_TRAIN_EXAMPLES = 20210 ADE20K_TRAIN_EXAMPLES = 20210
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
from official.vision.beta.dataloaders import classification_input from official.vision.dataloaders import classification_input
from official.vision.beta.ops import preprocess_ops from official.vision.ops import preprocess_ops
MEAN_RGB = (0.5 * 255, 0.5 * 255, 0.5 * 255) MEAN_RGB = (0.5 * 255, 0.5 * 255, 0.5 * 255)
STDDEV_RGB = (0.5 * 255, 0.5 * 255, 0.5 * 255) STDDEV_RGB = (0.5 * 255, 0.5 * 255, 0.5 * 255)
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from official.projects.edgetpu.vision.dataloaders import classification_input from official.projects.edgetpu.vision.dataloaders import classification_input
from official.vision.beta.configs import common from official.vision.configs import common
from official.vision.beta.dataloaders import tfexample_utils from official.vision.dataloaders import tfexample_utils
IMAGE_FIELD_KEY = 'image/encoded' IMAGE_FIELD_KEY = 'image/encoded'
LABEL_FIELD_KEY = 'image/class/label' LABEL_FIELD_KEY = 'image/class/label'
......
...@@ -22,7 +22,7 @@ import tensorflow as tf ...@@ -22,7 +22,7 @@ import tensorflow as tf
from official.modeling import hyperparams from official.modeling import hyperparams
from official.projects.edgetpu.vision.modeling.mobilenet_edgetpu_v1_model import MobilenetEdgeTPU from official.projects.edgetpu.vision.modeling.mobilenet_edgetpu_v1_model import MobilenetEdgeTPU
from official.projects.edgetpu.vision.modeling.mobilenet_edgetpu_v2_model import MobilenetEdgeTPUV2 from official.projects.edgetpu.vision.modeling.mobilenet_edgetpu_v2_model import MobilenetEdgeTPUV2
from official.vision.beta.modeling.backbones import factory from official.vision.modeling.backbones import factory
layers = tf.keras.layers layers = tf.keras.layers
......
...@@ -31,7 +31,7 @@ from official.projects.edgetpu.vision.modeling import custom_layers ...@@ -31,7 +31,7 @@ from official.projects.edgetpu.vision.modeling import custom_layers
from official.projects.edgetpu.vision.modeling.backbones import mobilenet_edgetpu from official.projects.edgetpu.vision.modeling.backbones import mobilenet_edgetpu
from official.projects.edgetpu.vision.tasks import image_classification from official.projects.edgetpu.vision.tasks import image_classification
from official.projects.edgetpu.vision.tasks import semantic_segmentation as edgetpu_semantic_segmentation from official.projects.edgetpu.vision.tasks import semantic_segmentation as edgetpu_semantic_segmentation
from official.vision.beta.tasks import semantic_segmentation from official.vision.tasks import semantic_segmentation
# pylint: enable=unused-import # pylint: enable=unused-import
MEAN_RGB = [127.5, 127.5, 127.5] MEAN_RGB = [127.5, 127.5, 127.5]
......
...@@ -28,8 +28,8 @@ from official.projects.edgetpu.vision.configs import mobilenet_edgetpu_config as ...@@ -28,8 +28,8 @@ from official.projects.edgetpu.vision.configs import mobilenet_edgetpu_config as
from official.projects.edgetpu.vision.dataloaders import classification_input from official.projects.edgetpu.vision.dataloaders import classification_input
from official.projects.edgetpu.vision.modeling import mobilenet_edgetpu_v1_model from official.projects.edgetpu.vision.modeling import mobilenet_edgetpu_v1_model
from official.projects.edgetpu.vision.modeling import mobilenet_edgetpu_v2_model from official.projects.edgetpu.vision.modeling import mobilenet_edgetpu_v2_model
from official.vision.beta.configs import image_classification as base_cfg from official.vision.configs import image_classification as base_cfg
from official.vision.beta.dataloaders import input_reader_factory from official.vision.dataloaders import input_reader_factory
def _copy_recursively(src: str, dst: str) -> None: def _copy_recursively(src: str, dst: str) -> None:
......
...@@ -20,11 +20,11 @@ from absl.testing import parameterized ...@@ -20,11 +20,11 @@ from absl.testing import parameterized
import orbit import orbit
import tensorflow as tf import tensorflow as tf
from official.common import registry_imports
from official.core import exp_factory from official.core import exp_factory
from official.modeling import optimization from official.modeling import optimization
from official.projects.edgetpu.vision.configs import mobilenet_edgetpu_config from official.projects.edgetpu.vision.configs import mobilenet_edgetpu_config
from official.projects.edgetpu.vision.tasks import image_classification from official.projects.edgetpu.vision.tasks import image_classification
from official.vision import registry_imports
# Dummy ImageNet TF dataset. # Dummy ImageNet TF dataset.
......
...@@ -27,11 +27,11 @@ from official.projects.edgetpu.vision.modeling import mobilenet_edgetpu_v1_model ...@@ -27,11 +27,11 @@ from official.projects.edgetpu.vision.modeling import mobilenet_edgetpu_v1_model
from official.projects.edgetpu.vision.modeling import mobilenet_edgetpu_v2_model from official.projects.edgetpu.vision.modeling import mobilenet_edgetpu_v2_model
from official.projects.edgetpu.vision.modeling.backbones import mobilenet_edgetpu # pylint: disable=unused-import from official.projects.edgetpu.vision.modeling.backbones import mobilenet_edgetpu # pylint: disable=unused-import
from official.projects.edgetpu.vision.modeling.heads import bifpn_head from official.projects.edgetpu.vision.modeling.heads import bifpn_head
from official.vision.beta.dataloaders import input_reader_factory from official.vision.dataloaders import input_reader_factory
from official.vision.beta.dataloaders import segmentation_input from official.vision.dataloaders import segmentation_input
from official.vision.beta.dataloaders import tfds_factory from official.vision.dataloaders import tfds_factory
from official.vision.beta.ops import preprocess_ops from official.vision.ops import preprocess_ops
from official.vision.beta.tasks import semantic_segmentation from official.vision.tasks import semantic_segmentation
class ClassMappingParser(segmentation_input.Parser): class ClassMappingParser(segmentation_input.Parser):
......
...@@ -19,9 +19,6 @@ from absl import app ...@@ -19,9 +19,6 @@ from absl import app
from absl import flags from absl import flags
import gin import gin
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import distribute_utils from official.common import distribute_utils
from official.common import flags as tfm_flags from official.common import flags as tfm_flags
from official.core import task_factory from official.core import task_factory
...@@ -35,6 +32,7 @@ from official.projects.edgetpu.vision.configs import semantic_segmentation_searc ...@@ -35,6 +32,7 @@ from official.projects.edgetpu.vision.configs import semantic_segmentation_searc
from official.projects.edgetpu.vision.modeling.backbones import mobilenet_edgetpu from official.projects.edgetpu.vision.modeling.backbones import mobilenet_edgetpu
from official.projects.edgetpu.vision.tasks import image_classification from official.projects.edgetpu.vision.tasks import image_classification
from official.projects.edgetpu.vision.tasks import semantic_segmentation from official.projects.edgetpu.vision.tasks import semantic_segmentation
from official.vision import registry_imports
# pylint: enable=unused-import # pylint: enable=unused-import
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -28,7 +28,7 @@ $ python3 train.py \ ...@@ -28,7 +28,7 @@ $ python3 train.py \
--mode=train_and_eval --mode=train_and_eval
``` ```
## Model Accuracy ## Image Classification
<figure align="center"> <figure align="center">
<img width=70% src=https://storage.googleapis.com/tf_model_garden/models/qat/images/readme-qat-classification-plot.png> <img width=70% src=https://storage.googleapis.com/tf_model_garden/models/qat/images/readme-qat-classification-plot.png>
...@@ -46,3 +46,18 @@ Note: The Top-1 model accuracy is measured on the validation set of [ImageNet](h ...@@ -46,3 +46,18 @@ Note: The Top-1 model accuracy is measured on the validation set of [ImageNet](h
|ResNet50 |224x224 |76.710% |76.420% |77.200% |[config](https://github.com/tensorflow/models/blob/master/official/projects/qat/vision/configs/experiments/image_classification/imagenet_resnet50_qat_gpu.yaml) |[TFLite(Int8/QAT)](https://storage.googleapis.com/tf_model_garden/vision/resnet50_imagenet/resnet_50_224_int8.tflite) | |ResNet50 |224x224 |76.710% |76.420% |77.200% |[config](https://github.com/tensorflow/models/blob/master/official/projects/qat/vision/configs/experiments/image_classification/imagenet_resnet50_qat_gpu.yaml) |[TFLite(Int8/QAT)](https://storage.googleapis.com/tf_model_garden/vision/resnet50_imagenet/resnet_50_224_int8.tflite) |
|MobileNetV3.5 MultiAVG|224x224 |75.212% |74.122% |75.130% |[config](https://github.com/tensorflow/models/blob/master/official/projects/qat/vision/configs/experiments/image_classification/imagenet_mobilenetv3.5_qat_gpu.yaml)|[TFLite(Int8/QAT)](https://storage.googleapis.com/tf_model_garden/vision/mobilenet/v3.5multiavg_1.0_int8/mobilenet_v3.5multiavg_1.00_224_int8.tflite)| |MobileNetV3.5 MultiAVG|224x224 |75.212% |74.122% |75.130% |[config](https://github.com/tensorflow/models/blob/master/official/projects/qat/vision/configs/experiments/image_classification/imagenet_mobilenetv3.5_qat_gpu.yaml)|[TFLite(Int8/QAT)](https://storage.googleapis.com/tf_model_garden/vision/mobilenet/v3.5multiavg_1.0_int8/mobilenet_v3.5multiavg_1.00_224_int8.tflite)|
## Semantic Segmentation
Model is pretrained using COCO train set. Two datasets, Pascal VOC segmentation
dataset and Cityscapes dataset (only for DeepLab v3+), are used to train and
evaluate models. Model accuracy is measured on full Pascal VOC segmentation
validation set.
### Pre-trained Models
model | resolution | mIoU | mIoU (FP32) | mIoU (FP16) | mIoU (INT8) | mIoU (QAT INT8) | download (tflite)
:------------------------- | :--------: | ----: | ----------: | ----------: | ----------: | --------------: | ------------------------------------------------------: | ------------------------------------------------------: | -------------------------------------------------------: | ------------------------------------------------------: | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ----------------:
MobileNet v2 + DeepLab v3 | 512x512 | 75.27 | 75.30 | 75.32 | 73.95 | 74.68 | [FP32](https://storage.googleapis.com/tf_model_garden/vision/qat/deeplabv3_mobilenetv2_pascal_coco_0.21/model_none.tflite) \| [FP16](https://storage.googleapis.com/tf_model_garden/vision/qat/deeplabv3_mobilenetv2_pascal_coco_0.21/model_fp16.tflite) \| [INT8](https://storage.googleapis.com/tf_model_garden/vision/qat/deeplabv3_mobilenetv2_pascal_coco_0.21model_int8_full.tflite) \| [QAT INT8](https://storage.googleapis.com/tf_model_garden/vision/qat/deeplabv3_mobilenetv2_pascal_coco_0.21/Fmodel_default.tflite)
MobileNet v2 + DeepLab v3+ | 1024x2048 | 73.82 | 73.84 | 73.65 | 72.33 | 73.49 | [FP32](https://storage.googleapis.com/tf_model_garden/vision/qat/mnv2_deeplabv3plus_cityscapes/model_none.tflite) \| [FP16](https://storage.googleapis.com/tf_model_garden/vision/qat/mnv2_deeplabv3plus_cityscapes/Fmodel_fp16.tflite) \| [INT8](https://storage.googleapis.com/tf_model_garden/vision/qat/mnv2_deeplabv3plus_cityscapes/model_int8_full.tflite) \| [QAT INT8](https://storage.googleapis.com/tf_model_garden/vision/qat/mnv2_deeplabv3plus_cityscapes/Fmodel_default.tflite)
...@@ -21,7 +21,7 @@ from typing import Optional ...@@ -21,7 +21,7 @@ from typing import Optional
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.projects.qat.vision.configs import common from official.projects.qat.vision.configs import common
from official.vision.beta.configs import image_classification from official.vision.configs import image_classification
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -22,7 +22,7 @@ from official.core import exp_factory ...@@ -22,7 +22,7 @@ from official.core import exp_factory
from official.projects.qat.vision.configs import common from official.projects.qat.vision.configs import common
from official.projects.qat.vision.configs import image_classification as qat_exp_cfg from official.projects.qat.vision.configs import image_classification as qat_exp_cfg
from official.vision import beta from official.vision import beta
from official.vision.beta.configs import image_classification as exp_cfg from official.vision.configs import image_classification as exp_cfg
class ImageClassificationConfigTest(tf.test.TestCase, parameterized.TestCase): class ImageClassificationConfigTest(tf.test.TestCase, parameterized.TestCase):
......
...@@ -20,8 +20,8 @@ from typing import Optional ...@@ -20,8 +20,8 @@ from typing import Optional
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.projects.qat.vision.configs import common from official.projects.qat.vision.configs import common
from official.vision.beta.configs import retinanet from official.vision.configs import retinanet
from official.vision.beta.configs.google import backbones from official.vision.configs.google import backbones
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -22,7 +22,7 @@ from official.core import exp_factory ...@@ -22,7 +22,7 @@ from official.core import exp_factory
from official.projects.qat.vision.configs import common from official.projects.qat.vision.configs import common
from official.projects.qat.vision.configs import retinanet as qat_exp_cfg from official.projects.qat.vision.configs import retinanet as qat_exp_cfg
from official.vision import beta from official.vision import beta
from official.vision.beta.configs import retinanet as exp_cfg from official.vision.configs import retinanet as exp_cfg
class RetinaNetConfigTest(tf.test.TestCase, parameterized.TestCase): class RetinaNetConfigTest(tf.test.TestCase, parameterized.TestCase):
......
...@@ -20,7 +20,7 @@ from typing import Optional ...@@ -20,7 +20,7 @@ from typing import Optional
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.projects.qat.vision.configs import common from official.projects.qat.vision.configs import common
from official.vision.beta.configs import semantic_segmentation from official.vision.configs import semantic_segmentation
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -22,7 +22,7 @@ from official.core import exp_factory ...@@ -22,7 +22,7 @@ from official.core import exp_factory
from official.projects.qat.vision.configs import common from official.projects.qat.vision.configs import common
from official.projects.qat.vision.configs import semantic_segmentation as qat_exp_cfg from official.projects.qat.vision.configs import semantic_segmentation as qat_exp_cfg
from official.vision import beta from official.vision import beta
from official.vision.beta.configs import semantic_segmentation as exp_cfg from official.vision.configs import semantic_segmentation as exp_cfg
class SemanticSegmentationConfigTest(tf.test.TestCase, parameterized.TestCase): class SemanticSegmentationConfigTest(tf.test.TestCase, parameterized.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment