Commit da53aa10 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

add API reset optimzation engine

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/640

Reviewed By: tglik

Differential Revision: D51908239

fbshipit-source-id: 7bcbad1fc7065b736cf4e38d155eed5d734758f7
parent 409cd213
...@@ -198,6 +198,14 @@ class BaseRunner(object): ...@@ -198,6 +198,14 @@ class BaseRunner(object):
""" """
pass pass
def cleanup(self) -> None:
"""
Override `cleanup` to add custom clean ups such as:
- de-register datasets.
- free up global variables.
"""
pass
@classmethod @classmethod
def create_shared_context(cls, cfg) -> D2GoSharedContext: def create_shared_context(cls, cfg) -> D2GoSharedContext:
""" """
......
...@@ -136,6 +136,7 @@ def main( ...@@ -136,6 +136,7 @@ def main(
metrics = {"_name_": {dataset_name: results}} metrics = {"_name_": {dataset_name: results}}
print_metrics_table(metrics) print_metrics_table(metrics)
runner.cleanup()
return BenchmarkDataOutput( return BenchmarkDataOutput(
accuracy=metrics, accuracy=metrics,
metrics=metrics, metrics=metrics,
......
...@@ -56,6 +56,7 @@ def main( ...@@ -56,6 +56,7 @@ def main(
predictor = create_predictor(predictor_path) predictor = create_predictor(predictor_path)
metrics = runner.do_test(cfg, predictor) metrics = runner.do_test(cfg, predictor)
print_metrics_table(metrics) print_metrics_table(metrics)
runner.cleanup()
return EvaluatorOutput( return EvaluatorOutput(
accuracy=metrics, accuracy=metrics,
metrics=metrics, metrics=metrics,
......
...@@ -95,6 +95,7 @@ def main( ...@@ -95,6 +95,7 @@ def main(
if not skip_if_fail: if not skip_if_fail:
raise e raise e
runner.cleanup()
return ExporterOutput( return ExporterOutput(
predictor_paths=predictor_paths, predictor_paths=predictor_paths,
accuracy_comparison={}, accuracy_comparison={},
......
...@@ -24,7 +24,6 @@ from d2go.setup import ( ...@@ -24,7 +24,6 @@ from d2go.setup import (
setup_root_logger, setup_root_logger,
) )
from d2go.trainer.api import TestNetOutput, TrainNetOutput from d2go.trainer.api import TestNetOutput, TrainNetOutput
from d2go.trainer.fsdp import is_fsdp_enabled
from d2go.utils.mast import gather_mast_errors, mast_error_handler from d2go.utils.mast import gather_mast_errors, mast_error_handler
from d2go.utils.misc import ( from d2go.utils.misc import (
dump_trained_model_configs, dump_trained_model_configs,
...@@ -68,6 +67,7 @@ def main( ...@@ -68,6 +67,7 @@ def main(
model.eval() model.eval()
metrics = runner.do_test(cfg, model, train_iter=train_iter) metrics = runner.do_test(cfg, model, train_iter=train_iter)
print_metrics_table(metrics) print_metrics_table(metrics)
runner.cleanup()
return TestNetOutput( return TestNetOutput(
accuracy=metrics, accuracy=metrics,
metrics=metrics, metrics=metrics,
...@@ -98,6 +98,7 @@ def main( ...@@ -98,6 +98,7 @@ def main(
# dump config files for trained models # dump config files for trained models
trained_model_configs = dump_trained_model_configs(cfg.OUTPUT_DIR, trained_cfgs) trained_model_configs = dump_trained_model_configs(cfg.OUTPUT_DIR, trained_cfgs)
runner.cleanup()
return TrainNetOutput( return TrainNetOutput(
# for e2e_workflow # for e2e_workflow
accuracy=metrics, accuracy=metrics,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment