"...python/git@developer.sourcefind.cn:change/sglang.git" did not exist on "c9bf3877a0a02a80267bd851fb712c30f3bf9ccd"
Commit b04ba38b authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

add evaluation result type annotation

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/343

Reviewed By: miqueljubert

Differential Revision: D38077850

fbshipit-source-id: a79541d899ce2b49a30c7f2a81a616f76321026f
parent 5c16a4ea
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Dict, TypeVar, Union
T = TypeVar("T")
# "accuracy" in D2Go is defined by a 4-level dictionary in the order of:
# model_tag -> dataset -> task -> metrics
AccuracyDict = Dict[str, Dict[str, Dict[str, Dict[str, T]]]]
# "metric" in D2Go is a nested dictionary, which may have arbitrary levels.
MetricsDict = Union[Dict[str, "MetricsDict"], T]
...@@ -7,12 +7,13 @@ Tool for benchmarking data loading ...@@ -7,12 +7,13 @@ Tool for benchmarking data loading
import logging import logging
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Type, Union from typing import Type, Union
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
import numpy as np import numpy as np
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.distributed import get_num_processes_per_machine, launch from d2go.distributed import get_num_processes_per_machine, launch
from d2go.evaluation.api import AccuracyDict, MetricsDict
from d2go.runner import BaseRunner from d2go.runner import BaseRunner
from d2go.setup import ( from d2go.setup import (
basic_argument_parser, basic_argument_parser,
...@@ -30,9 +31,8 @@ logger = logging.getLogger("d2go.tools.benchmark_data") ...@@ -30,9 +31,8 @@ logger = logging.getLogger("d2go.tools.benchmark_data")
@dataclass @dataclass
class BenchmarkDataOutput: class BenchmarkDataOutput:
accuracy: Dict[str, Any] accuracy: AccuracyDict[float]
# TODO: support arbitrary levels of dicts metrics: MetricsDict[float]
metrics: Dict[str, Dict[str, Dict[str, Dict[str, float]]]]
def main( def main(
......
...@@ -9,11 +9,12 @@ torchscript, caffe2, etc.) using Detectron2Go system (dataloading, evaluation, e ...@@ -9,11 +9,12 @@ torchscript, caffe2, etc.) using Detectron2Go system (dataloading, evaluation, e
import logging import logging
import sys import sys
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional, Type, Union from typing import Optional, Type, Union
import torch import torch
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.distributed import launch from d2go.distributed import launch
from d2go.evaluation.api import AccuracyDict, MetricsDict
from d2go.runner import BaseRunner from d2go.runner import BaseRunner
from d2go.setup import ( from d2go.setup import (
basic_argument_parser, basic_argument_parser,
...@@ -31,9 +32,8 @@ logger = logging.getLogger("d2go.tools.caffe2_evaluator") ...@@ -31,9 +32,8 @@ logger = logging.getLogger("d2go.tools.caffe2_evaluator")
@dataclass @dataclass
class EvaluatorOutput: class EvaluatorOutput:
accuracy: Dict[str, Any] accuracy: AccuracyDict[float]
# TODO: support arbitrary levels of dicts metrics: MetricsDict[float]
metrics: Dict[str, Dict[str, Dict[str, Dict[str, float]]]]
def main( def main(
......
...@@ -8,11 +8,12 @@ Detection Training Script. ...@@ -8,11 +8,12 @@ Detection Training Script.
import logging import logging
import sys import sys
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Type, Union from typing import Dict, List, Optional, Type, Union
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.distributed import launch from d2go.distributed import launch
from d2go.evaluation.api import AccuracyDict, MetricsDict
from d2go.runner import BaseRunner from d2go.runner import BaseRunner
from d2go.setup import ( from d2go.setup import (
basic_argument_parser, basic_argument_parser,
...@@ -34,9 +35,8 @@ logger = logging.getLogger("d2go.tools.train_net") ...@@ -34,9 +35,8 @@ logger = logging.getLogger("d2go.tools.train_net")
@dataclass @dataclass
class TrainNetOutput: class TrainNetOutput:
accuracy: Dict[str, Any] accuracy: AccuracyDict[float]
# TODO: support arbitrary levels of dicts metrics: MetricsDict[float]
metrics: Dict[str, Dict[str, Dict[str, Dict[str, Any]]]]
model_configs: Dict[str, str] model_configs: Dict[str, str]
# TODO: decide if `tensorboard_log_dir` should be part of output # TODO: decide if `tensorboard_log_dir` should be part of output
tensorboard_log_dir: Optional[str] = None tensorboard_log_dir: Optional[str] = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment