You need to sign in or sign up before continuing.
Commit 54b352d9 authored by Albert Pumarola's avatar Albert Pumarola Committed by Facebook GitHub Bot
Browse files

Add unittest for DETR runner

Summary: Add create and train unit tests to OSS runner

Reviewed By: zhanghang1989

Differential Revision: D29254417

fbshipit-source-id: f7c52b90b2bc7afa83a204895be149664c675e52
parent 58f0ae3d
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
import tempfile
import unittest
from detr import runner as oss_runner
import d2go.runner.default_runner as default_runner
from d2go.utils.testing.data_loader_helper import create_local_dataset
# RUN:
# buck test mobile-vision/d2go/projects_oss/detr:test_detr_runner
def _get_cfg(runner, output_dir, dataset_name):
cfg = runner.get_default_cfg()
cfg.MODEL.DEVICE = "cpu"
cfg.MODEL.META_ARCHITECTURE = "Detr"
cfg.DATASETS.TRAIN = (dataset_name,)
cfg.DATASETS.TEST = (dataset_name,)
cfg.INPUT.MIN_SIZE_TRAIN = (10,)
cfg.INPUT.MIN_SIZE_TEST = (10,)
cfg.SOLVER.MAX_ITER = 5
cfg.SOLVER.STEPS = []
cfg.SOLVER.WARMUP_ITERS = 1
cfg.SOLVER.CHECKPOINT_PERIOD = 1
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.OUTPUT_DIR = output_dir
return cfg
class TestOssRunner(unittest.TestCase):
def test_build_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
ds_name = create_local_dataset(tmp_dir, 5, 10, 10)
runner = oss_runner.DETRRunner()
cfg = _get_cfg(runner, tmp_dir, ds_name)
model = runner.build_model(cfg)
dl = runner.build_detection_train_loader(cfg)
batch = next(iter(dl))
output = model(batch)
self.assertIsInstance(output, dict)
model.eval()
output = model(batch)
self.assertIsInstance(output, list)
default_runner._close_all_tbx_writers()
def test_runner_train(self):
with tempfile.TemporaryDirectory() as tmp_dir:
ds_name = create_local_dataset(tmp_dir, 5, 10, 10, num_classes=1000)
runner = oss_runner.DETRRunner()
cfg = _get_cfg(runner, tmp_dir, ds_name)
model = runner.build_model(cfg)
runner.do_train(cfg, model, True)
final_model_path = os.path.join(tmp_dir, "model_final.pth")
self.assertTrue(os.path.isfile(final_model_path))
default_runner._close_all_tbx_writers()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment