Commit a0658c4a authored by Sam Tsai's avatar Sam Tsai Committed by Facebook GitHub Bot
Browse files

reorganize unit tests

Summary: Separate unit tests into individual folder based on functionality.

Reviewed By: wat3rBro

Differential Revision: D27132567

fbshipit-source-id: 9a8200be530ca14c7ef42191d59795b05b9800cc
parent d29f93e7
...@@ -2,16 +2,29 @@ ...@@ -2,16 +2,29 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import importlib
import os import os
import socket import socket
import uuid import uuid
from functools import wraps from functools import wraps
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
def get_resource_path(file: Optional[str] = None):
path_list = [
os.path.dirname(importlib.import_module("d2go.tests").__file__),
"resources",
]
if file is not None:
path_list.append(file)
return os.path.join(*path_list)
def skip_if_no_gpu(func): def skip_if_no_gpu(func):
"""Decorator that can be used to skip GPU tests on non-GPU machines.""" """Decorator that can be used to skip GPU tests on non-GPU machines."""
func.skip_if_no_gpu = True func.skip_if_no_gpu = True
......
...@@ -5,9 +5,9 @@ ...@@ -5,9 +5,9 @@
import unittest import unittest
import numpy as np import numpy as np
from detectron2.data.transforms import apply_transform_gens
from d2go.data.transforms.build import build_transform_gen from d2go.data.transforms.build import build_transform_gen
from d2go.runner import Detectron2GoRunner from d2go.runner import Detectron2GoRunner
from detectron2.data.transforms import apply_transform_gens
class TestDataTransforms(unittest.TestCase): class TestDataTransforms(unittest.TestCase):
......
...@@ -10,7 +10,7 @@ from d2go.data.transforms import tensor as tensor_aug ...@@ -10,7 +10,7 @@ from d2go.data.transforms import tensor as tensor_aug
from detectron2.data.transforms.augmentation import AugmentationList from detectron2.data.transforms.augmentation import AugmentationList
class TestDataTransformsTenaor(unittest.TestCase): class TestDataTransformsTensor(unittest.TestCase):
def test_tensor_aug(self): def test_tensor_aug(self):
"""Data augmentation that that allows torch.Tensor as input""" """Data augmentation that that allows torch.Tensor as input"""
......
...@@ -9,6 +9,7 @@ import unittest ...@@ -9,6 +9,7 @@ import unittest
from d2go.config import auto_scale_world_size, reroute_config_path from d2go.config import auto_scale_world_size, reroute_config_path
from d2go.runner import GeneralizedRCNNRunner from d2go.runner import GeneralizedRCNNRunner
from d2go.tests.helper import get_resource_path
from mobile_cv.common.misc.file_utils import make_temp_directory from mobile_cv.common.misc.file_utils import make_temp_directory
...@@ -31,8 +32,7 @@ class TestConfigs(unittest.TestCase): ...@@ -31,8 +32,7 @@ class TestConfigs(unittest.TestCase):
""" Test arch def str-to-dict conversion compatible with merging """ """ Test arch def str-to-dict conversion compatible with merging """
default_cfg = GeneralizedRCNNRunner().get_default_cfg() default_cfg = GeneralizedRCNNRunner().get_default_cfg()
cfg = default_cfg.clone() cfg = default_cfg.clone()
cfg.merge_from_file(os.path.join(os.path.dirname(os.path.abspath(__file__)), cfg.merge_from_file(get_resource_path("arch_def_merging.yaml"))
"resources/arch_def_merging.yaml"))
with make_temp_directory("detectron2go_tmp") as tmp_dir: with make_temp_directory("detectron2go_tmp") as tmp_dir:
# Dump out config with arch def # Dump out config with arch def
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os import os
import tempfile
import unittest import unittest
import numpy as np import numpy as np
...@@ -50,10 +49,9 @@ class TestLightningTrainNet(unittest.TestCase): ...@@ -50,10 +49,9 @@ class TestLightningTrainNet(unittest.TestCase):
ckpts, ckpts,
) )
tmp_dir2 = tempfile.TemporaryDirectory() # noqa to avoid flaky test
cfg2 = cfg.clone() cfg2 = cfg.clone()
cfg2.defrost() cfg2.defrost()
cfg2.OUTPUT_DIR = tmp_dir2.name cfg2.OUTPUT_DIR = os.path.join(tmp_dir, 'output')
# load the last checkpoint from previous training # load the last checkpoint from previous training
cfg2.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt") cfg2.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
...@@ -62,4 +60,3 @@ class TestLightningTrainNet(unittest.TestCase): ...@@ -62,4 +60,3 @@ class TestLightningTrainNet(unittest.TestCase):
accuracy2 = flatten_config_dict(out2.accuracy) accuracy2 = flatten_config_dict(out2.accuracy)
for k in accuracy: for k in accuracy:
np.testing.assert_equal(accuracy[k], accuracy2[k]) np.testing.assert_equal(accuracy[k], accuracy2[k])
tmp_dir2.cleanup()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment