Commit 87495b03 authored by Liangzhe Yuan's avatar Liangzhe Yuan Committed by A. Unique TensorFlower
Browse files

Enable to construct decoder only in DETRTransformer and refactor DETR model....

Enable to construct decoder only in DETRTransformer and refactor DETR model. Also fix a typo in DetectionTask.

PiperOrigin-RevId: 461376875
parent 730b778e
...@@ -54,6 +54,7 @@ class Losses(hyperparams.Config): ...@@ -54,6 +54,7 @@ class Losses(hyperparams.Config):
@dataclasses.dataclass @dataclasses.dataclass
class Detr(hyperparams.Config): class Detr(hyperparams.Config):
"""Detr model definations."""
num_queries: int = 100 num_queries: int = 100
hidden_size: int = 256 hidden_size: int = 256
num_classes: int = 91 # 0: background num_classes: int = 91 # 0: background
...@@ -63,6 +64,7 @@ class Detr(hyperparams.Config): ...@@ -63,6 +64,7 @@ class Detr(hyperparams.Config):
backbone: backbones.Backbone = backbones.Backbone( backbone: backbones.Backbone = backbones.Backbone(
type='resnet', resnet=backbones.ResNet(model_id=50, bn_trainable=False)) type='resnet', resnet=backbones.ResNet(model_id=50, bn_trainable=False))
norm_activation: common.NormActivation = common.NormActivation() norm_activation: common.NormActivation = common.NormActivation()
backbone_endpoint_name: str = '5'
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -20,6 +20,8 @@ tf.train.Checkpoint for object based saving and loading and tf.saved_model.save ...@@ -20,6 +20,8 @@ tf.train.Checkpoint for object based saving and loading and tf.saved_model.save
for graph serializaiton. for graph serializaiton.
""" """
import math import math
from typing import Any, List
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
...@@ -101,6 +103,7 @@ class DETR(tf.keras.Model): ...@@ -101,6 +103,7 @@ class DETR(tf.keras.Model):
def __init__(self, def __init__(self,
backbone, backbone,
backbone_endpoint_name,
num_queries, num_queries,
hidden_size, hidden_size,
num_classes, num_classes,
...@@ -118,10 +121,16 @@ class DETR(tf.keras.Model): ...@@ -118,10 +121,16 @@ class DETR(tf.keras.Model):
if hidden_size % 2 != 0: if hidden_size % 2 != 0:
raise ValueError("hidden_size must be a multiple of 2.") raise ValueError("hidden_size must be a multiple of 2.")
self._backbone = backbone self._backbone = backbone
self._backbone_endpoint_name = backbone_endpoint_name
def build(self, input_shape=None): def build(self, input_shape=None):
self._input_proj = tf.keras.layers.Conv2D( self._input_proj = tf.keras.layers.Conv2D(
self._hidden_size, 1, name="detr/conv2d") self._hidden_size, 1, name="detr/conv2d")
self._build_detection_decoder()
super().build(input_shape)
def _build_detection_decoder(self):
"""Builds detection decoder."""
self._transformer = DETRTransformer( self._transformer = DETRTransformer(
num_encoder_layers=self._num_encoder_layers, num_encoder_layers=self._num_encoder_layers,
num_decoder_layers=self._num_decoder_layers, num_decoder_layers=self._num_decoder_layers,
...@@ -152,7 +161,6 @@ class DETR(tf.keras.Model): ...@@ -152,7 +161,6 @@ class DETR(tf.keras.Model):
-sqrt_k, sqrt_k), -sqrt_k, sqrt_k),
name="detr/box_dense_2")] name="detr/box_dense_2")]
self._sigmoid = tf.keras.layers.Activation("sigmoid") self._sigmoid = tf.keras.layers.Activation("sigmoid")
super().build(input_shape)
@property @property
def backbone(self) -> tf.keras.Model: def backbone(self) -> tf.keras.Model:
...@@ -161,6 +169,7 @@ class DETR(tf.keras.Model): ...@@ -161,6 +169,7 @@ class DETR(tf.keras.Model):
def get_config(self): def get_config(self):
return { return {
"backbone": self._backbone, "backbone": self._backbone,
"backbone_endpoint_name": self._backbone_endpoint_name,
"num_queries": self._num_queries, "num_queries": self._num_queries,
"hidden_size": self._hidden_size, "hidden_size": self._hidden_size,
"num_classes": self._num_classes, "num_classes": self._num_classes,
...@@ -173,15 +182,21 @@ class DETR(tf.keras.Model): ...@@ -173,15 +182,21 @@ class DETR(tf.keras.Model):
def from_config(cls, config): def from_config(cls, config):
return cls(**config) return cls(**config)
def call(self, inputs): def _generate_image_mask(self, inputs: tf.Tensor,
batch_size = tf.shape(inputs)[0] target_shape: tf.Tensor) -> tf.Tensor:
"""Generates image mask from input image."""
mask = tf.expand_dims( mask = tf.expand_dims(
tf.cast(tf.not_equal(tf.reduce_sum(inputs, axis=-1), 0), inputs.dtype), tf.cast(tf.not_equal(tf.reduce_sum(inputs, axis=-1), 0), inputs.dtype),
axis=-1) axis=-1)
features = self._backbone(inputs)["5"]
shape = tf.shape(features)
mask = tf.image.resize( mask = tf.image.resize(
mask, shape[1:3], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) mask, target_shape, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
return mask
def call(self, inputs: tf.Tensor) -> List[Any]:
batch_size = tf.shape(inputs)[0]
features = self._backbone(inputs)[self._backbone_endpoint_name]
shape = tf.shape(features)
mask = self._generate_image_mask(inputs, shape[1: 3])
pos_embed = position_embedding_sine( pos_embed = position_embedding_sine(
mask[:, :, :, 0], num_pos_features=self._hidden_size) mask[:, :, :, 0], num_pos_features=self._hidden_size)
...@@ -225,13 +240,16 @@ class DETRTransformer(tf.keras.layers.Layer): ...@@ -225,13 +240,16 @@ class DETRTransformer(tf.keras.layers.Layer):
self._num_decoder_layers = num_decoder_layers self._num_decoder_layers = num_decoder_layers
def build(self, input_shape=None): def build(self, input_shape=None):
if self._num_encoder_layers > 0:
self._encoder = transformer.TransformerEncoder( self._encoder = transformer.TransformerEncoder(
attention_dropout_rate=self._dropout_rate, attention_dropout_rate=self._dropout_rate,
dropout_rate=self._dropout_rate, dropout_rate=self._dropout_rate,
intermediate_dropout=self._dropout_rate, intermediate_dropout=self._dropout_rate,
norm_first=False, norm_first=False,
num_layers=self._num_encoder_layers, num_layers=self._num_encoder_layers)
) else:
self._encoder = None
self._decoder = transformer.TransformerDecoder( self._decoder = transformer.TransformerDecoder(
attention_dropout_rate=self._dropout_rate, attention_dropout_rate=self._dropout_rate,
dropout_rate=self._dropout_rate, dropout_rate=self._dropout_rate,
...@@ -255,8 +273,12 @@ class DETRTransformer(tf.keras.layers.Layer): ...@@ -255,8 +273,12 @@ class DETRTransformer(tf.keras.layers.Layer):
input_shape = tf_utils.get_shape_list(sources) input_shape = tf_utils.get_shape_list(sources)
source_attention_mask = tf.tile( source_attention_mask = tf.tile(
tf.expand_dims(mask, axis=1), [1, input_shape[1], 1]) tf.expand_dims(mask, axis=1), [1, input_shape[1], 1])
if self._encoder is not None:
memory = self._encoder( memory = self._encoder(
sources, attention_mask=source_attention_mask, pos_embed=pos_embed) sources, attention_mask=source_attention_mask, pos_embed=pos_embed)
else:
memory = sources
target_shape = tf_utils.get_shape_list(targets) target_shape = tf_utils.get_shape_list(targets)
cross_attention_mask = tf.tile( cross_attention_mask = tf.tile(
tf.expand_dims(mask, axis=1), [1, target_shape[1], 1]) tf.expand_dims(mask, axis=1), [1, target_shape[1], 1])
......
...@@ -27,7 +27,9 @@ class DetrTest(tf.test.TestCase): ...@@ -27,7 +27,9 @@ class DetrTest(tf.test.TestCase):
image_size = 640 image_size = 640
batch_size = 2 batch_size = 2
backbone = resnet.ResNet(50, bn_trainable=False) backbone = resnet.ResNet(50, bn_trainable=False)
model = detr.DETR(backbone, num_queries, hidden_size, num_classes) backbone_endpoint_name = '5'
model = detr.DETR(backbone, backbone_endpoint_name, num_queries,
hidden_size, num_classes)
outs = model(tf.ones((batch_size, image_size, image_size, 3))) outs = model(tf.ones((batch_size, image_size, image_size, 3)))
self.assertLen(outs, 6) # intermediate decoded outputs. self.assertLen(outs, 6) # intermediate decoded outputs.
for out in outs: for out in outs:
...@@ -50,6 +52,7 @@ class DetrTest(tf.test.TestCase): ...@@ -50,6 +52,7 @@ class DetrTest(tf.test.TestCase):
def test_get_from_config_detr(self): def test_get_from_config_detr(self):
config = { config = {
'backbone': resnet.ResNet(50, bn_trainable=False), 'backbone': resnet.ResNet(50, bn_trainable=False),
'backbone_endpoint_name': '5',
'num_queries': 2, 'num_queries': 2,
'hidden_size': 4, 'hidden_size': 4,
'num_classes': 10, 'num_classes': 10,
......
...@@ -36,7 +36,7 @@ from official.vision.ops import box_ops ...@@ -36,7 +36,7 @@ from official.vision.ops import box_ops
@task_factory.register_task_cls(detr_cfg.DetrTask) @task_factory.register_task_cls(detr_cfg.DetrTask)
class DectectionTask(base_task.Task): class DetectionTask(base_task.Task):
"""A single-replica view of training procedure. """A single-replica view of training procedure.
DETR task provides artifacts for training/evalution procedures, including DETR task provides artifacts for training/evalution procedures, including
...@@ -55,7 +55,9 @@ class DectectionTask(base_task.Task): ...@@ -55,7 +55,9 @@ class DectectionTask(base_task.Task):
backbone_config=self._task_config.model.backbone, backbone_config=self._task_config.model.backbone,
norm_activation_config=self._task_config.model.norm_activation) norm_activation_config=self._task_config.model.norm_activation)
model = detr.DETR(backbone, self._task_config.model.num_queries, model = detr.DETR(backbone,
self._task_config.model.backbone_endpoint_name,
self._task_config.model.num_queries,
self._task_config.model.hidden_size, self._task_config.model.hidden_size,
self._task_config.model.num_classes, self._task_config.model.num_classes,
self._task_config.model.num_encoder_layers, self._task_config.model.num_encoder_layers,
......
...@@ -76,7 +76,7 @@ class DetectionTest(tf.test.TestCase): ...@@ -76,7 +76,7 @@ class DetectionTest(tf.test.TestCase):
global_batch_size=2, global_batch_size=2,
)) ))
with tfds.testing.mock_data(as_dataset_fn=_as_dataset): with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
task = detection.DectectionTask(config) task = detection.DetectionTask(config)
model = task.build_model() model = task.build_model()
dataset = task.build_inputs(config.train_data) dataset = task.build_inputs(config.train_data)
iterator = iter(dataset) iterator = iter(dataset)
...@@ -96,7 +96,7 @@ class DetectionTest(tf.test.TestCase): ...@@ -96,7 +96,7 @@ class DetectionTest(tf.test.TestCase):
} }
}, },
}) })
optimizer = detection.DectectionTask.create_optimizer(opt_cfg) optimizer = detection.DetectionTask.create_optimizer(opt_cfg)
task.train_step(next(iterator), model, optimizer) task.train_step(next(iterator), model, optimizer)
def test_validation_step(self): def test_validation_step(self):
...@@ -118,7 +118,7 @@ class DetectionTest(tf.test.TestCase): ...@@ -118,7 +118,7 @@ class DetectionTest(tf.test.TestCase):
)) ))
with tfds.testing.mock_data(as_dataset_fn=_as_dataset): with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
task = detection.DectectionTask(config) task = detection.DetectionTask(config)
model = task.build_model() model = task.build_model()
metrics = task.build_metrics(training=False) metrics = task.build_metrics(training=False)
dataset = task.build_inputs(config.validation_data) dataset = task.build_inputs(config.validation_data)
...@@ -148,7 +148,7 @@ class DetectionTFDSTest(tf.test.TestCase): ...@@ -148,7 +148,7 @@ class DetectionTFDSTest(tf.test.TestCase):
global_batch_size=2, global_batch_size=2,
)) ))
with tfds.testing.mock_data(as_dataset_fn=_as_dataset): with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
task = detection.DectectionTask(config) task = detection.DetectionTask(config)
model = task.build_model() model = task.build_model()
dataset = task.build_inputs(config.train_data) dataset = task.build_inputs(config.train_data)
iterator = iter(dataset) iterator = iter(dataset)
...@@ -168,7 +168,7 @@ class DetectionTFDSTest(tf.test.TestCase): ...@@ -168,7 +168,7 @@ class DetectionTFDSTest(tf.test.TestCase):
} }
}, },
}) })
optimizer = detection.DectectionTask.create_optimizer(opt_cfg) optimizer = detection.DetectionTask.create_optimizer(opt_cfg)
task.train_step(next(iterator), model, optimizer) task.train_step(next(iterator), model, optimizer)
def test_validation_step(self): def test_validation_step(self):
...@@ -190,7 +190,7 @@ class DetectionTFDSTest(tf.test.TestCase): ...@@ -190,7 +190,7 @@ class DetectionTFDSTest(tf.test.TestCase):
)) ))
with tfds.testing.mock_data(as_dataset_fn=_as_dataset): with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
task = detection.DectectionTask(config) task = detection.DetectionTask(config)
model = task.build_model() model = task.build_model()
metrics = task.build_metrics(training=False) metrics = task.build_metrics(training=False)
dataset = task.build_inputs(config.validation_data) dataset = task.build_inputs(config.validation_data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment