Commit 87495b03 authored by Liangzhe Yuan's avatar Liangzhe Yuan Committed by A. Unique TensorFlower
Browse files

Enable to construct decoder only in DETRTransformer and refactor DETR model....

Enable to construct decoder only in DETRTransformer and refactor DETR model. Also fix a typo in DetectionTask.

PiperOrigin-RevId: 461376875
parent 730b778e
......@@ -54,6 +54,7 @@ class Losses(hyperparams.Config):
@dataclasses.dataclass
class Detr(hyperparams.Config):
"""Detr model definations."""
num_queries: int = 100
hidden_size: int = 256
num_classes: int = 91 # 0: background
......@@ -63,6 +64,7 @@ class Detr(hyperparams.Config):
backbone: backbones.Backbone = backbones.Backbone(
type='resnet', resnet=backbones.ResNet(model_id=50, bn_trainable=False))
norm_activation: common.NormActivation = common.NormActivation()
backbone_endpoint_name: str = '5'
@dataclasses.dataclass
......
......@@ -20,6 +20,8 @@ tf.train.Checkpoint for object based saving and loading and tf.saved_model.save
for graph serializaiton.
"""
import math
from typing import Any, List
import tensorflow as tf
from official.modeling import tf_utils
......@@ -101,6 +103,7 @@ class DETR(tf.keras.Model):
def __init__(self,
backbone,
backbone_endpoint_name,
num_queries,
hidden_size,
num_classes,
......@@ -118,10 +121,16 @@ class DETR(tf.keras.Model):
if hidden_size % 2 != 0:
raise ValueError("hidden_size must be a multiple of 2.")
self._backbone = backbone
self._backbone_endpoint_name = backbone_endpoint_name
def build(self, input_shape=None):
self._input_proj = tf.keras.layers.Conv2D(
self._hidden_size, 1, name="detr/conv2d")
self._build_detection_decoder()
super().build(input_shape)
def _build_detection_decoder(self):
"""Builds detection decoder."""
self._transformer = DETRTransformer(
num_encoder_layers=self._num_encoder_layers,
num_decoder_layers=self._num_decoder_layers,
......@@ -152,7 +161,6 @@ class DETR(tf.keras.Model):
-sqrt_k, sqrt_k),
name="detr/box_dense_2")]
self._sigmoid = tf.keras.layers.Activation("sigmoid")
super().build(input_shape)
@property
def backbone(self) -> tf.keras.Model:
......@@ -161,6 +169,7 @@ class DETR(tf.keras.Model):
def get_config(self):
return {
"backbone": self._backbone,
"backbone_endpoint_name": self._backbone_endpoint_name,
"num_queries": self._num_queries,
"hidden_size": self._hidden_size,
"num_classes": self._num_classes,
......@@ -173,15 +182,21 @@ class DETR(tf.keras.Model):
def from_config(cls, config):
return cls(**config)
def call(self, inputs):
batch_size = tf.shape(inputs)[0]
def _generate_image_mask(self, inputs: tf.Tensor,
target_shape: tf.Tensor) -> tf.Tensor:
"""Generates image mask from input image."""
mask = tf.expand_dims(
tf.cast(tf.not_equal(tf.reduce_sum(inputs, axis=-1), 0), inputs.dtype),
axis=-1)
features = self._backbone(inputs)["5"]
shape = tf.shape(features)
mask = tf.image.resize(
mask, shape[1:3], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
mask, target_shape, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
return mask
def call(self, inputs: tf.Tensor) -> List[Any]:
batch_size = tf.shape(inputs)[0]
features = self._backbone(inputs)[self._backbone_endpoint_name]
shape = tf.shape(features)
mask = self._generate_image_mask(inputs, shape[1: 3])
pos_embed = position_embedding_sine(
mask[:, :, :, 0], num_pos_features=self._hidden_size)
......@@ -225,13 +240,16 @@ class DETRTransformer(tf.keras.layers.Layer):
self._num_decoder_layers = num_decoder_layers
def build(self, input_shape=None):
self._encoder = transformer.TransformerEncoder(
attention_dropout_rate=self._dropout_rate,
dropout_rate=self._dropout_rate,
intermediate_dropout=self._dropout_rate,
norm_first=False,
num_layers=self._num_encoder_layers,
)
if self._num_encoder_layers > 0:
self._encoder = transformer.TransformerEncoder(
attention_dropout_rate=self._dropout_rate,
dropout_rate=self._dropout_rate,
intermediate_dropout=self._dropout_rate,
norm_first=False,
num_layers=self._num_encoder_layers)
else:
self._encoder = None
self._decoder = transformer.TransformerDecoder(
attention_dropout_rate=self._dropout_rate,
dropout_rate=self._dropout_rate,
......@@ -255,8 +273,12 @@ class DETRTransformer(tf.keras.layers.Layer):
input_shape = tf_utils.get_shape_list(sources)
source_attention_mask = tf.tile(
tf.expand_dims(mask, axis=1), [1, input_shape[1], 1])
memory = self._encoder(
sources, attention_mask=source_attention_mask, pos_embed=pos_embed)
if self._encoder is not None:
memory = self._encoder(
sources, attention_mask=source_attention_mask, pos_embed=pos_embed)
else:
memory = sources
target_shape = tf_utils.get_shape_list(targets)
cross_attention_mask = tf.tile(
tf.expand_dims(mask, axis=1), [1, target_shape[1], 1])
......
......@@ -27,7 +27,9 @@ class DetrTest(tf.test.TestCase):
image_size = 640
batch_size = 2
backbone = resnet.ResNet(50, bn_trainable=False)
model = detr.DETR(backbone, num_queries, hidden_size, num_classes)
backbone_endpoint_name = '5'
model = detr.DETR(backbone, backbone_endpoint_name, num_queries,
hidden_size, num_classes)
outs = model(tf.ones((batch_size, image_size, image_size, 3)))
self.assertLen(outs, 6) # intermediate decoded outputs.
for out in outs:
......@@ -50,6 +52,7 @@ class DetrTest(tf.test.TestCase):
def test_get_from_config_detr(self):
config = {
'backbone': resnet.ResNet(50, bn_trainable=False),
'backbone_endpoint_name': '5',
'num_queries': 2,
'hidden_size': 4,
'num_classes': 10,
......
......@@ -36,7 +36,7 @@ from official.vision.ops import box_ops
@task_factory.register_task_cls(detr_cfg.DetrTask)
class DectectionTask(base_task.Task):
class DetectionTask(base_task.Task):
"""A single-replica view of training procedure.
DETR task provides artifacts for training/evalution procedures, including
......@@ -55,7 +55,9 @@ class DectectionTask(base_task.Task):
backbone_config=self._task_config.model.backbone,
norm_activation_config=self._task_config.model.norm_activation)
model = detr.DETR(backbone, self._task_config.model.num_queries,
model = detr.DETR(backbone,
self._task_config.model.backbone_endpoint_name,
self._task_config.model.num_queries,
self._task_config.model.hidden_size,
self._task_config.model.num_classes,
self._task_config.model.num_encoder_layers,
......
......@@ -76,7 +76,7 @@ class DetectionTest(tf.test.TestCase):
global_batch_size=2,
))
with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
task = detection.DectectionTask(config)
task = detection.DetectionTask(config)
model = task.build_model()
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
......@@ -96,7 +96,7 @@ class DetectionTest(tf.test.TestCase):
}
},
})
optimizer = detection.DectectionTask.create_optimizer(opt_cfg)
optimizer = detection.DetectionTask.create_optimizer(opt_cfg)
task.train_step(next(iterator), model, optimizer)
def test_validation_step(self):
......@@ -118,7 +118,7 @@ class DetectionTest(tf.test.TestCase):
))
with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
task = detection.DectectionTask(config)
task = detection.DetectionTask(config)
model = task.build_model()
metrics = task.build_metrics(training=False)
dataset = task.build_inputs(config.validation_data)
......@@ -148,7 +148,7 @@ class DetectionTFDSTest(tf.test.TestCase):
global_batch_size=2,
))
with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
task = detection.DectectionTask(config)
task = detection.DetectionTask(config)
model = task.build_model()
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
......@@ -168,7 +168,7 @@ class DetectionTFDSTest(tf.test.TestCase):
}
},
})
optimizer = detection.DectectionTask.create_optimizer(opt_cfg)
optimizer = detection.DetectionTask.create_optimizer(opt_cfg)
task.train_step(next(iterator), model, optimizer)
def test_validation_step(self):
......@@ -190,7 +190,7 @@ class DetectionTFDSTest(tf.test.TestCase):
))
with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
task = detection.DectectionTask(config)
task = detection.DetectionTask(config)
model = task.build_model()
metrics = task.build_metrics(training=False)
dataset = task.build_inputs(config.validation_data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment