Commit 57809b0f authored by Zhicheng Yan's avatar Zhicheng Yan Committed by Facebook GitHub Bot
Browse files

support FP16 gradient compression

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/70

DDP supports an fp16_compress_hook which compresses the gradient to FP16 before communication. This can result in a significant speed up.

Add one argument `_C.MODEL.DDP_FP16_GRAD_COMPRESS` to trigger it.

Reviewed By: zhanghang1989

Differential Revision: D28467701

fbshipit-source-id: 3c80865222f48eb8fe6947ea972448c445ee3ef3
parent daf37a84
......@@ -59,6 +59,8 @@ def get_default_cfg(_C):
# Set find_unused_parameters for DistributedDataParallel.
_C.MODEL.DDP_FIND_UNUSED_PARAMETERS = False
# Set FP16 gradient compression for DistributedDataParallel.
_C.MODEL.DDP_FP16_GRAD_COMPRESS = False
# Set default optimizer
_C.SOLVER.OPTIMIZER = "sgd"
......
......@@ -16,7 +16,7 @@ from d2go.setup import (
setup_after_launch,
)
from d2go.utils.misc import print_metrics_table, dump_trained_model_configs
from torch.nn.parallel import DistributedDataParallel
from detectron2.engine.defaults import create_ddp_model
logger = logging.getLogger("d2go.tools.train_net")
......@@ -53,9 +53,9 @@ def main(
"metrics": metrics,
}
if comm.get_world_size() > 1:
model = DistributedDataParallel(
model = create_ddp_model(
model,
fp16_compression=cfg.MODEL.DDP_FP16_GRAD_COMPRESS,
device_ids=None if cfg.MODEL.DEVICE == "cpu" else [comm.get_local_rank()],
broadcast_buffers=False,
find_unused_parameters=cfg.MODEL.DDP_FIND_UNUSED_PARAMETERS,
......@@ -88,6 +88,7 @@ def run_with_cmdline_args(args):
args=(cfg, output_dir, runner, args.eval_only, args.resume),
)
def cli():
parser = basic_argument_parser(requires_output_dir=False)
parser.add_argument(
......@@ -100,5 +101,6 @@ def cli():
)
run_with_cmdline_args(parser.parse_args())
if __name__ == "__main__":
cli()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment