Commit 65eeb427 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Support Torch DDP for single-stage, num_microbatches() > 1

parent 8cd16667
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from contextlib import contextmanager
import torch import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args from megatron import get_args
from megatron import get_num_microbatches from megatron import get_num_microbatches
...@@ -74,6 +76,14 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad): ...@@ -74,6 +76,14 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
return input_tensor_grad return input_tensor_grad
@contextmanager
def dummy_handler():
try:
yield
finally:
pass
def forward_backward_no_pipelining(forward_step_func, data_iterator, model, def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
optimizer, timers, forward_only): optimizer, timers, forward_only):
"""Run forward and backward passes with no pipeline parallelism """Run forward and backward passes with no pipeline parallelism
...@@ -83,14 +93,26 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model, ...@@ -83,14 +93,26 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
assert len(model) == 1 assert len(model) == 1
model = model[0] model = model[0]
context_handler = dummy_handler
if isinstance(model, torchDDP):
context_handler = model.no_sync
losses_reduced = [] losses_reduced = []
for i in range(get_num_microbatches()): input_tensor, output_tensor_grad = None, None
input_tensor, output_tensor_grad = None, None with context_handler():
output_tensor = forward_step(forward_step_func, data_iterator, model, for i in range(get_num_microbatches() - 1):
input_tensor, losses_reduced) output_tensor = forward_step(forward_step_func, data_iterator, model,
if not forward_only: input_tensor, losses_reduced)
backward_step(optimizer, input_tensor, output_tensor, if not forward_only:
output_tensor_grad) backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
return losses_reduced return losses_reduced
......
...@@ -309,8 +309,6 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -309,8 +309,6 @@ def setup_model_and_optimizer(model_provider_func):
args.iteration = 0 args.iteration = 0
# We only support local DDP with multiple micro-batches. # We only support local DDP with multiple micro-batches.
if get_num_microbatches() > 1:
assert args.DDP_impl == 'local'
if len(model) > 1: if len(model) > 1:
assert args.DDP_impl == 'local' assert args.DDP_impl == 'local'
if mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment