Commit 5e6f52bf authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

clean up d2go/utils/helper

Summary:
X-link: https://github.com/facebookresearch/mobile-vision/pull/110

Pull Request resolved: https://github.com/facebookresearch/d2go/pull/380

- remove `alias`
- only annotate different implementation with `fb_overwrite`
- fix lint

Reviewed By: itomatik

Differential Revision: D39981383

fbshipit-source-id: 9739b7026510b3f1a2e69fe1de5b3f721759a209
parent 382bec5b
...@@ -9,10 +9,10 @@ import numpy as np ...@@ -9,10 +9,10 @@ import numpy as np
import torch import torch
from d2go.config import CfgNode as CN from d2go.config import CfgNode as CN
from d2go.data.dataset_mappers import D2GO_DATA_MAPPER_REGISTRY, D2GoDatasetMapper from d2go.data.dataset_mappers import D2GO_DATA_MAPPER_REGISTRY, D2GoDatasetMapper
from d2go.utils.helper import alias
from detectron2.layers import cat from detectron2.layers import cat
from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads
from detectron2.utils.registry import Registry from detectron2.utils.registry import Registry
from mobile_cv.torch.utils_toffee.alias import alias
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#!/usr/bin/python #!/usr/bin/python
import errno
import importlib import importlib
import inspect
import logging
import math
import os import os
import pickle from functools import wraps
import re from typing import Any, Callable, List, TypeVar
import signal
import sys
import tempfile
import threading
import time
import traceback
import typing
import warnings
import zipfile
from contextlib import contextmanager
from functools import partial, wraps
from random import random
from typing import (
Any,
Callable,
Iterable,
List,
Mapping,
NamedTuple,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
import pkg_resources
import six
import torch import torch
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog from detectron2.data import MetadataCatalog
from detectron2.engine import ( from detectron2.engine import DefaultTrainer
default_argument_parser,
default_setup,
DefaultTrainer,
hooks,
launch,
)
from detectron2.evaluation import ( from detectron2.evaluation import (
CityscapesInstanceEvaluator, CityscapesInstanceEvaluator,
CityscapesSemSegEvaluator, CityscapesSemSegEvaluator,
...@@ -58,18 +20,22 @@ from detectron2.evaluation import ( ...@@ -58,18 +20,22 @@ from detectron2.evaluation import (
LVISEvaluator, LVISEvaluator,
PascalVOCDetectionEvaluator, PascalVOCDetectionEvaluator,
SemSegEvaluator, SemSegEvaluator,
verify_results,
) )
from detectron2.utils.events import TensorboardXWriter
from mobile_cv.common.misc.oss_utils import fb_overwritable from mobile_cv.common.misc.oss_utils import fb_overwritable
T = TypeVar("T") T = TypeVar("T")
CallbackMapping = Mapping[Callable, Optional[Iterable[Any]]]
FuncType = Callable[..., Any] FuncType = Callable[..., Any]
F = TypeVar("F", bound=FuncType) F = TypeVar("F", bound=FuncType)
RT = TypeVar("RT")
NT = TypeVar("T", bound=NamedTuple)
from detectron2.utils.events import TensorboardXWriter
__all__ = [
"run_once",
"retryable",
"get_dir_path",
"TensorboardXWriter", # TODO: move to D2Go's vis utils if needed
"D2Trainer", # TODO: move to trainer folder
]
class MultipleFunctionCallError(Exception): class MultipleFunctionCallError(Exception):
...@@ -135,15 +101,6 @@ def get_dir_path(relative_path): ...@@ -135,15 +101,6 @@ def get_dir_path(relative_path):
return os.path.dirname(importlib.import_module(relative_path).__file__) return os.path.dirname(importlib.import_module(relative_path).__file__)
# copy util function for oss
def alias(x, name, is_backward=False):
if not torch.onnx.is_in_onnx_export():
return x
assert isinstance(x, torch.Tensor)
return torch.ops._caffe2.AliasWithName(x, name, is_backward=is_backward)
@fb_overwritable()
class D2Trainer(DefaultTrainer): class D2Trainer(DefaultTrainer):
@classmethod @classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None): def build_evaluator(cls, cfg, dataset_name, output_folder=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment