"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "1d4ad34af04abc2fde96ed1e1ae7995173681bbc"
Commit dcb9cb48 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

move oss utils from d2go to mobile_cv

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/261

X-link: https://github.com/facebookresearch/mobile-vision/pull/71

`is_oss` and `fb_overwritable` are also needed in `mobile_cv`, move them from d2go.

Reviewed By: zhanghang1989

Differential Revision: D36655821

fbshipit-source-id: 421c4d22d4c4620678908fe13d6e47ab39604ae7
parent 6f02a8de
...@@ -35,7 +35,5 @@ jobs: ...@@ -35,7 +35,5 @@ jobs:
pip install -e . pip install -e .
- name: Run pytest - name: Run pytest
env:
OSSRUN: 1
run: | run: |
python -m unittest discover -v -s tests python -m unittest discover -v -s tests
...@@ -6,7 +6,7 @@ from enum import Enum ...@@ -6,7 +6,7 @@ from enum import Enum
from typing import Any, Dict, List from typing import Any, Dict, List
import pkg_resources import pkg_resources
from d2go.utils.oss_helper import fb_overwritable from mobile_cv.common.misc.oss_utils import fb_overwritable
@fb_overwritable() @fb_overwritable()
......
...@@ -12,7 +12,6 @@ import torch ...@@ -12,7 +12,6 @@ import torch
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.data.dataset_mappers import build_dataset_mapper from d2go.data.dataset_mappers import build_dataset_mapper
from d2go.data.utils import ClipLengthGroupedDataset from d2go.data.utils import ClipLengthGroupedDataset
from d2go.utils.oss_helper import fb_overwritable
from detectron2.data import ( from detectron2.data import (
build_batch_data_loader, build_batch_data_loader,
build_detection_train_loader, build_detection_train_loader,
...@@ -23,6 +22,7 @@ from detectron2.data.common import DatasetFromList, MapDataset ...@@ -23,6 +22,7 @@ from detectron2.data.common import DatasetFromList, MapDataset
from detectron2.data.dataset_mapper import DatasetMapper from detectron2.data.dataset_mapper import DatasetMapper
from detectron2.data.samplers import RepeatFactorTrainingSampler from detectron2.data.samplers import RepeatFactorTrainingSampler
from detectron2.utils.comm import get_world_size from detectron2.utils.comm import get_world_size
from mobile_cv.common.misc.oss_utils import fb_overwritable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import os import os
from d2go.utils.oss_helper import fb_overwritable from mobile_cv.common.misc.oss_utils import fb_overwritable
@fb_overwritable() @fb_overwritable()
......
...@@ -9,9 +9,9 @@ import os ...@@ -9,9 +9,9 @@ import os
from collections import namedtuple from collections import namedtuple
from d2go.utils.helper import get_dir_path from d2go.utils.helper import get_dir_path
from d2go.utils.oss_helper import fb_overwritable
from detectron2.data import DatasetCatalog, MetadataCatalog from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.utils.registry import Registry from detectron2.utils.registry import Registry
from mobile_cv.common.misc.oss_utils import fb_overwritable
from .extended_coco import coco_text_load, extended_coco_load from .extended_coco import coco_text_load, extended_coco_load
from .extended_lvis import extended_lvis_load from .extended_lvis import extended_lvis_load
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from d2go.utils.oss_helper import fb_overwritable from mobile_cv.common.misc.oss_utils import fb_overwritable
@fb_overwritable() @fb_overwritable()
......
...@@ -32,7 +32,6 @@ from d2go.utils.flop_calculator import attach_profilers ...@@ -32,7 +32,6 @@ from d2go.utils.flop_calculator import attach_profilers
from d2go.utils.get_default_cfg import get_default_cfg from d2go.utils.get_default_cfg import get_default_cfg
from d2go.utils.helper import D2Trainer, TensorboardXWriter from d2go.utils.helper import D2Trainer, TensorboardXWriter
from d2go.utils.misc import get_tensorboard_log_dir from d2go.utils.misc import get_tensorboard_log_dir
from d2go.utils.oss_helper import fb_overwritable
from d2go.utils.visualization import DataLoaderVisWrapper, VisualizationEvaluator from d2go.utils.visualization import DataLoaderVisWrapper, VisualizationEvaluator
from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer
from detectron2.data import ( from detectron2.data import (
...@@ -53,6 +52,7 @@ from detectron2.evaluation import ( ...@@ -53,6 +52,7 @@ from detectron2.evaluation import (
from detectron2.modeling import GeneralizedRCNNWithTTA from detectron2.modeling import GeneralizedRCNNWithTTA
from detectron2.solver import build_lr_scheduler as d2_build_lr_scheduler from detectron2.solver import build_lr_scheduler as d2_build_lr_scheduler
from detectron2.utils.events import CommonMetricPrinter, JSONWriter from detectron2.utils.events import CommonMetricPrinter, JSONWriter
from mobile_cv.common.misc.oss_utils import fb_overwritable
from mobile_cv.predictor.api import PredictorWrapper from mobile_cv.predictor.api import PredictorWrapper
......
...@@ -12,7 +12,7 @@ from d2go.modeling.meta_arch.fcos import add_fcos_configs ...@@ -12,7 +12,7 @@ from d2go.modeling.meta_arch.fcos import add_fcos_configs
from d2go.modeling.model_freezing_utils import add_model_freezing_configs from d2go.modeling.model_freezing_utils import add_model_freezing_configs
from d2go.modeling.subclass import add_subclass_configs from d2go.modeling.subclass import add_subclass_configs
from d2go.quantization.modeling import add_quantization_default_configs from d2go.quantization.modeling import add_quantization_default_configs
from d2go.utils.oss_helper import fb_overwritable from mobile_cv.common.misc.oss_utils import fb_overwritable
@fb_overwritable() @fb_overwritable()
......
...@@ -39,7 +39,6 @@ import detectron2.utils.comm as comm ...@@ -39,7 +39,6 @@ import detectron2.utils.comm as comm
import pkg_resources import pkg_resources
import six import six
import torch import torch
from d2go.utils.oss_helper import fb_overwritable
from detectron2.checkpoint import DetectionCheckpointer from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog from detectron2.data import MetadataCatalog
...@@ -61,6 +60,7 @@ from detectron2.evaluation import ( ...@@ -61,6 +60,7 @@ from detectron2.evaluation import (
SemSegEvaluator, SemSegEvaluator,
verify_results, verify_results,
) )
from mobile_cv.common.misc.oss_utils import fb_overwritable
T = TypeVar("T") T = TypeVar("T")
CallbackMapping = Mapping[Callable, Optional[Iterable[Any]]] CallbackMapping = Mapping[Callable, Optional[Iterable[Any]]]
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from d2go.utils.oss_helper import fb_overwritable from mobile_cv.common.misc.oss_utils import fb_overwritable
@fb_overwritable() @fb_overwritable()
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from mobile_cv.common.misc.py import dynamic_import
def fb_overwritable():
"""Decorator on function that has alternative internal implementation"""
try:
import d2go.utils.fb.open_source_canary # noqa
is_oss = False
except ImportError:
is_oss = True
def deco(oss_func):
if is_oss:
return oss_func
else:
oss_module = oss_func.__module__
fb_module = oss_module + "_fb" # xxx.py -> xxx_fb.py
fb_func = dynamic_import("{}.{}".format(fb_module, oss_func.__name__))
return fb_func
return deco
...@@ -6,7 +6,7 @@ import logging ...@@ -6,7 +6,7 @@ import logging
import os import os
from functools import lru_cache from functools import lru_cache
from d2go.utils.oss_helper import fb_overwritable from mobile_cv.common.misc.oss_utils import fb_overwritable
@fb_overwritable() @fb_overwritable()
......
...@@ -2,17 +2,17 @@ ...@@ -2,17 +2,17 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
import unittest import unittest
import torch import torch
from detectron2.layers import cat from detectron2.layers import cat
from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference
from detectron2.structures import Boxes from detectron2.structures import Boxes
from mobile_cv.common.misc.oss_utils import is_oss
class TestBoxWithNMSLimit(unittest.TestCase): class TestBoxWithNMSLimit(unittest.TestCase):
@unittest.skipIf(os.getenv("OSSRUN") == "1", "Caffe2 is not available for OSS") @unittest.skipIf(is_oss(), "Caffe2 is not available for OSS")
def test_caffe2_pytorch_eq(self): def test_caffe2_pytorch_eq(self):
ims_per_batch = 8 ims_per_batch = 8
post_nms_topk = 100 post_nms_topk = 100
......
...@@ -15,6 +15,7 @@ from d2go.utils.testing.data_loader_helper import ( ...@@ -15,6 +15,7 @@ from d2go.utils.testing.data_loader_helper import (
) )
from d2go.utils.testing.rcnn_helper import get_quick_test_config_opts, RCNNBaseTestCases from d2go.utils.testing.rcnn_helper import get_quick_test_config_opts, RCNNBaseTestCases
from mobile_cv.common.misc.file_utils import make_temp_directory from mobile_cv.common.misc.file_utils import make_temp_directory
from mobile_cv.common.misc.oss_utils import is_oss
# Add APIs to D2's meta arch, this is usually called in runner's setup, however in # Add APIs to D2's meta arch, this is usually called in runner's setup, however in
# unittest it needs to be called sperarately. (maybe we should apply this by default) # unittest it needs to be called sperarately. (maybe we should apply this by default)
...@@ -22,7 +23,7 @@ patch_d2_meta_arch() ...@@ -22,7 +23,7 @@ patch_d2_meta_arch()
def _maybe_skip_test(self, predictor_type): def _maybe_skip_test(self, predictor_type):
if os.getenv("OSSRUN") == "1" and "@c2_ops" in predictor_type: if is_oss() and "@c2_ops" in predictor_type:
self.skipTest("Caffe2 is not available for OSS") self.skipTest("Caffe2 is not available for OSS")
if not torch.cuda.is_available() and "_gpu" in predictor_type: if not torch.cuda.is_available() and "_gpu" in predictor_type:
...@@ -129,7 +130,7 @@ class TestFBNetV3KeypointRCNNFP32(RCNNBaseTestCases.TemplateTestCase): ...@@ -129,7 +130,7 @@ class TestFBNetV3KeypointRCNNFP32(RCNNBaseTestCases.TemplateTestCase):
] ]
) )
def test_export(self, predictor_type, compare_match): def test_export(self, predictor_type, compare_match):
if os.getenv("OSSRUN") == "1" and "@c2_ops" in predictor_type: if is_oss() and "@c2_ops" in predictor_type:
self.skipTest("Caffe2 is not available for OSS") self.skipTest("Caffe2 is not available for OSS")
self._test_export(predictor_type, compare_match=compare_match) self._test_export(predictor_type, compare_match=compare_match)
......
...@@ -10,6 +10,7 @@ from d2go.tools.exporter import main ...@@ -10,6 +10,7 @@ from d2go.tools.exporter import main
from d2go.utils.testing.data_loader_helper import create_local_dataset from d2go.utils.testing.data_loader_helper import create_local_dataset
from d2go.utils.testing.rcnn_helper import get_quick_test_config_opts from d2go.utils.testing.rcnn_helper import get_quick_test_config_opts
from mobile_cv.common.misc.file_utils import make_temp_directory from mobile_cv.common.misc.file_utils import make_temp_directory
from mobile_cv.common.misc.oss_utils import is_oss
def maskrcnn_export_caffe2_vs_torchvision_opset_format_example(self): def maskrcnn_export_caffe2_vs_torchvision_opset_format_example(self):
...@@ -100,6 +101,6 @@ def maskrcnn_export_caffe2_vs_torchvision_opset_format_example(self): ...@@ -100,6 +101,6 @@ def maskrcnn_export_caffe2_vs_torchvision_opset_format_example(self):
class TestOptimizer(unittest.TestCase): class TestOptimizer(unittest.TestCase):
@unittest.skipIf(os.getenv("OSSRUN") == "1", "Caffe2 is not available for OSS") @unittest.skipIf(is_oss(), "Caffe2 is not available for OSS")
def test_maskrcnn_export_caffe2_vs_torchvision_opset_format_example(self): def test_maskrcnn_export_caffe2_vs_torchvision_opset_format_example(self):
maskrcnn_export_caffe2_vs_torchvision_opset_format_example(self) maskrcnn_export_caffe2_vs_torchvision_opset_format_example(self)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment