Commit 65eeb427 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Support Torch DDP for single-stage, num_microbatches() > 1

parent 8cd16667
......@@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import get_num_microbatches
......@@ -74,6 +76,14 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
return input_tensor_grad
@contextmanager
def dummy_handler():
try:
yield
finally:
pass
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
optimizer, timers, forward_only):
"""Run forward and backward passes with no pipeline parallelism
......@@ -83,15 +93,27 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
assert len(model) == 1
model = model[0]
context_handler = dummy_handler
if isinstance(model, torchDDP):
context_handler = model.no_sync
losses_reduced = []
for i in range(get_num_microbatches()):
input_tensor, output_tensor_grad = None, None
with context_handler():
for i in range(get_num_microbatches() - 1):
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
return losses_reduced
......
......@@ -309,8 +309,6 @@ def setup_model_and_optimizer(model_provider_func):
args.iteration = 0
# We only support local DDP with multiple micro-batches.
if get_num_microbatches() > 1:
assert args.DDP_impl == 'local'
if len(model) > 1:
assert args.DDP_impl == 'local'
if mpu.get_pipeline_model_parallel_world_size() > 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment