"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "ca1e40726e0e6b770e62ba98e4c6f74ea6bf9647"
Commit 54b352d9 authored by Albert Pumarola's avatar Albert Pumarola Committed by Facebook GitHub Bot
Browse files

Add unittest for DETR runner

Summary: Add create and train unit tests to OSS runner

Reviewed By: zhanghang1989

Differential Revision: D29254417

fbshipit-source-id: f7c52b90b2bc7afa83a204895be149664c675e52
parent 58f0ae3d
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
import tempfile
import unittest
from detr import runner as oss_runner
import d2go.runner.default_runner as default_runner
from d2go.utils.testing.data_loader_helper import create_local_dataset
# RUN:
# buck test mobile-vision/d2go/projects_oss/detr:test_detr_runner
def _get_cfg(runner, output_dir, dataset_name):
cfg = runner.get_default_cfg()
cfg.MODEL.DEVICE = "cpu"
cfg.MODEL.META_ARCHITECTURE = "Detr"
cfg.DATASETS.TRAIN = (dataset_name,)
cfg.DATASETS.TEST = (dataset_name,)
cfg.INPUT.MIN_SIZE_TRAIN = (10,)
cfg.INPUT.MIN_SIZE_TEST = (10,)
cfg.SOLVER.MAX_ITER = 5
cfg.SOLVER.STEPS = []
cfg.SOLVER.WARMUP_ITERS = 1
cfg.SOLVER.CHECKPOINT_PERIOD = 1
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.OUTPUT_DIR = output_dir
return cfg
class TestOssRunner(unittest.TestCase):
def test_build_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
ds_name = create_local_dataset(tmp_dir, 5, 10, 10)
runner = oss_runner.DETRRunner()
cfg = _get_cfg(runner, tmp_dir, ds_name)
model = runner.build_model(cfg)
dl = runner.build_detection_train_loader(cfg)
batch = next(iter(dl))
output = model(batch)
self.assertIsInstance(output, dict)
model.eval()
output = model(batch)
self.assertIsInstance(output, list)
default_runner._close_all_tbx_writers()
def test_runner_train(self):
with tempfile.TemporaryDirectory() as tmp_dir:
ds_name = create_local_dataset(tmp_dir, 5, 10, 10, num_classes=1000)
runner = oss_runner.DETRRunner()
cfg = _get_cfg(runner, tmp_dir, ds_name)
model = runner.build_model(cfg)
runner.do_train(cfg, model, True)
final_model_path = os.path.join(tmp_dir, "model_final.pth")
self.assertTrue(os.path.isfile(final_model_path))
default_runner._close_all_tbx_writers()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment