Commit 57809b0f authored by Zhicheng Yan's avatar Zhicheng Yan Committed by Facebook GitHub Bot
Browse files

support FP16 gradient compression

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/70

DDP supports an fp16_compress_hook which compresses the gradient to FP16 before communication. This can result in a significant speed up.

Add one argument `_C.MODEL.DDP_FP16_GRAD_COMPRESS` to trigger it.

Reviewed By: zhanghang1989

Differential Revision: D28467701

fbshipit-source-id: 3c80865222f48eb8fe6947ea972448c445ee3ef3
parent daf37a84
...@@ -59,6 +59,8 @@ def get_default_cfg(_C): ...@@ -59,6 +59,8 @@ def get_default_cfg(_C):
# Set find_unused_parameters for DistributedDataParallel. # Set find_unused_parameters for DistributedDataParallel.
_C.MODEL.DDP_FIND_UNUSED_PARAMETERS = False _C.MODEL.DDP_FIND_UNUSED_PARAMETERS = False
# Set FP16 gradient compression for DistributedDataParallel.
_C.MODEL.DDP_FP16_GRAD_COMPRESS = False
# Set default optimizer # Set default optimizer
_C.SOLVER.OPTIMIZER = "sgd" _C.SOLVER.OPTIMIZER = "sgd"
......
...@@ -16,7 +16,7 @@ from d2go.setup import ( ...@@ -16,7 +16,7 @@ from d2go.setup import (
setup_after_launch, setup_after_launch,
) )
from d2go.utils.misc import print_metrics_table, dump_trained_model_configs from d2go.utils.misc import print_metrics_table, dump_trained_model_configs
from torch.nn.parallel import DistributedDataParallel from detectron2.engine.defaults import create_ddp_model
logger = logging.getLogger("d2go.tools.train_net") logger = logging.getLogger("d2go.tools.train_net")
...@@ -53,13 +53,13 @@ def main( ...@@ -53,13 +53,13 @@ def main(
"metrics": metrics, "metrics": metrics,
} }
if comm.get_world_size() > 1: model = create_ddp_model(
model = DistributedDataParallel( model,
model, fp16_compression=cfg.MODEL.DDP_FP16_GRAD_COMPRESS,
device_ids=None if cfg.MODEL.DEVICE == "cpu" else [comm.get_local_rank()], device_ids=None if cfg.MODEL.DEVICE == "cpu" else [comm.get_local_rank()],
broadcast_buffers=False, broadcast_buffers=False,
find_unused_parameters=cfg.MODEL.DDP_FIND_UNUSED_PARAMETERS, find_unused_parameters=cfg.MODEL.DDP_FIND_UNUSED_PARAMETERS,
) )
trained_cfgs = runner.do_train(cfg, model, resume=resume) trained_cfgs = runner.do_train(cfg, model, resume=resume)
metrics = runner.do_test(cfg, model) metrics = runner.do_test(cfg, model)
...@@ -88,6 +88,7 @@ def run_with_cmdline_args(args): ...@@ -88,6 +88,7 @@ def run_with_cmdline_args(args):
args=(cfg, output_dir, runner, args.eval_only, args.resume), args=(cfg, output_dir, runner, args.eval_only, args.resume),
) )
def cli(): def cli():
parser = basic_argument_parser(requires_output_dir=False) parser = basic_argument_parser(requires_output_dir=False)
parser.add_argument( parser.add_argument(
...@@ -100,5 +101,6 @@ def cli(): ...@@ -100,5 +101,6 @@ def cli():
) )
run_with_cmdline_args(parser.parse_args()) run_with_cmdline_args(parser.parse_args())
if __name__ == "__main__": if __name__ == "__main__":
cli() cli()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment