Commit a0876f62 authored by Gunho Park's avatar Gunho Park
Browse files

Internal change

parent 2412b118
......@@ -92,7 +92,6 @@ class DetrTask(cfg.TaskConfig):
COCO_INPUT_PATH_BASE = 'gs://ghpark-tfrecords/coco'
COCO_TRAIN_EXAMPLES = 118287
#COCO_TRAIN_EXAMPLES = 9600
COCO_VAL_EXAMPLES = 5000
@exp_factory.register_config_factory('detr_coco')
......@@ -100,10 +99,9 @@ def detr_coco() -> cfg.ExperimentConfig:
"""Config to get results that matches the paper."""
train_batch_size = 32
eval_batch_size = 64
num_train_data = 118287
steps_per_epoch = COCO_TRAIN_EXAMPLES // train_batch_size
train_steps = 300 * steps_per_epoch # 500 epochs
decay_at = train_steps - 100 * steps_per_epoch # 400 epochs
train_steps = 300 * steps_per_epoch # 300 epochs
decay_at = train_steps - 100 * steps_per_epoch # 200 epochs
config = cfg.ExperimentConfig(
task=DetrTask(
init_checkpoint='gs://ghpark-imagenet-tfrecord/ckpt/resnet50_imagenet',
......
#!/bin/bash
python3 train.py \
--experiment=detr_coco \
--mode=train_and_eval \
--model_dir=gs://ghpark-ckpts/detr/detr_coco/ckpt_03_detr_coco_resnet101 \
--tpu=postech-tpu \
--params_override=runtime.distribution_strategy='tpu'
\ No newline at end of file
......@@ -48,13 +48,6 @@ class DectectionTask(base_task.Task):
input_specs = tf.keras.layers.InputSpec(
shape=[None] + self._task_config.model.input_size)
l2_weight_decay = self.task_config.losses.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
backbone_config=self._task_config.model.backbone,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment