Commit f0f55cdc authored by Sudarshan Raghunathan's avatar Sudarshan Raghunathan Committed by Facebook GitHub Bot
Browse files

Add reply files to d2go training processes

Summary:
This diff contains a minimal set of changes to support returning reply files to MAST.

There are three parts:
1. First, we have a try..except in the main function to catch all the "catchable" Python exceptions. Exceptions from C++ code or segfaults will not be handled here.
2. Each exception is then written to a per-process JSON reply file.
3. At the end, all per-process files are stat-ed and the earliest file is copied to a location specified by MAST.

# Limitations
1. This only works when local processes are launched using multiprocessing (which is the default)
2. If any error happens in C++ code - it will likely not be caught in Python and the reply file might not have the correct logs

Differential Revision: D43097683

fbshipit-source-id: 0eaf4f19f6199a9c77f2ce4c7d2bbc2a2078be99
parent b21607b1
......@@ -31,6 +31,11 @@ from d2go.utils.misc import (
)
from detectron2.engine.defaults import create_ddp_model
from torch.distributed.elastic.multiprocessing.errors import (
_NOT_AVAILABLE,
ChildFailedError,
get_error_handler,
)
logger = logging.getLogger("d2go.tools.train_net")
......@@ -42,6 +47,15 @@ def main(
eval_only: bool = False,
resume: bool = True, # NOTE: always enable resume when running on cluster
) -> Union[TrainNetOutput, TestNetOutput]:
logger.info("Starting main")
error_handler = get_error_handler()
logger.debug(f">>>>>>> Error handler is: {type(error_handler)=}, {error_handler=}")
error_handler.initialize()
logger.debug("Error handler has been initialized")
try: # Main error handler starts here...
logger.debug(f"Entered main for d2go, {runner_class=}")
runner = setup_after_launch(cfg, output_dir, runner_class)
model = runner.build_model(cfg)
......@@ -52,7 +66,9 @@ def main(
# checkpointer.resume_or_load() will skip all additional checkpointable
# which may not be desired like ema states
if resume and checkpointer.has_checkpoint():
checkpoint = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume)
checkpoint = checkpointer.resume_or_load(
cfg.MODEL.WEIGHTS, resume=resume
)
else:
checkpoint = checkpointer.load(cfg.MODEL.WEIGHTS)
train_iter = checkpoint.get("iteration", None)
......@@ -70,11 +86,14 @@ def main(
model = create_ddp_model(
model,
fp16_compression=cfg.MODEL.DDP_FP16_GRAD_COMPRESS,
device_ids=None if cfg.MODEL.DEVICE == "cpu" else [comm.get_local_rank()],
device_ids=None
if cfg.MODEL.DEVICE == "cpu"
else [comm.get_local_rank()],
broadcast_buffers=False,
find_unused_parameters=cfg.MODEL.DDP_FIND_UNUSED_PARAMETERS,
)
logger.info("Starting train..")
trained_cfgs = runner.do_train(cfg, model, resume=resume)
final_eval = cfg.TEST.FINAL_EVAL
......@@ -94,6 +113,24 @@ def main(
model_configs=trained_model_configs,
metrics=metrics,
)
except ChildFailedError as e:
logger.info(f"Got a ChildFailedError: {e=}")
rank, failure = e.get_first_failure()
if failure.error_file != _NOT_AVAILABLE:
error_handler.dump_error_file(failure.error_file, failure.exitcode)
else:
logger.info(
(
f"local_rank {rank} FAILED with no error file."
f" Decorate your entrypoint fn with @record for traceback info."
f" See: https://pytorch.org/docs/stable/elastic/errors.html"
)
)
raise
except Exception as e:
logger.info(f"Caught a generic exception: {e=}")
error_handler.record_exception(e)
raise
def run_with_cmdline_args(args):
......@@ -122,6 +159,7 @@ def run_with_cmdline_args(args):
def cli(args=None):
logger.info(f"Inside CLI, {args=}")
parser = basic_argument_parser(requires_output_dir=False)
parser.add_argument(
"--eval-only", action="store_true", help="perform evaluation only"
......@@ -153,4 +191,29 @@ def build_cli_args(
if __name__ == "__main__":
setup_root_logger()
logger.info("Starting CLI application")
try:
cli()
finally:
logging.info("Entering final reply file generation step")
import glob
import os
import shutil
torchx_reply_files = glob.glob("/tmp/torchx_*/**/*.json", recursive=True)
logger.info(
f"Found the following reply files on this host: {torchx_reply_files}"
)
first_reply_file = None
first_reply_file_st = float("Inf")
for f in torchx_reply_files:
if (mtime := os.stat(f).st_mtime) < first_reply_file_st:
first_reply_file = f
first_reply_file_st = mtime
if first_reply_file and os.environ.get("MAST_HPC_TASK_FAILURE_REPLY_FILE"):
logger.info(
f'Copying {first_reply_file=} to {os.environ["MAST_HPC_TASK_FAILURE_REPLY_FILE"]}'
)
shutil.copyfile(
first_reply_file, os.environ["MAST_HPC_TASK_FAILURE_REPLY_FILE"]
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment