Unverified Commit 4402c410 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[cleanup] ShardedDDP - inline gatekeeper (#248)

parent f74afebb
...@@ -15,7 +15,6 @@ from typing import Any, Callable, Generator, List, Tuple, Union ...@@ -15,7 +15,6 @@ from typing import Any, Callable, Generator, List, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from torch.autograd import Variable
import torch.distributed as dist import torch.distributed as dist
from torch.nn import Parameter from torch.nn import Parameter
...@@ -177,11 +176,6 @@ class ShardedDataParallel(nn.Module): ...@@ -177,11 +176,6 @@ class ShardedDataParallel(nn.Module):
Either way a delayed action is necessary and is passed as a callback. Either way a delayed action is necessary and is passed as a callback.
""" """
def gatekeeper() -> None:
# Make sure that all the asynchronous calls have concluded before moving on. Consume the futures
# and execute the delayed actions (release gradients, unroll the buckets)
Variable._execution_engine.queue_callback(optimizer._consume_work_handles)
def reduce_direct(*_: Any) -> None: def reduce_direct(*_: Any) -> None:
# Skip gradient reduction, do not alter status flags # Skip gradient reduction, do not alter status flags
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]: if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
...@@ -206,9 +200,11 @@ class ShardedDataParallel(nn.Module): ...@@ -206,9 +200,11 @@ class ShardedDataParallel(nn.Module):
) )
) )
# If all the reduce operations have been called, add the gatekeeper # If all the reduce operations have been called,
# make sure that all the asynchronous calls have concluded before moving on
# and execute the delayed actions (release gradients, unroll the buckets)
if len(optimizer.work_handles) == optimizer._max_work_handles: if len(optimizer.work_handles) == optimizer._max_work_handles:
gatekeeper() optimizer._consume_work_handles()
# Bucket, update status, and possibly unroll the results # Bucket, update status, and possibly unroll the results
def reduce_bucket(*_: Any) -> None: def reduce_bucket(*_: Any) -> None:
...@@ -240,9 +236,11 @@ class ShardedDataParallel(nn.Module): ...@@ -240,9 +236,11 @@ class ShardedDataParallel(nn.Module):
) )
) )
# If all the reduce operations have been called, add the gatekeeper # If all the reduce operations have been called,
# make sure that all the asynchronous calls have concluded before moving on
# and execute the delayed actions (release gradients, unroll the buckets)
if len(optimizer.work_handles) == optimizer._max_work_handles: if len(optimizer.work_handles) == optimizer._max_work_handles:
gatekeeper() optimizer._consume_work_handles()
return reduce_bucket if should_bucket else reduce_direct return reduce_bucket if should_bucket else reduce_direct
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment