Commit e4fa6d63 authored by Francisc Bungiu's avatar Francisc Bungiu Committed by Facebook GitHub Bot
Browse files

Extend reply files to all binaries

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/591

We previously added reply files for train_net, but not the other relevant binaries with MAST support: evaluator and lightning.
Adding support here by extracting the common bits into a separate module and wrapping the functions to reuse the functionality.

Differential Revision: D47293689

fbshipit-source-id: 70630a471c0cf037d180c9edfb57a4db4fdf7bde
parent 53748d9d
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
from typing import Callable
from torch.distributed.elastic.multiprocessing.errors import (
_NOT_AVAILABLE,
ChildFailedError,
get_error_handler,
)
logger = logging.getLogger(__name__)
def mast_error_handler(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
logger.info("Starting main")
error_handler = get_error_handler()
logger.debug(f"Error handler is: {type(error_handler)=}, {error_handler=}")
error_handler.initialize()
logger.debug("Error handler has been initialized")
try:
logger.debug("Entered main for d2go")
return func(*args, **kwargs)
except ChildFailedError as e:
logger.info(f"Got a ChildFailedError: {e=}")
rank, failure = e.get_first_failure()
if failure.error_file != _NOT_AVAILABLE:
error_handler.dump_error_file(failure.error_file, failure.exitcode)
else:
logger.info(
(
f"local_rank {rank} FAILED with no error file."
f" Decorate your entrypoint fn with @record for traceback info."
f" See: https://pytorch.org/docs/stable/elastic/errors.html"
)
)
raise
except Exception as e:
logger.info(f"Caught a generic exception: {e=}")
error_handler.record_exception(e)
raise
return wrapper
def gather_mast_errors(func: Callable) -> Callable:
def wrapper(*args, **kwargs) -> None:
logger.info("Starting CLI application")
try:
func(*args, **kwargs)
finally:
logging.info("Entering final reply file generation step")
import glob
import os
import shutil
torchx_reply_files = glob.glob("/tmp/torchx_*/**/*.json", recursive=True)
logger.info(
f"Found the following reply files on this host: {torchx_reply_files}"
)
first_reply_file = None
first_reply_file_st = float("Inf")
for f in torchx_reply_files:
if (mtime := os.stat(f).st_mtime) < first_reply_file_st:
first_reply_file = f
first_reply_file_st = mtime
if first_reply_file and os.environ.get("MAST_HPC_TASK_FAILURE_REPLY_FILE"):
logger.info(
f'Copying {first_reply_file=} to {os.environ["MAST_HPC_TASK_FAILURE_REPLY_FILE"]}'
)
shutil.copyfile(
first_reply_file, os.environ["MAST_HPC_TASK_FAILURE_REPLY_FILE"]
)
return wrapper
...@@ -8,13 +8,11 @@ torchscript, caffe2, etc.) using Detectron2Go system (dataloading, evaluation, e ...@@ -8,13 +8,11 @@ torchscript, caffe2, etc.) using Detectron2Go system (dataloading, evaluation, e
import logging import logging
import sys import sys
from dataclasses import dataclass from typing import Callable, List, Optional, Type, Union
from typing import List, Optional, Type, Union
import torch import torch
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.distributed import launch from d2go.distributed import launch
from d2go.evaluation.api import AccuracyDict, MetricsDict
from d2go.quantization.qconfig import smart_decode_backend from d2go.quantization.qconfig import smart_decode_backend
from d2go.runner import BaseRunner from d2go.runner import BaseRunner
from d2go.setup import ( from d2go.setup import (
...@@ -28,6 +26,7 @@ from d2go.setup import ( ...@@ -28,6 +26,7 @@ from d2go.setup import (
setup_root_logger, setup_root_logger,
) )
from d2go.trainer.api import EvaluatorOutput from d2go.trainer.api import EvaluatorOutput
from d2go.utils.mast import gather_mast_errors, mast_error_handler
from d2go.utils.misc import print_metrics_table, save_binary_outputs from d2go.utils.misc import print_metrics_table, save_binary_outputs
from mobile_cv.predictor.api import create_predictor from mobile_cv.predictor.api import create_predictor
...@@ -63,10 +62,18 @@ def main( ...@@ -63,10 +62,18 @@ def main(
) )
def wrapped_main(*args, **kwargs) -> Callable:
return mast_error_handler(main)(*args, **kwargs)
def run_with_cmdline_args(args): def run_with_cmdline_args(args):
cfg, output_dir, runner_name = prepare_for_launch(args) cfg, output_dir, runner_name = prepare_for_launch(args)
shared_context = setup_before_launch(cfg, output_dir, runner_name) shared_context = setup_before_launch(cfg, output_dir, runner_name)
main_func = main if args.disable_post_mortem else post_mortem_if_fail_for_main(main) main_func = (
wrapped_main
if args.disable_post_mortem
else post_mortem_if_fail_for_main(wrapped_main)
)
outputs = launch( outputs = launch(
main_func, main_func,
args.num_processes, args.num_processes,
...@@ -137,4 +144,4 @@ def cli(args=None): ...@@ -137,4 +144,4 @@ def cli(args=None):
if __name__ == "__main__": if __name__ == "__main__":
setup_root_logger() setup_root_logger()
cli() gather_mast_errors(cli())
...@@ -7,7 +7,7 @@ Detection Training Script. ...@@ -7,7 +7,7 @@ Detection Training Script.
import logging import logging
import sys import sys
from typing import List, Type, Union from typing import Callable, List, Type, Union
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
from d2go.config import CfgNode from d2go.config import CfgNode
...@@ -24,6 +24,7 @@ from d2go.setup import ( ...@@ -24,6 +24,7 @@ from d2go.setup import (
) )
from d2go.trainer.api import TestNetOutput, TrainNetOutput from d2go.trainer.api import TestNetOutput, TrainNetOutput
from d2go.trainer.fsdp import is_fsdp_enabled from d2go.trainer.fsdp import is_fsdp_enabled
from d2go.utils.mast import gather_mast_errors, mast_error_handler
from d2go.utils.misc import ( from d2go.utils.misc import (
dump_trained_model_configs, dump_trained_model_configs,
print_metrics_table, print_metrics_table,
...@@ -31,12 +32,6 @@ from d2go.utils.misc import ( ...@@ -31,12 +32,6 @@ from d2go.utils.misc import (
) )
from detectron2.engine.defaults import create_ddp_model from detectron2.engine.defaults import create_ddp_model
from torch.distributed.elastic.multiprocessing.errors import (
_NOT_AVAILABLE,
ChildFailedError,
get_error_handler,
)
logger = logging.getLogger("d2go.tools.train_net") logger = logging.getLogger("d2go.tools.train_net")
# Make sure logging is set up centrally even for e.g. dataloading workers which # Make sure logging is set up centrally even for e.g. dataloading workers which
# have entry points outside of D2Go. # have entry points outside of D2Go.
...@@ -50,13 +45,6 @@ def main( ...@@ -50,13 +45,6 @@ def main(
eval_only: bool = False, eval_only: bool = False,
resume: bool = True, # NOTE: always enable resume when running on cluster resume: bool = True, # NOTE: always enable resume when running on cluster
) -> Union[TrainNetOutput, TestNetOutput]: ) -> Union[TrainNetOutput, TestNetOutput]:
logger.info("Starting main")
error_handler = get_error_handler()
logger.debug(f"Error handler is: {type(error_handler)=}, {error_handler=}")
error_handler.initialize()
logger.debug("Error handler has been initialized")
try: # Main error handler starts here...
logger.debug(f"Entered main for d2go, {runner_class=}") logger.debug(f"Entered main for d2go, {runner_class=}")
runner = setup_after_launch(cfg, output_dir, runner_class) runner = setup_after_launch(cfg, output_dir, runner_class)
...@@ -68,9 +56,7 @@ def main( ...@@ -68,9 +56,7 @@ def main(
# checkpointer.resume_or_load() will skip all additional checkpointable # checkpointer.resume_or_load() will skip all additional checkpointable
# which may not be desired like ema states # which may not be desired like ema states
if resume and checkpointer.has_checkpoint(): if resume and checkpointer.has_checkpoint():
checkpoint = checkpointer.resume_or_load( checkpoint = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume)
cfg.MODEL.WEIGHTS, resume=resume
)
else: else:
checkpoint = checkpointer.load(cfg.MODEL.WEIGHTS) checkpoint = checkpointer.load(cfg.MODEL.WEIGHTS)
train_iter = checkpoint.get("iteration", None) train_iter = checkpoint.get("iteration", None)
...@@ -88,9 +74,7 @@ def main( ...@@ -88,9 +74,7 @@ def main(
model = create_ddp_model( model = create_ddp_model(
model, model,
fp16_compression=cfg.MODEL.DDP_FP16_GRAD_COMPRESS, fp16_compression=cfg.MODEL.DDP_FP16_GRAD_COMPRESS,
device_ids=None device_ids=None if cfg.MODEL.DEVICE == "cpu" else [comm.get_local_rank()],
if cfg.MODEL.DEVICE == "cpu"
else [comm.get_local_rank()],
broadcast_buffers=False, broadcast_buffers=False,
find_unused_parameters=cfg.MODEL.DDP_FIND_UNUSED_PARAMETERS, find_unused_parameters=cfg.MODEL.DDP_FIND_UNUSED_PARAMETERS,
gradient_as_bucket_view=cfg.MODEL.DDP_GRADIENT_AS_BUCKET_VIEW, gradient_as_bucket_view=cfg.MODEL.DDP_GRADIENT_AS_BUCKET_VIEW,
...@@ -116,31 +100,21 @@ def main( ...@@ -116,31 +100,21 @@ def main(
model_configs=trained_model_configs, model_configs=trained_model_configs,
metrics=metrics, metrics=metrics,
) )
except ChildFailedError as e:
logger.info(f"Got a ChildFailedError: {e=}")
rank, failure = e.get_first_failure() def wrapped_main(*args, **kwargs) -> Callable:
if failure.error_file != _NOT_AVAILABLE: return mast_error_handler(main)(*args, **kwargs)
error_handler.dump_error_file(failure.error_file, failure.exitcode)
else:
logger.info(
(
f"local_rank {rank} FAILED with no error file."
f" Decorate your entrypoint fn with @record for traceback info."
f" See: https://pytorch.org/docs/stable/elastic/errors.html"
)
)
raise
except Exception as e:
logger.info(f"Caught a generic exception: {e=}")
error_handler.record_exception(e)
raise
def run_with_cmdline_args(args): def run_with_cmdline_args(args):
cfg, output_dir, runner_name = prepare_for_launch(args) cfg, output_dir, runner_name = prepare_for_launch(args)
shared_context = setup_before_launch(cfg, output_dir, runner_name) shared_context = setup_before_launch(cfg, output_dir, runner_name)
main_func = main if args.disable_post_mortem else post_mortem_if_fail_for_main(main) main_func = (
wrapped_main
if args.disable_post_mortem
else post_mortem_if_fail_for_main(wrapped_main)
)
if args.run_as_worker: if args.run_as_worker:
logger.info("Running as worker") logger.info("Running as worker")
...@@ -176,6 +150,8 @@ def run_with_cmdline_args(args): ...@@ -176,6 +150,8 @@ def run_with_cmdline_args(args):
# Only save results from global rank 0 for consistency. # Only save results from global rank 0 for consistency.
if args.save_return_file is not None and args.machine_rank == 0: if args.save_return_file is not None and args.machine_rank == 0:
logger.info(f"Operator results: {outputs[0]}")
logger.info(f"Writing results to {args.save_return_file}.")
save_binary_outputs(args.save_return_file, outputs[0]) save_binary_outputs(args.save_return_file, outputs[0])
...@@ -211,29 +187,4 @@ def build_cli_args( ...@@ -211,29 +187,4 @@ def build_cli_args(
if __name__ == "__main__": if __name__ == "__main__":
logger.info("Starting CLI application") gather_mast_errors(cli())
try:
cli()
finally:
logging.info("Entering final reply file generation step")
import glob
import os
import shutil
torchx_reply_files = glob.glob("/tmp/torchx_*/**/*.json", recursive=True)
logger.info(
f"Found the following reply files on this host: {torchx_reply_files}"
)
first_reply_file = None
first_reply_file_st = float("Inf")
for f in torchx_reply_files:
if (mtime := os.stat(f).st_mtime) < first_reply_file_st:
first_reply_file = f
first_reply_file_st = mtime
if first_reply_file and os.environ.get("MAST_HPC_TASK_FAILURE_REPLY_FILE"):
logger.info(
f'Copying {first_reply_file=} to {os.environ["MAST_HPC_TASK_FAILURE_REPLY_FILE"]}'
)
shutil.copyfile(
first_reply_file, os.environ["MAST_HPC_TASK_FAILURE_REPLY_FILE"]
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment