Commit 67446adb authored by Gunho Park's avatar Gunho Park
Browse files

Make both input types work

parent 38a5d626
......@@ -58,6 +58,7 @@ class DataConfig(cfg.DataConfig):
@dataclasses.dataclass
class Losses(hyperparams.Config):
class_offset: int = 0
lambda_cls: float = 1.0
lambda_box: float = 5.0
lambda_giou: float = 2.0
......@@ -101,7 +102,7 @@ def detr_coco() -> cfg.ExperimentConfig:
decay_at = train_steps - 100 * num_steps_per_epoch # 400 epochs
config = cfg.ExperimentConfig(
task=DetrTask(
init_checkpoint='gs://ghpark-imagenet-tfrecord/ckpt/resnet50_imagenet',
init_checkpoint='gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400',
init_checkpoint_modules='backbone',
model=Detr(
num_classes=81,
......@@ -109,7 +110,6 @@ def detr_coco() -> cfg.ExperimentConfig:
norm_activation=common.NormActivation(use_sync_bn=False)),
losses=Losses(),
train_data=coco.COCODataConfig(
file_type='tfrecord',
tfds_name='coco/2017',
tfds_split='train',
is_training=True,
......@@ -117,7 +117,6 @@ def detr_coco() -> cfg.ExperimentConfig:
shuffle_buffer_size=1000,
),
validation_data=coco.COCODataConfig(
file_type='tfrecord',
tfds_name='coco/2017',
tfds_split='validation',
is_training=False,
......@@ -159,7 +158,7 @@ def detr_coco() -> cfg.ExperimentConfig:
])
return config
COCO_INPUT_PATH_BASE = 'gs://ghpark-tfrecords/coco'
COCO_INPUT_PATH_BASE = ''
COCO_TRAIN_EXAMPLES = 118287
COCO_VAL_EXAMPLES = 5000
......@@ -173,7 +172,7 @@ def detr_coco() -> cfg.ExperimentConfig:
decay_at = train_steps - 100 * steps_per_epoch # 200 epochs
config = cfg.ExperimentConfig(
task=DetrTask(
init_checkpoint='gs://ghpark-imagenet-tfrecord/ckpt/resnet50_imagenet',
init_checkpoint='gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400',
init_checkpoint_modules='backbone',
annotation_file=os.path.join(COCO_INPUT_PATH_BASE,
'instances_val2017.json'),
......@@ -227,3 +226,71 @@ def detr_coco() -> cfg.ExperimentConfig:
'task.train_data.is_training != None',
])
return config
@exp_factory.register_config_factory('detr_coco_tfds')
def detr_coco() -> cfg.ExperimentConfig:
"""Config to get results that matches the paper."""
train_batch_size = 64
eval_batch_size = 64
steps_per_epoch = COCO_TRAIN_EXAMPLES // train_batch_size
train_steps = 300 * steps_per_epoch # 300 epochs
decay_at = train_steps - 100 * steps_per_epoch # 200 epochs
config = cfg.ExperimentConfig(
task=DetrTask(
init_checkpoint='gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400',
init_checkpoint_modules='backbone',
model=Detr(
num_classes=81,
input_size=[1333, 1333, 3],
norm_activation=common.NormActivation(use_sync_bn=False)),
losses=Losses(
class_offset=1
),
train_data=DataConfig(
tfds_name='coco/2017',
tfds_split='train',
is_training=True,
global_batch_size=train_batch_size,
shuffle_buffer_size=1000,
),
validation_data=DataConfig(
tfds_name='coco/2017',
tfds_split='validation',
is_training=False,
global_batch_size=eval_batch_size,
drop_remainder=False
)
),
trainer=cfg.TrainerConfig(
train_steps=train_steps,
validation_steps=COCO_VAL_EXAMPLES // eval_batch_size,
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
validation_interval=5*steps_per_epoch,
max_to_keep=1,
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_eval_metric='AP',
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'detr_adamw',
'detr_adamw': {
'weight_decay_rate': 1e-4,
'global_clipnorm': 0.1,
# Avoid AdamW legacy behavior.
'gradient_clip_norm': 0.0
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [decay_at],
'values': [0.0001, 1.0e-05]
}
},
})
),
restrictions=[
'task.train_data.is_training != None',
])
return config
\ No newline at end of file
......@@ -27,7 +27,6 @@ from official.vision.ops import preprocess_ops
@dataclasses.dataclass
class COCODataConfig(cfg.DataConfig):
"""Data config for COCO."""
file_type: str = 'tfrecord'
output_size: Tuple[int, int] = (1333, 1333)
max_num_boxes: int = 100
resize_scales: Tuple[int, ...] = (
......
......@@ -31,34 +31,28 @@ class Parser(parser.Parser):
"""Parse an image and its annotations into a dictionary of tensors."""
def __init__(self,
class_offset: int = 0,
output_size: Tuple[int, int] = (1333, 1333),
max_num_boxes: int = 100,
resize_scales: Tuple[int, ...] = RESIZE_SCALES,
aug_rand_hflip=True):
self._class_offset = class_offset
self._output_size = output_size
self._max_num_boxes = max_num_boxes
self._resize_scales = resize_scales
self._aug_rand_hflip = aug_rand_hflip
def _parse_train_data(self, data):
"""Parses data for training and evaluation."""
#classes = data['groundtruth_classes'] + 1
classes = data['groundtruth_classes']
classes = data['groundtruth_classes'] + self._class_offset
boxes = data['groundtruth_boxes']
is_crowd = data['groundtruth_is_crowd']
# Gets original image.
image = data['image']
# Apply autoaug or randaug.
#if self._augmenter is not None:
# image, boxes = self._augmenter.distort_with_boxes(image, boxes)
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image)
image, boxes, _ = preprocess_ops.random_horizontal_flip(image, boxes)
do_crop = tf.greater(tf.random.uniform([]), 0.5)
......
......@@ -31,6 +31,7 @@ from official.vision.dataloaders import tf_example_decoder
from official.vision.dataloaders import tfds_factory
from official.vision.dataloaders import tf_example_label_map_decoder
from official.projects.detr.dataloaders import detr_input
from official.projects.detr.dataloaders import coco
from official.vision.modeling import backbones
@task_factory.register_task_cls(detr_cfg.DetrTask)
......@@ -84,16 +85,13 @@ class DectectionTask(base_task.Task):
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
"""def build_inputs(self,
params: detr_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
return coco.COCODataLoader(params).load(input_context)"""
def build_inputs(self,
params,
input_context: Optional[tf.distribute.InputContext] = None):
"""Build input dataset."""
if type(params) is coco.COCODataConfig:
dataset = coco.COCODataLoader(params).load(input_context)
else:
if params.tfds_name:
decoder = tfds_factory.get_detection_decoder(params.tfds_name)
else:
......@@ -110,6 +108,7 @@ class DectectionTask(base_task.Task):
params.decoder.type))
parser = detr_input.Parser(
class_offset=self._task_config.losses.class_offset,
output_size=self._task_config.model.input_size[:2],
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment