Commit 9d238344 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

move test utils to core library

Summary: Not d2go.tests is not a library for oss, move utils code to d2go.utils.testing

Reviewed By: zhanghang1989

Differential Revision: D26706933

fbshipit-source-id: 85767b66bbb6c67db05e11823beb4840220b2aa3
parent ec2e8fff
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
...@@ -44,7 +44,7 @@ def create_toy_dataset( ...@@ -44,7 +44,7 @@ def create_toy_dataset(
bbox = ( bbox = (
[width / 4, height / 4, width / 2, height / 2] # XYWH_ABS [width / 4, height / 4, width / 2, height / 2] # XYWH_ABS
if not is_rotated if not is_rotated
else [width / 2, height / 2, width / 2, height / 2, 45] # cXcYWHO_ABS else [width / 2, height / 2, width / 2, height / 2, 45] # cXcYWHO_ABS
) )
annotations.append( annotations.append(
...@@ -140,7 +140,7 @@ def create_local_dataset( ...@@ -140,7 +140,7 @@ def create_local_dataset(
"meta_data": meta_data, "meta_data": meta_data,
} }
if is_rotated: if is_rotated:
split_dict['evaluator_type'] = "rotated_coco" split_dict["evaluator_type"] = "rotated_coco"
register_dataset_split(dataset_name, split_dict) register_dataset_split(dataset_name, split_dict)
return dataset_name return dataset_name
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
from functools import wraps from functools import wraps
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
...@@ -42,6 +41,7 @@ def enable_ddp_env(func): ...@@ -42,6 +41,7 @@ def enable_ddp_env(func):
return wrapper return wrapper
def tempdir(func): def tempdir(func):
""" A decorator for creating a tempory directory that is cleaned up after function execution. """ """ A decorator for creating a tempory directory that is cleaned up after function execution. """
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-unsafe # pyre-unsafe
from functools import wraps
from tempfile import TemporaryDirectory
from typing import Optional from typing import Optional
import torch import torch
......
...@@ -3,11 +3,10 @@ ...@@ -3,11 +3,10 @@
import torch import torch
from d2go.utils.testing.data_loader_helper import create_local_dataset
from detectron2.modeling import META_ARCH_REGISTRY from detectron2.modeling import META_ARCH_REGISTRY
from detectron2.structures import Boxes, ImageList, Instances from detectron2.structures import Boxes, ImageList, Instances
from d2go.tests.data_loader_helper import create_local_dataset
@META_ARCH_REGISTRY.register() @META_ARCH_REGISTRY.register()
class DetMetaArchForTest(torch.nn.Module): class DetMetaArchForTest(torch.nn.Module):
......
...@@ -93,4 +93,3 @@ if __name__ == '__main__': ...@@ -93,4 +93,3 @@ if __name__ == '__main__':
] ]
}, },
) )
...@@ -9,7 +9,10 @@ import unittest ...@@ -9,7 +9,10 @@ import unittest
import d2go.data.extended_coco as extended_coco import d2go.data.extended_coco as extended_coco
from d2go.data.utils import maybe_subsample_n_images from d2go.data.utils import maybe_subsample_n_images
from d2go.runner import Detectron2GoRunner from d2go.runner import Detectron2GoRunner
from d2go.tests.data_loader_helper import LocalImageGenerator, create_toy_dataset from d2go.utils.testing.data_loader_helper import (
LocalImageGenerator,
create_toy_dataset,
)
from detectron2.data import DatasetCatalog from detectron2.data import DatasetCatalog
from mobile_cv.common.misc.file_utils import make_temp_directory from mobile_cv.common.misc.file_utils import make_temp_directory
......
...@@ -5,11 +5,12 @@ ...@@ -5,11 +5,12 @@
import os import os
import unittest import unittest
from d2go.runner import GeneralizedRCNNRunner, create_runner from d2go.runner import create_runner
from d2go.utils.testing.data_loader_helper import (
LocalImageGenerator,
register_toy_dataset,
)
from mobile_cv.common.misc.file_utils import make_temp_directory from mobile_cv.common.misc.file_utils import make_temp_directory
from PIL import Image
from d2go.tests.data_loader_helper import LocalImageGenerator, register_toy_dataset
class TestD2GoDatasetMapper(unittest.TestCase): class TestD2GoDatasetMapper(unittest.TestCase):
......
...@@ -6,20 +6,22 @@ import unittest ...@@ -6,20 +6,22 @@ import unittest
import numpy as np import numpy as np
from d2go.data.transforms import color_yuv as cy from d2go.data.transforms import color_yuv as cy
from d2go.data.transforms.build import build_transform_gen
from d2go.runner import Detectron2GoRunner from d2go.runner import Detectron2GoRunner
from detectron2.data.transforms import apply_augmentations from detectron2.data.transforms import apply_augmentations
from d2go.data.transforms.build import build_transform_gen
class TestDataTransformsColorYUV(unittest.TestCase): class TestDataTransformsColorYUV(unittest.TestCase):
def test_yuv_color_transforms(self): def test_yuv_color_transforms(self):
default_cfg = Detectron2GoRunner().get_default_cfg() default_cfg = Detectron2GoRunner().get_default_cfg()
img = np.concatenate([ img = np.concatenate(
np.random.uniform(0, 1, size=(80, 60, 1)), [
np.random.uniform(-0.5, 0.5, size=(80, 60, 1)), np.random.uniform(0, 1, size=(80, 60, 1)),
np.random.uniform(-0.5, 0.5, size=(80, 60, 1)), np.random.uniform(-0.5, 0.5, size=(80, 60, 1)),
], axis=2) np.random.uniform(-0.5, 0.5, size=(80, 60, 1)),
],
axis=2,
)
default_cfg.D2GO_DATA.AUG_OPS.TRAIN = [ default_cfg.D2GO_DATA.AUG_OPS.TRAIN = [
'RandomContrastYUVOp::{"intensity_min": 0.3, "intensity_max": 0.5}', 'RandomContrastYUVOp::{"intensity_min": 0.3, "intensity_max": 0.5}',
...@@ -45,7 +47,6 @@ class TestDataTransformsColorYUV(unittest.TestCase): ...@@ -45,7 +47,6 @@ class TestDataTransformsColorYUV(unittest.TestCase):
self.assertGreater(np.var(high_saturation[:, :, 1]), np.var(img[:, :, 1])) self.assertGreater(np.var(high_saturation[:, :, 1]), np.var(img[:, :, 1]))
self.assertGreater(np.var(high_saturation[:, :, 2]), np.var(img[:, :, 2])) self.assertGreater(np.var(high_saturation[:, :, 2]), np.var(img[:, :, 2]))
def test_transform_color_yuv_rgbyuv_convert(self): def test_transform_color_yuv_rgbyuv_convert(self):
image = np.arange(256).reshape(16, 16, 1).repeat(3, axis=2).astype(np.uint8) image = np.arange(256).reshape(16, 16, 1).repeat(3, axis=2).astype(np.uint8)
tf1 = cy.RGB2YUVBT601().get_transform(image) tf1 = cy.RGB2YUVBT601().get_transform(image)
......
...@@ -7,19 +7,21 @@ import unittest ...@@ -7,19 +7,21 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from detectron2.data import DatasetCatalog, DatasetFromList, MapDataset
from detectron2.engine import SimpleTrainer
from d2go.modeling.kmeans_anchors import ( from d2go.modeling.kmeans_anchors import (
add_kmeans_anchors_cfg, add_kmeans_anchors_cfg,
compute_kmeans_anchors, compute_kmeans_anchors,
compute_kmeans_anchors_hook, compute_kmeans_anchors_hook,
) )
from d2go.runner import GeneralizedRCNNRunner from d2go.runner import GeneralizedRCNNRunner
from d2go.utils.testing.data_loader_helper import (
LocalImageGenerator,
register_toy_dataset,
)
from detectron2.data import DatasetCatalog, DatasetFromList, MapDataset
from detectron2.engine import SimpleTrainer
from mobile_cv.common.misc.file_utils import make_temp_directory from mobile_cv.common.misc.file_utils import make_temp_directory
from torch.utils.data.sampler import BatchSampler, Sampler from torch.utils.data.sampler import BatchSampler, Sampler
from d2go.tests.data_loader_helper import LocalImageGenerator, register_toy_dataset
class IntervalSampler(Sampler): class IntervalSampler(Sampler):
def __init__(self, size: int, interval: int): def __init__(self, size: int, interval: int):
......
...@@ -9,9 +9,9 @@ import numpy as np ...@@ -9,9 +9,9 @@ import numpy as np
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.config.utils import flatten_config_dict from d2go.config.utils import flatten_config_dict
from d2go.runner.lightning_task import GeneralizedRCNNTask from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.tests import meta_arch_helper as mah
from d2go.tests.helper import tempdir
from d2go.tools.lightning_train_net import main, FINAL_MODEL_CKPT from d2go.tools.lightning_train_net import main, FINAL_MODEL_CKPT
from d2go.utils.testing import meta_arch_helper as mah
from d2go.utils.testing.helper import tempdir
class TestLightningTrainNet(unittest.TestCase): class TestLightningTrainNet(unittest.TestCase):
......
...@@ -8,10 +8,9 @@ import d2go.data.transforms.box_utils as bu ...@@ -8,10 +8,9 @@ import d2go.data.transforms.box_utils as bu
import d2go.modeling.image_pooler as image_pooler import d2go.modeling.image_pooler as image_pooler
import numpy as np import numpy as np
import torch import torch
from d2go.utils.testing import rcnn_helper as rh
from detectron2.structures import Boxes from detectron2.structures import Boxes
from d2go.tests import rcnn_helper as rh
class TestModelingImagePooler(unittest.TestCase): class TestModelingImagePooler(unittest.TestCase):
def test_image_pooler(self): def test_image_pooler(self):
......
...@@ -6,11 +6,10 @@ import copy ...@@ -6,11 +6,10 @@ import copy
import itertools import itertools
import unittest import unittest
import d2go.runner.default_runner as default_runner
import torch import torch
from d2go.modeling import model_ema from d2go.modeling import model_ema
import d2go.runner.default_runner as default_runner from d2go.utils.testing import helper
from d2go.tests import helper
class TestArch(torch.nn.Module): class TestArch(torch.nn.Module):
...@@ -174,6 +173,4 @@ class TestModelingModelEMAHook(unittest.TestCase): ...@@ -174,6 +173,4 @@ class TestModelingModelEMAHook(unittest.TestCase):
out_model = TestArch() out_model = TestArch()
ema_checkpointers["ema_state"].apply_to(out_model) ema_checkpointers["ema_state"].apply_to(out_model)
self.assertTrue( self.assertTrue(_compare_state_dict(out_model, model))
_compare_state_dict(out_model, model)
)
...@@ -6,10 +6,9 @@ import unittest ...@@ -6,10 +6,9 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from d2go.utils.testing import rcnn_helper as rh
from detectron2.structures import Boxes from detectron2.structures import Boxes
from d2go.tests import rcnn_helper as rh
class TestRCNNHelper(unittest.TestCase): class TestRCNNHelper(unittest.TestCase):
def test_get_instances_from_image(self): def test_get_instances_from_image(self):
......
...@@ -9,6 +9,8 @@ import unittest ...@@ -9,6 +9,8 @@ import unittest
import d2go.runner.default_runner as default_runner import d2go.runner.default_runner as default_runner
import torch import torch
from d2go.utils.testing import helper
from d2go.utils.testing.data_loader_helper import create_local_dataset
from detectron2.evaluation import COCOEvaluator, RotatedCOCOEvaluator from detectron2.evaluation import COCOEvaluator, RotatedCOCOEvaluator
from detectron2.modeling import META_ARCH_REGISTRY from detectron2.modeling import META_ARCH_REGISTRY
from detectron2.structures import Boxes, ImageList, Instances from detectron2.structures import Boxes, ImageList, Instances
...@@ -17,9 +19,6 @@ from mobile_cv.arch.quantization.qconfig import ( ...@@ -17,9 +19,6 @@ from mobile_cv.arch.quantization.qconfig import (
) )
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from d2go.tests import helper
from d2go.tests.data_loader_helper import create_local_dataset
@META_ARCH_REGISTRY.register() @META_ARCH_REGISTRY.register()
class MetaArchForTest(torch.nn.Module): class MetaArchForTest(torch.nn.Module):
...@@ -328,6 +327,7 @@ class TestDefaultRunner(unittest.TestCase): ...@@ -328,6 +327,7 @@ class TestDefaultRunner(unittest.TestCase):
default_runner._close_all_tbx_writers() default_runner._close_all_tbx_writers()
def _compare_state_dict(sd1, sd2, abs_error=1e-3): def _compare_state_dict(sd1, sd2, abs_error=1e-3):
if len(sd1) != len(sd2): if len(sd1) != len(sd2):
return False return False
......
...@@ -13,9 +13,9 @@ from d2go.runner.callbacks.quantization import ( ...@@ -13,9 +13,9 @@ from d2go.runner.callbacks.quantization import (
get_default_qconfig, get_default_qconfig,
get_default_qat_qconfig, get_default_qat_qconfig,
) )
from d2go.tests.helper import tempdir
from d2go.tests.lightning_test_module import TestModule
from d2go.utils.misc import mode from d2go.utils.misc import mode
from d2go.utils.testing.helper import tempdir
from d2go.utils.testing.lightning_test_module import TestModule
from pytorch_lightning import Trainer, seed_everything from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
......
...@@ -12,7 +12,7 @@ import pytorch_lightning as pl # type: ignore ...@@ -12,7 +12,7 @@ import pytorch_lightning as pl # type: ignore
import torch import torch
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.runner.lightning_task import GeneralizedRCNNTask from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.tests import meta_arch_helper as mah from d2go.utils.testing import meta_arch_helper as mah
from detectron2.utils.events import EventStorage from detectron2.utils.events import EventStorage
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from torch import Tensor from torch import Tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment