Commit 8521bb22 authored by Michael Carilli's avatar Michael Carilli
Browse files

Patching in changes to enable multiple allreduces in flight

parent 61b8a0fd
# Introduction
This repository holds NVIDIA-maintained utilities to streamline
mixed precision and distributed training in Pytorch.
This repository holds NVIDIA-maintained utilities to streamline
mixed precision and distributed training in Pytorch.
Some of the code here will be included in upstream Pytorch eventually.
The intention of Apex is to make up-to-date utilities available to
The intention of Apex is to make up-to-date utilities available to
users as quickly as possible.
## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex)
......@@ -29,7 +29,7 @@ different flags to `amp.initialize`.
## 2. Distributed Training
`apex.parallel.DistributedDataParallel` is a module wrapper, similar to
`apex.parallel.DistributedDataParallel` is a module wrapper, similar to
`torch.nn.parallel.DistributedDataParallel`. It enables convenient multiprocess distributed training,
optimized for NVIDIA's NCCL communication library.
......
......@@ -133,7 +133,7 @@ def _initialize(models, optimizers, properties, num_losses=1):
if not _amp_state.allow_incoming_model_not_fp32:
check_params_fp32(models)
check_optimizers(optimizers)
# In the future, when FP16_Optimizer can be deprecated and master weights can
......@@ -163,7 +163,7 @@ def _initialize(models, optimizers, properties, num_losses=1):
model.forward = patch_forward(model.forward)
# State dict trick to recast any preexisting per-param state tensors
# State dict trick to recast any preexisting per-param state tensors
for optimizer in optimizers:
optimizer.load_state_dict(optimizer.state_dict())
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment