Commit f7783e7a authored by Gunho Park's avatar Gunho Park
Browse files

Use backbone factory

parent 14a9701d
...@@ -62,6 +62,7 @@ class Losses(hyperparams.Config): ...@@ -62,6 +62,7 @@ class Losses(hyperparams.Config):
lambda_box: float = 5.0 lambda_box: float = 5.0
lambda_giou: float = 2.0 lambda_giou: float = 2.0
background_cls_weight: float = 0.1 background_cls_weight: float = 0.1
l2_weight_decay: float = 1e-4
@dataclasses.dataclass @dataclasses.dataclass
class Detr(hyperparams.Config): class Detr(hyperparams.Config):
...@@ -73,7 +74,7 @@ class Detr(hyperparams.Config): ...@@ -73,7 +74,7 @@ class Detr(hyperparams.Config):
input_size: List[int] = dataclasses.field(default_factory=list) input_size: List[int] = dataclasses.field(default_factory=list)
backbone: backbones.Backbone = backbones.Backbone( backbone: backbones.Backbone = backbones.Backbone(
type='resnet', resnet=backbones.ResNet( type='resnet', resnet=backbones.ResNet(
model_id=101, model_id=50,
bn_trainable=False)) bn_trainable=False))
norm_activation: common.NormActivation = common.NormActivation() norm_activation: common.NormActivation = common.NormActivation()
...@@ -105,7 +106,7 @@ def detr_coco() -> cfg.ExperimentConfig: ...@@ -105,7 +106,7 @@ def detr_coco() -> cfg.ExperimentConfig:
decay_at = train_steps - 100 * steps_per_epoch # 400 epochs decay_at = train_steps - 100 * steps_per_epoch # 400 epochs
config = cfg.ExperimentConfig( config = cfg.ExperimentConfig(
task=DetrTask( task=DetrTask(
init_checkpoint='gs://ghpark-imagenet-tfrecord/ckpt/resnet101_imagenet', init_checkpoint='gs://ghpark-imagenet-tfrecord/ckpt/resnet50_imagenet',
init_checkpoint_modules='backbone', init_checkpoint_modules='backbone',
annotation_file=os.path.join(COCO_INPUT_PATH_BASE, annotation_file=os.path.join(COCO_INPUT_PATH_BASE,
'instances_val2017.json'), 'instances_val2017.json'),
......
...@@ -24,7 +24,7 @@ import tensorflow as tf ...@@ -24,7 +24,7 @@ import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.projects.detr.modeling import transformer from official.projects.detr.modeling import transformer
#from official.vision.modeling.backbones import resnet from official.vision.modeling.backbones import resnet
def position_embedding_sine(attention_mask, def position_embedding_sine(attention_mask,
...@@ -116,7 +116,7 @@ class DETR(tf.keras.Model): ...@@ -116,7 +116,7 @@ class DETR(tf.keras.Model):
raise ValueError("hidden_size must be a multiple of 2.") raise ValueError("hidden_size must be a multiple of 2.")
# TODO(frederickliu): Consider using the backbone factory. # TODO(frederickliu): Consider using the backbone factory.
# TODO(frederickliu): Add to factory once we get skeleton code in. # TODO(frederickliu): Add to factory once we get skeleton code in.
#self._backbone = resnet.ResNet(50, bn_trainable=False) #self._backbone = resnet.ResNet(101, bn_trainable=False)
# (gunho) use backbone factory # (gunho) use backbone factory
self._backbone = backbone self._backbone = backbone
......
...@@ -48,12 +48,17 @@ class DectectionTask(base_task.Task): ...@@ -48,12 +48,17 @@ class DectectionTask(base_task.Task):
input_specs = tf.keras.layers.InputSpec( input_specs = tf.keras.layers.InputSpec(
shape=[None] + self._task_config.model.input_size) shape=[None] + self._task_config.model.input_size)
l2_weight_decay = self.task_config.losses.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
input_specs=input_specs, input_specs=input_specs,
backbone_config=self._task_config.model.backbone, backbone_config=self._task_config.model.backbone,
norm_activation_config=self._task_config.model.norm_activation) norm_activation_config=self._task_config.model.norm_activation)
model = detr.DETR( model = detr.DETR(
backbone, backbone,
self._task_config.model.num_queries, self._task_config.model.num_queries,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment