Unverified Commit 4402c410 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[cleanup] ShardedDDP - inline gatekeeper (#248)

parent f74afebb
......@@ -15,7 +15,6 @@ from typing import Any, Callable, Generator, List, Tuple, Union
import torch
from torch import nn
from torch.autograd import Variable
import torch.distributed as dist
from torch.nn import Parameter
......@@ -177,11 +176,6 @@ class ShardedDataParallel(nn.Module):
Either way a delayed action is necessary and is passed as a callback.
"""
def gatekeeper() -> None:
# Make sure that all the asynchronous calls have concluded before moving on. Consume the futures
# and execute the delayed actions (release gradients, unroll the buckets)
Variable._execution_engine.queue_callback(optimizer._consume_work_handles)
def reduce_direct(*_: Any) -> None:
# Skip gradient reduction, do not alter status flags
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
......@@ -206,9 +200,11 @@ class ShardedDataParallel(nn.Module):
)
)
# If all the reduce operations have been called, add the gatekeeper
# If all the reduce operations have been called,
# make sure that all the asynchronous calls have concluded before moving on
# and execute the delayed actions (release gradients, unroll the buckets)
if len(optimizer.work_handles) == optimizer._max_work_handles:
gatekeeper()
optimizer._consume_work_handles()
# Bucket, update status, and possibly unroll the results
def reduce_bucket(*_: Any) -> None:
......@@ -240,9 +236,11 @@ class ShardedDataParallel(nn.Module):
)
)
# If all the reduce operations have been called, add the gatekeeper
# If all the reduce operations have been called,
# make sure that all the asynchronous calls have concluded before moving on
# and execute the delayed actions (release gradients, unroll the buckets)
if len(optimizer.work_handles) == optimizer._max_work_handles:
gatekeeper()
optimizer._consume_work_handles()
return reduce_bucket if should_bucket else reduce_direct
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment