Commit 9d238344 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

move test utils to core library

Summary: Not d2go.tests is not a library for oss, move utils code to d2go.utils.testing

Reviewed By: zhanghang1989

Differential Revision: D26706933

fbshipit-source-id: 85767b66bbb6c67db05e11823beb4840220b2aa3
parent ec2e8fff
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
......@@ -140,7 +140,7 @@ def create_local_dataset(
"meta_data": meta_data,
}
if is_rotated:
split_dict['evaluator_type'] = "rotated_coco"
split_dict["evaluator_type"] = "rotated_coco"
register_dataset_split(dataset_name, split_dict)
return dataset_name
......
......@@ -2,7 +2,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
from functools import wraps
from tempfile import TemporaryDirectory
......@@ -42,6 +41,7 @@ def enable_ddp_env(func):
return wrapper
def tempdir(func):
""" A decorator for creating a tempory directory that is cleaned up after function execution. """
......
......@@ -2,8 +2,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-unsafe
from functools import wraps
from tempfile import TemporaryDirectory
from typing import Optional
import torch
......
......@@ -3,11 +3,10 @@
import torch
from d2go.utils.testing.data_loader_helper import create_local_dataset
from detectron2.modeling import META_ARCH_REGISTRY
from detectron2.structures import Boxes, ImageList, Instances
from d2go.tests.data_loader_helper import create_local_dataset
@META_ARCH_REGISTRY.register()
class DetMetaArchForTest(torch.nn.Module):
......
......@@ -93,4 +93,3 @@ if __name__ == '__main__':
]
},
)
......@@ -9,7 +9,10 @@ import unittest
import d2go.data.extended_coco as extended_coco
from d2go.data.utils import maybe_subsample_n_images
from d2go.runner import Detectron2GoRunner
from d2go.tests.data_loader_helper import LocalImageGenerator, create_toy_dataset
from d2go.utils.testing.data_loader_helper import (
LocalImageGenerator,
create_toy_dataset,
)
from detectron2.data import DatasetCatalog
from mobile_cv.common.misc.file_utils import make_temp_directory
......
......@@ -5,11 +5,12 @@
import os
import unittest
from d2go.runner import GeneralizedRCNNRunner, create_runner
from d2go.runner import create_runner
from d2go.utils.testing.data_loader_helper import (
LocalImageGenerator,
register_toy_dataset,
)
from mobile_cv.common.misc.file_utils import make_temp_directory
from PIL import Image
from d2go.tests.data_loader_helper import LocalImageGenerator, register_toy_dataset
class TestD2GoDatasetMapper(unittest.TestCase):
......
......@@ -6,20 +6,22 @@ import unittest
import numpy as np
from d2go.data.transforms import color_yuv as cy
from d2go.data.transforms.build import build_transform_gen
from d2go.runner import Detectron2GoRunner
from detectron2.data.transforms import apply_augmentations
from d2go.data.transforms.build import build_transform_gen
class TestDataTransformsColorYUV(unittest.TestCase):
def test_yuv_color_transforms(self):
default_cfg = Detectron2GoRunner().get_default_cfg()
img = np.concatenate([
img = np.concatenate(
[
np.random.uniform(0, 1, size=(80, 60, 1)),
np.random.uniform(-0.5, 0.5, size=(80, 60, 1)),
np.random.uniform(-0.5, 0.5, size=(80, 60, 1)),
], axis=2)
],
axis=2,
)
default_cfg.D2GO_DATA.AUG_OPS.TRAIN = [
'RandomContrastYUVOp::{"intensity_min": 0.3, "intensity_max": 0.5}',
......@@ -45,7 +47,6 @@ class TestDataTransformsColorYUV(unittest.TestCase):
self.assertGreater(np.var(high_saturation[:, :, 1]), np.var(img[:, :, 1]))
self.assertGreater(np.var(high_saturation[:, :, 2]), np.var(img[:, :, 2]))
def test_transform_color_yuv_rgbyuv_convert(self):
image = np.arange(256).reshape(16, 16, 1).repeat(3, axis=2).astype(np.uint8)
tf1 = cy.RGB2YUVBT601().get_transform(image)
......
......@@ -7,19 +7,21 @@ import unittest
import numpy as np
import torch
from detectron2.data import DatasetCatalog, DatasetFromList, MapDataset
from detectron2.engine import SimpleTrainer
from d2go.modeling.kmeans_anchors import (
add_kmeans_anchors_cfg,
compute_kmeans_anchors,
compute_kmeans_anchors_hook,
)
from d2go.runner import GeneralizedRCNNRunner
from d2go.utils.testing.data_loader_helper import (
LocalImageGenerator,
register_toy_dataset,
)
from detectron2.data import DatasetCatalog, DatasetFromList, MapDataset
from detectron2.engine import SimpleTrainer
from mobile_cv.common.misc.file_utils import make_temp_directory
from torch.utils.data.sampler import BatchSampler, Sampler
from d2go.tests.data_loader_helper import LocalImageGenerator, register_toy_dataset
class IntervalSampler(Sampler):
def __init__(self, size: int, interval: int):
......
......@@ -9,9 +9,9 @@ import numpy as np
from d2go.config import CfgNode
from d2go.config.utils import flatten_config_dict
from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.tests import meta_arch_helper as mah
from d2go.tests.helper import tempdir
from d2go.tools.lightning_train_net import main, FINAL_MODEL_CKPT
from d2go.utils.testing import meta_arch_helper as mah
from d2go.utils.testing.helper import tempdir
class TestLightningTrainNet(unittest.TestCase):
......
......@@ -8,10 +8,9 @@ import d2go.data.transforms.box_utils as bu
import d2go.modeling.image_pooler as image_pooler
import numpy as np
import torch
from d2go.utils.testing import rcnn_helper as rh
from detectron2.structures import Boxes
from d2go.tests import rcnn_helper as rh
class TestModelingImagePooler(unittest.TestCase):
def test_image_pooler(self):
......
......@@ -6,11 +6,10 @@ import copy
import itertools
import unittest
import d2go.runner.default_runner as default_runner
import torch
from d2go.modeling import model_ema
import d2go.runner.default_runner as default_runner
from d2go.tests import helper
from d2go.utils.testing import helper
class TestArch(torch.nn.Module):
......@@ -174,6 +173,4 @@ class TestModelingModelEMAHook(unittest.TestCase):
out_model = TestArch()
ema_checkpointers["ema_state"].apply_to(out_model)
self.assertTrue(
_compare_state_dict(out_model, model)
)
self.assertTrue(_compare_state_dict(out_model, model))
......@@ -6,10 +6,9 @@ import unittest
import numpy as np
import torch
from d2go.utils.testing import rcnn_helper as rh
from detectron2.structures import Boxes
from d2go.tests import rcnn_helper as rh
class TestRCNNHelper(unittest.TestCase):
def test_get_instances_from_image(self):
......
......@@ -9,6 +9,8 @@ import unittest
import d2go.runner.default_runner as default_runner
import torch
from d2go.utils.testing import helper
from d2go.utils.testing.data_loader_helper import create_local_dataset
from detectron2.evaluation import COCOEvaluator, RotatedCOCOEvaluator
from detectron2.modeling import META_ARCH_REGISTRY
from detectron2.structures import Boxes, ImageList, Instances
......@@ -17,9 +19,6 @@ from mobile_cv.arch.quantization.qconfig import (
)
from torch.nn.parallel import DistributedDataParallel
from d2go.tests import helper
from d2go.tests.data_loader_helper import create_local_dataset
@META_ARCH_REGISTRY.register()
class MetaArchForTest(torch.nn.Module):
......@@ -328,6 +327,7 @@ class TestDefaultRunner(unittest.TestCase):
default_runner._close_all_tbx_writers()
def _compare_state_dict(sd1, sd2, abs_error=1e-3):
if len(sd1) != len(sd2):
return False
......
......@@ -13,9 +13,9 @@ from d2go.runner.callbacks.quantization import (
get_default_qconfig,
get_default_qat_qconfig,
)
from d2go.tests.helper import tempdir
from d2go.tests.lightning_test_module import TestModule
from d2go.utils.misc import mode
from d2go.utils.testing.helper import tempdir
from d2go.utils.testing.lightning_test_module import TestModule
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
......
......@@ -12,7 +12,7 @@ import pytorch_lightning as pl # type: ignore
import torch
from d2go.config import CfgNode
from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.tests import meta_arch_helper as mah
from d2go.utils.testing import meta_arch_helper as mah
from detectron2.utils.events import EventStorage
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from torch import Tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment