Unverified Commit 131a5356 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[perf][minor] cache the rank lookups, small shardedddp perf fix (#474)

* [perf][minor] cache the rank lookups, small shardedddp perf fix
* tiny improvement, code quality
parent d1fab39e
......@@ -107,7 +107,7 @@ class ShardedDataParallel(nn.Module):
self.enable_broadcast_buffers = broadcast_buffers
self.auto_refresh_trainable = auto_refresh_trainable
self.reduce_fp16 = reduce_fp16
if reduce_buffer_size > 0:
if reduce_buffer_size > 0 and reduce_fp16:
self.reduce_fp16 = False
logging.warning(
"fp16 gradient reduction is not compatible with reduction buffers, which are requested. fp16 grad reduction is deactivated."
......@@ -124,6 +124,9 @@ class ShardedDataParallel(nn.Module):
self.reference_global_rank = OSS.get_global_rank(self.process_group, 0) # picking rank 0 as the reference
self.rank = dist.get_rank(self.process_group)
self.global_rank = OSS.get_global_rank(self.process_group, self.rank)
self._local_to_global_rank = [
OSS.get_global_rank(self.process_group, i) for i in range(dist.get_world_size(self.process_group))
]
# Expose some of the PytorchDDP attributes, some frameworks rely on them.
# See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel
......@@ -153,10 +156,6 @@ class ShardedDataParallel(nn.Module):
self._trainable_param_to_rank: Dict[torch.Tensor, int] = {}
self._reference_trainable_mask = list(map(_trainable, self._all_params))
# - keep track of the grads which have already been reduced
self._reduced_grads = 0
self._reduced_grads_max = 0
# - setup buckets and tensor views
model_size = sum([p.numel() for p in self.module.parameters()])
self.buffer_max_size = min(reduce_buffer_size, model_size)
......@@ -358,7 +357,7 @@ class ShardedDataParallel(nn.Module):
def _clear_counters(self) -> None:
"""Reset all the grad reduce and call counters"""
self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced]
self._reduced_grads = 0
self._bucket_flush_callback_set = False
# Do not reset the buckets
if self.use_buckets:
......@@ -375,8 +374,6 @@ class ShardedDataParallel(nn.Module):
bucket.reset()
self._bucket_flush_callback_set = False
if not self.should_accumulate_grads:
self.accumulate_grads_flipped = False
......@@ -406,7 +403,7 @@ class ShardedDataParallel(nn.Module):
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
if not self._bucket_flush_callback_set:
Variable._execution_engine.queue_callback(self._flush_buckets)
Variable._execution_engine.queue_callback(self._flush_reduce_calls)
self._bucket_flush_callback_set = True
# Make sure that this is not fired twice
......@@ -425,27 +422,21 @@ class ShardedDataParallel(nn.Module):
param.grad.data = param.grad.data.to(dtype=param.dtype)
# Async reduce for this buffer, log the future
dst_global_rank = OSS.get_global_rank(self.process_group, dst_rank)
self._work_handles.append(
Workhandle(
handle=dist.reduce(
tensor=param.grad.data, dst=dst_global_rank, group=self.process_group, async_op=True
tensor=param.grad.data,
dst=self._local_to_global_rank[dst_rank],
group=self.process_group,
async_op=True,
),
callback=cleanup,
)
)
self._reduced_grads += 1
# Opportunistically try to empty the queue
# Opportunistically try to empty the queue, free memory
self._try_consume_work_handle()
# If all the reduce operations have been called,
# make sure that all the asynchronous calls have concluded before moving on
# and execute the delayed actions (release gradients, unroll the buckets)
if self._reduced_grads == self._reduced_grads_max:
self._consume_work_handles()
else:
@torch.no_grad()
......@@ -455,7 +446,7 @@ class ShardedDataParallel(nn.Module):
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
if not self._bucket_flush_callback_set:
Variable._execution_engine.queue_callback(self._flush_buckets)
Variable._execution_engine.queue_callback(self._flush_reduce_calls)
self._bucket_flush_callback_set = True
# Make sure that this is not fired twice
......@@ -480,17 +471,10 @@ class ShardedDataParallel(nn.Module):
callback=None,
)
)
self._reduced_grads += 1
# Opportunistically try to empty the queue
self._try_consume_work_handle()
# If all the reduce operations have been called,
# make sure that all the asynchronous calls have concluded before moving on
# and execute the delayed actions (release gradients, unroll the buckets)
if self._reduced_grads == self._reduced_grads_max:
self._consume_work_handles()
return reduce
def _setup_backward_hooks(self) -> None:
......@@ -550,9 +534,6 @@ class ShardedDataParallel(nn.Module):
This method can be a slow for big models, but it it not typically called often (not for every forward for instance)
"""
# A priori, one reduce call per param
self._reduced_grads_max = len(self._trainable_params)
if not self.use_buckets:
return
......@@ -572,7 +553,7 @@ class ShardedDataParallel(nn.Module):
]
bucket = self.buckets[device][dst_rank]
bucket.destination = OSS.get_global_rank(self.process_group, dst_rank)
bucket.destination = self._local_to_global_rank[dst_rank]
# Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket
......@@ -590,7 +571,6 @@ class ShardedDataParallel(nn.Module):
bucket.fill = fill_next
# Update the bucket
self._reduced_grads_max -= 1 # one less reduce call per bucketed grad
self.buckets[device][dst_rank].max_params_checked_in += 1
else:
......@@ -602,8 +582,6 @@ class ShardedDataParallel(nn.Module):
for bucket in self._bucket_list:
bucket.buffer.resize_(bucket.fill)
bucket.sent = True
if bucket.max_params_checked_in > 0:
self._reduced_grads_max += 1 # one reduce call per bucket
def _consume_work_handles(self) -> None:
"""Consume all the futures which are tied to this optimizer's buckets.
......@@ -623,20 +601,22 @@ class ShardedDataParallel(nn.Module):
if work_handle.callback is not None:
work_handle.callback()
# Flush all the buckets, just in case
def _flush_buckets(self) -> None:
def _flush_reduce_calls(self) -> None:
if self._bucket_list is not None:
last_handle = None
for bucket in self._bucket_list:
if not bucket.sent:
# Normalize the bucket in one go
bucket.buffer.mul_(self.world_size_scaling)
# Reduce the bucket
last_handle = dist.reduce(
self._work_handles.append(
Workhandle(
handle=dist.reduce(
tensor=bucket.buffer, dst=bucket.destination, group=self.process_group, async_op=True,
),
callback=None,
)
)
bucket.sent = True
if last_handle is not None:
last_handle.wait()
self._consume_work_handles()
......@@ -97,8 +97,9 @@ class OSS(Optimizer):
self.world_size = dist.get_world_size(self.group)
self.rank = dist.get_rank(self.group)
self.global_rank = self.get_global_rank(self.group, self.rank)
self.buckets: Dict[torch.device, List[torch.Tensor]] = {}
self._local_to_global_rank = [self.get_global_rank(self.group, i) for i in range(self.world_size)]
self.buckets: Dict[torch.device, List[torch.Tensor]] = {}
self._all_states: List[Dict[str, Any]] = [] # Optional consolidated optimizer state
self._default_device = torch.device("cpu")
......@@ -377,6 +378,7 @@ class OSS(Optimizer):
# Make sure that the parameters are sorted in the state, as expected
state_dict["state"] = dict(sorted(state_dict["state"].items()))
return state_dict
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
......@@ -440,12 +442,10 @@ class OSS(Optimizer):
dist_device=self._default_device,
)
else:
global_rank = self.get_global_rank(self.group, rank)
# Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather
broadcast_object(
torch.tensor([dummy_sync_tensor], dtype=torch.uint8, device=self._default_device),
src_rank=global_rank,
src_rank=self._local_to_global_rank[rank],
group=self.group,
dist_device=self._default_device,
)
......@@ -470,10 +470,9 @@ class OSS(Optimizer):
)
else:
# Fetch the optim state from the other replicas
global_rank = self.get_global_rank(self.group, rank)
replica_state = broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._default_device),
src_rank=global_rank,
src_rank=self._local_to_global_rank[rank],
group=self.group,
dist_device=self._default_device,
)
......@@ -548,7 +547,7 @@ class OSS(Optimizer):
if bucket.numel() > 0:
global_src_rank = self.get_global_rank(self.group, src_rank)
last_work_handle = dist.broadcast(
tensor=bucket, src=global_src_rank, group=self.group, async_op=True
tensor=bucket, src=self._local_to_global_rank[src_rank], group=self.group, async_op=True
)
# Only check on the last handle, they're all inlined on the same CUDA stream
......
tests/nn/data_parallel/test_fsdp_uneven.py
tests/nn/data_parallel/test_fsdp_grad_scaler.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_features_sharded_ddp.py
tests/nn/data_parallel/test_pytorch_parity_sharded_ddp.py
tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
tests/nn/pipe/skip/test_gpipe.py
tests/nn/pipe/skip/test_verify_skippables.py
tests/nn/pipe/skip/test_stash_pop.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment