Unverified Commit 131a5356 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[perf][minor] cache the rank lookups, small shardedddp perf fix (#474)

* [perf][minor] cache the rank lookups, small shardedddp perf fix
* tiny improvement, code quality
parent d1fab39e
...@@ -107,7 +107,7 @@ class ShardedDataParallel(nn.Module): ...@@ -107,7 +107,7 @@ class ShardedDataParallel(nn.Module):
self.enable_broadcast_buffers = broadcast_buffers self.enable_broadcast_buffers = broadcast_buffers
self.auto_refresh_trainable = auto_refresh_trainable self.auto_refresh_trainable = auto_refresh_trainable
self.reduce_fp16 = reduce_fp16 self.reduce_fp16 = reduce_fp16
if reduce_buffer_size > 0: if reduce_buffer_size > 0 and reduce_fp16:
self.reduce_fp16 = False self.reduce_fp16 = False
logging.warning( logging.warning(
"fp16 gradient reduction is not compatible with reduction buffers, which are requested. fp16 grad reduction is deactivated." "fp16 gradient reduction is not compatible with reduction buffers, which are requested. fp16 grad reduction is deactivated."
...@@ -124,6 +124,9 @@ class ShardedDataParallel(nn.Module): ...@@ -124,6 +124,9 @@ class ShardedDataParallel(nn.Module):
self.reference_global_rank = OSS.get_global_rank(self.process_group, 0) # picking rank 0 as the reference self.reference_global_rank = OSS.get_global_rank(self.process_group, 0) # picking rank 0 as the reference
self.rank = dist.get_rank(self.process_group) self.rank = dist.get_rank(self.process_group)
self.global_rank = OSS.get_global_rank(self.process_group, self.rank) self.global_rank = OSS.get_global_rank(self.process_group, self.rank)
self._local_to_global_rank = [
OSS.get_global_rank(self.process_group, i) for i in range(dist.get_world_size(self.process_group))
]
# Expose some of the PytorchDDP attributes, some frameworks rely on them. # Expose some of the PytorchDDP attributes, some frameworks rely on them.
# See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel # See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel
...@@ -153,10 +156,6 @@ class ShardedDataParallel(nn.Module): ...@@ -153,10 +156,6 @@ class ShardedDataParallel(nn.Module):
self._trainable_param_to_rank: Dict[torch.Tensor, int] = {} self._trainable_param_to_rank: Dict[torch.Tensor, int] = {}
self._reference_trainable_mask = list(map(_trainable, self._all_params)) self._reference_trainable_mask = list(map(_trainable, self._all_params))
# - keep track of the grads which have already been reduced
self._reduced_grads = 0
self._reduced_grads_max = 0
# - setup buckets and tensor views # - setup buckets and tensor views
model_size = sum([p.numel() for p in self.module.parameters()]) model_size = sum([p.numel() for p in self.module.parameters()])
self.buffer_max_size = min(reduce_buffer_size, model_size) self.buffer_max_size = min(reduce_buffer_size, model_size)
...@@ -358,7 +357,7 @@ class ShardedDataParallel(nn.Module): ...@@ -358,7 +357,7 @@ class ShardedDataParallel(nn.Module):
def _clear_counters(self) -> None: def _clear_counters(self) -> None:
"""Reset all the grad reduce and call counters""" """Reset all the grad reduce and call counters"""
self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced] self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced]
self._reduced_grads = 0 self._bucket_flush_callback_set = False
# Do not reset the buckets # Do not reset the buckets
if self.use_buckets: if self.use_buckets:
...@@ -375,8 +374,6 @@ class ShardedDataParallel(nn.Module): ...@@ -375,8 +374,6 @@ class ShardedDataParallel(nn.Module):
bucket.reset() bucket.reset()
self._bucket_flush_callback_set = False
if not self.should_accumulate_grads: if not self.should_accumulate_grads:
self.accumulate_grads_flipped = False self.accumulate_grads_flipped = False
...@@ -406,7 +403,7 @@ class ShardedDataParallel(nn.Module): ...@@ -406,7 +403,7 @@ class ShardedDataParallel(nn.Module):
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None" assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
if not self._bucket_flush_callback_set: if not self._bucket_flush_callback_set:
Variable._execution_engine.queue_callback(self._flush_buckets) Variable._execution_engine.queue_callback(self._flush_reduce_calls)
self._bucket_flush_callback_set = True self._bucket_flush_callback_set = True
# Make sure that this is not fired twice # Make sure that this is not fired twice
...@@ -425,27 +422,21 @@ class ShardedDataParallel(nn.Module): ...@@ -425,27 +422,21 @@ class ShardedDataParallel(nn.Module):
param.grad.data = param.grad.data.to(dtype=param.dtype) param.grad.data = param.grad.data.to(dtype=param.dtype)
# Async reduce for this buffer, log the future # Async reduce for this buffer, log the future
dst_global_rank = OSS.get_global_rank(self.process_group, dst_rank)
self._work_handles.append( self._work_handles.append(
Workhandle( Workhandle(
handle=dist.reduce( handle=dist.reduce(
tensor=param.grad.data, dst=dst_global_rank, group=self.process_group, async_op=True tensor=param.grad.data,
dst=self._local_to_global_rank[dst_rank],
group=self.process_group,
async_op=True,
), ),
callback=cleanup, callback=cleanup,
) )
) )
self._reduced_grads += 1
# Opportunistically try to empty the queue # Opportunistically try to empty the queue, free memory
self._try_consume_work_handle() self._try_consume_work_handle()
# If all the reduce operations have been called,
# make sure that all the asynchronous calls have concluded before moving on
# and execute the delayed actions (release gradients, unroll the buckets)
if self._reduced_grads == self._reduced_grads_max:
self._consume_work_handles()
else: else:
@torch.no_grad() @torch.no_grad()
...@@ -455,7 +446,7 @@ class ShardedDataParallel(nn.Module): ...@@ -455,7 +446,7 @@ class ShardedDataParallel(nn.Module):
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None" assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
if not self._bucket_flush_callback_set: if not self._bucket_flush_callback_set:
Variable._execution_engine.queue_callback(self._flush_buckets) Variable._execution_engine.queue_callback(self._flush_reduce_calls)
self._bucket_flush_callback_set = True self._bucket_flush_callback_set = True
# Make sure that this is not fired twice # Make sure that this is not fired twice
...@@ -480,17 +471,10 @@ class ShardedDataParallel(nn.Module): ...@@ -480,17 +471,10 @@ class ShardedDataParallel(nn.Module):
callback=None, callback=None,
) )
) )
self._reduced_grads += 1
# Opportunistically try to empty the queue # Opportunistically try to empty the queue
self._try_consume_work_handle() self._try_consume_work_handle()
# If all the reduce operations have been called,
# make sure that all the asynchronous calls have concluded before moving on
# and execute the delayed actions (release gradients, unroll the buckets)
if self._reduced_grads == self._reduced_grads_max:
self._consume_work_handles()
return reduce return reduce
def _setup_backward_hooks(self) -> None: def _setup_backward_hooks(self) -> None:
...@@ -550,9 +534,6 @@ class ShardedDataParallel(nn.Module): ...@@ -550,9 +534,6 @@ class ShardedDataParallel(nn.Module):
This method can be a slow for big models, but it it not typically called often (not for every forward for instance) This method can be a slow for big models, but it it not typically called often (not for every forward for instance)
""" """
# A priori, one reduce call per param
self._reduced_grads_max = len(self._trainable_params)
if not self.use_buckets: if not self.use_buckets:
return return
...@@ -572,7 +553,7 @@ class ShardedDataParallel(nn.Module): ...@@ -572,7 +553,7 @@ class ShardedDataParallel(nn.Module):
] ]
bucket = self.buckets[device][dst_rank] bucket = self.buckets[device][dst_rank]
bucket.destination = OSS.get_global_rank(self.process_group, dst_rank) bucket.destination = self._local_to_global_rank[dst_rank]
# Criteria to decide whether this parameter is to be bucketed or not: # Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket # - enough room in the bucket
...@@ -590,7 +571,6 @@ class ShardedDataParallel(nn.Module): ...@@ -590,7 +571,6 @@ class ShardedDataParallel(nn.Module):
bucket.fill = fill_next bucket.fill = fill_next
# Update the bucket # Update the bucket
self._reduced_grads_max -= 1 # one less reduce call per bucketed grad
self.buckets[device][dst_rank].max_params_checked_in += 1 self.buckets[device][dst_rank].max_params_checked_in += 1
else: else:
...@@ -602,8 +582,6 @@ class ShardedDataParallel(nn.Module): ...@@ -602,8 +582,6 @@ class ShardedDataParallel(nn.Module):
for bucket in self._bucket_list: for bucket in self._bucket_list:
bucket.buffer.resize_(bucket.fill) bucket.buffer.resize_(bucket.fill)
bucket.sent = True bucket.sent = True
if bucket.max_params_checked_in > 0:
self._reduced_grads_max += 1 # one reduce call per bucket
def _consume_work_handles(self) -> None: def _consume_work_handles(self) -> None:
"""Consume all the futures which are tied to this optimizer's buckets. """Consume all the futures which are tied to this optimizer's buckets.
...@@ -623,20 +601,22 @@ class ShardedDataParallel(nn.Module): ...@@ -623,20 +601,22 @@ class ShardedDataParallel(nn.Module):
if work_handle.callback is not None: if work_handle.callback is not None:
work_handle.callback() work_handle.callback()
# Flush all the buckets, just in case def _flush_reduce_calls(self) -> None:
def _flush_buckets(self) -> None:
if self._bucket_list is not None: if self._bucket_list is not None:
last_handle = None
for bucket in self._bucket_list: for bucket in self._bucket_list:
if not bucket.sent: if not bucket.sent:
# Normalize the bucket in one go # Normalize the bucket in one go
bucket.buffer.mul_(self.world_size_scaling) bucket.buffer.mul_(self.world_size_scaling)
# Reduce the bucket # Reduce the bucket
last_handle = dist.reduce( self._work_handles.append(
tensor=bucket.buffer, dst=bucket.destination, group=self.process_group, async_op=True, Workhandle(
handle=dist.reduce(
tensor=bucket.buffer, dst=bucket.destination, group=self.process_group, async_op=True,
),
callback=None,
)
) )
bucket.sent = True bucket.sent = True
if last_handle is not None: self._consume_work_handles()
last_handle.wait()
...@@ -97,8 +97,9 @@ class OSS(Optimizer): ...@@ -97,8 +97,9 @@ class OSS(Optimizer):
self.world_size = dist.get_world_size(self.group) self.world_size = dist.get_world_size(self.group)
self.rank = dist.get_rank(self.group) self.rank = dist.get_rank(self.group)
self.global_rank = self.get_global_rank(self.group, self.rank) self.global_rank = self.get_global_rank(self.group, self.rank)
self.buckets: Dict[torch.device, List[torch.Tensor]] = {} self._local_to_global_rank = [self.get_global_rank(self.group, i) for i in range(self.world_size)]
self.buckets: Dict[torch.device, List[torch.Tensor]] = {}
self._all_states: List[Dict[str, Any]] = [] # Optional consolidated optimizer state self._all_states: List[Dict[str, Any]] = [] # Optional consolidated optimizer state
self._default_device = torch.device("cpu") self._default_device = torch.device("cpu")
...@@ -377,6 +378,7 @@ class OSS(Optimizer): ...@@ -377,6 +378,7 @@ class OSS(Optimizer):
# Make sure that the parameters are sorted in the state, as expected # Make sure that the parameters are sorted in the state, as expected
state_dict["state"] = dict(sorted(state_dict["state"].items())) state_dict["state"] = dict(sorted(state_dict["state"].items()))
return state_dict return state_dict
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
...@@ -440,12 +442,10 @@ class OSS(Optimizer): ...@@ -440,12 +442,10 @@ class OSS(Optimizer):
dist_device=self._default_device, dist_device=self._default_device,
) )
else: else:
global_rank = self.get_global_rank(self.group, rank)
# Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather # Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather
broadcast_object( broadcast_object(
torch.tensor([dummy_sync_tensor], dtype=torch.uint8, device=self._default_device), torch.tensor([dummy_sync_tensor], dtype=torch.uint8, device=self._default_device),
src_rank=global_rank, src_rank=self._local_to_global_rank[rank],
group=self.group, group=self.group,
dist_device=self._default_device, dist_device=self._default_device,
) )
...@@ -470,10 +470,9 @@ class OSS(Optimizer): ...@@ -470,10 +470,9 @@ class OSS(Optimizer):
) )
else: else:
# Fetch the optim state from the other replicas # Fetch the optim state from the other replicas
global_rank = self.get_global_rank(self.group, rank)
replica_state = broadcast_object( replica_state = broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._default_device), torch.tensor([0], dtype=torch.uint8, device=self._default_device),
src_rank=global_rank, src_rank=self._local_to_global_rank[rank],
group=self.group, group=self.group,
dist_device=self._default_device, dist_device=self._default_device,
) )
...@@ -548,7 +547,7 @@ class OSS(Optimizer): ...@@ -548,7 +547,7 @@ class OSS(Optimizer):
if bucket.numel() > 0: if bucket.numel() > 0:
global_src_rank = self.get_global_rank(self.group, src_rank) global_src_rank = self.get_global_rank(self.group, src_rank)
last_work_handle = dist.broadcast( last_work_handle = dist.broadcast(
tensor=bucket, src=global_src_rank, group=self.group, async_op=True tensor=bucket, src=self._local_to_global_rank[src_rank], group=self.group, async_op=True
) )
# Only check on the last handle, they're all inlined on the same CUDA stream # Only check on the last handle, they're all inlined on the same CUDA stream
......
tests/nn/data_parallel/test_fsdp_uneven.py tests/nn/data_parallel/test_fsdp_uneven.py
tests/nn/data_parallel/test_fsdp_grad_scaler.py tests/nn/data_parallel/test_fsdp_grad_scaler.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_features_sharded_ddp.py tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_pytorch_parity_sharded_ddp.py tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
tests/nn/pipe/skip/test_gpipe.py tests/nn/pipe/skip/test_gpipe.py
tests/nn/pipe/skip/test_verify_skippables.py tests/nn/pipe/skip/test_verify_skippables.py
tests/nn/pipe/skip/test_stash_pop.py tests/nn/pipe/skip/test_stash_pop.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment