"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "0c575ace48dee00791377eedd19f1507ccc2314a"
Unverified Commit ca74ee22 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[OSS] Getting rid of the "should bucket" hash table, just use a list + non...

[OSS] Getting rid of the "should bucket" hash table, just use a list + non trainable params fix (#259)

* Getting rid of the "should bucket" hash table, just use a list
Properly handle all params, with or without requires_grad

* make sure that this case is unit tested
parent bd7e25a5
...@@ -88,14 +88,20 @@ class ShardedDataParallel(nn.Module): ...@@ -88,14 +88,20 @@ class ShardedDataParallel(nn.Module):
self.device_type = list(distinct_device_types)[0] self.device_type = list(distinct_device_types)[0]
# Scafolding to be able to reduce the grads during the BW pass # Scafolding to be able to reduce the grads during the BW pass
# several optimizers can be present each working on seperate parameter sets, # several optimizers can be present each working on seperate parameter set which is spread across multiple ranks
# we build an iterator which goes through all the parameters involved globally
self._param_iterator = chain(*[optim.should_bucket_param.keys() for optim in self.sharded_optimizers]) # - we build an iterator which goes through all the parameters involved globally
self._grad_to_be_reduced = [True for _ in self._param_iterator] all_param_iterator = chain(
*[sum([sum(p, []) for p in optim.per_device_params.values()], []) for optim in self.sharded_optimizers]
)
self._grad_to_be_reduced = [True for _ in filter(lambda x: x.requires_grad, all_param_iterator)]
# - keep track of the grads which have already been reduced
self._reduced_grads: Dict[OSS, int] = {} self._reduced_grads: Dict[OSS, int] = {}
self._reduced_grads_max = {o: len(o.param_to_rank.values()) for o in self.sharded_optimizers} self._reduced_grads_max = {o: len(o.param_to_rank.values()) for o in self.sharded_optimizers}
self._clear_counters() self._clear_counters()
# - setup backward hooks which will be called by Torch's autograd in due time
self._grad_accs: List[Callable] = [] self._grad_accs: List[Callable] = []
self._setup_backward_hooks() self._setup_backward_hooks()
...@@ -214,20 +220,24 @@ class ShardedDataParallel(nn.Module): ...@@ -214,20 +220,24 @@ class ShardedDataParallel(nn.Module):
# Go through the parameters, attach the hook # Go through the parameters, attach the hook
for sharded_optimizer in self.sharded_optimizers: for sharded_optimizer in self.sharded_optimizers:
for param, _ in sharded_optimizer.should_bucket_param.items(): for (
if param.grad is not None and param.grad.requires_grad: device_per_rank_params
raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad") ) in sharded_optimizer.per_device_params.values(): # all the params on this device (inc all ranks)
for device_params in device_per_rank_params:
# Register the hook to the next function in line, for param in filter(lambda x: x.requires_grad, device_params):
# so that the hook is fired when this grad has properly been computed if param.grad is not None and param.grad.requires_grad:
p_tmp = param.expand_as(param) raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad")
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Register the hook to the next function in line,
dst_rank = sharded_optimizer.param_to_rank[param] # so that the hook is fired when this grad has properly been computed
index = len(self._grad_accs) p_tmp = param.expand_as(param)
assert p_tmp.grad_fn is not None
grad_acc.register_hook(self._get_reduce_fn(index, param, dst_rank, sharded_optimizer)) grad_acc = p_tmp.grad_fn.next_functions[0][0]
self._grad_accs.append(grad_acc) # keep this function in scope dst_rank = sharded_optimizer.param_to_rank[param]
index = len(self._grad_accs)
grad_acc.register_hook(self._get_reduce_fn(index, param, dst_rank, sharded_optimizer))
self._grad_accs.append(grad_acc) # keep this function in scope
def _sync_params_and_buffers(self) -> None: def _sync_params_and_buffers(self) -> None:
""" """
......
...@@ -116,7 +116,7 @@ class OSS(Optimizer): ...@@ -116,7 +116,7 @@ class OSS(Optimizer):
Bucket(buffer=torch.zeros(broadcast_buffer_size, dtype=per_device[0][0].dtype, device=device)) Bucket(buffer=torch.zeros(broadcast_buffer_size, dtype=per_device[0][0].dtype, device=device))
for _ in range(len(per_device)) for _ in range(len(per_device))
] ]
self.should_bucket_param: Dict[torch.Tensor, bool] = {} self.should_bucket_param: List[bool] = []
self.work_handles: Deque[Workhandle] = deque() self.work_handles: Deque[Workhandle] = deque()
self._max_work_handles = -1 self._max_work_handles = -1
self._setup_bucket_strategy() self._setup_bucket_strategy()
...@@ -385,6 +385,14 @@ class OSS(Optimizer): ...@@ -385,6 +385,14 @@ class OSS(Optimizer):
if k != "params": if k != "params":
global_group[k] = v global_group[k] = v
# Force a re-partitioning, in case the model changed with the new state
self._partition_parameters.clear()
self._per_device_params.clear()
self._param_rank.clear()
# Update the bucketing strategy accordingly
self._setup_bucket_strategy()
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Restore the global parameter groups as well as the shard. """Restore the global parameter groups as well as the shard.
...@@ -393,8 +401,6 @@ class OSS(Optimizer): ...@@ -393,8 +401,6 @@ class OSS(Optimizer):
from a call to :meth:`state_dict` from a call to :meth:`state_dict`
""" """
print("loading state dict")
# Check whether we got a local or global dict # Check whether we got a local or global dict
if state_dict["local_state_dict"]: if state_dict["local_state_dict"]:
self.load_local_state_dict(state_dict) self.load_local_state_dict(state_dict)
...@@ -426,6 +432,9 @@ class OSS(Optimizer): ...@@ -426,6 +432,9 @@ class OSS(Optimizer):
if len(param_groups) == len(self.optim.param_groups) + 1: if len(param_groups) == len(self.optim.param_groups) + 1:
self.optim.add_param_group(param_groups[-1]) self.optim.add_param_group(param_groups[-1])
# Update the bucketing strategy accordingly
self._setup_bucket_strategy()
@staticmethod @staticmethod
def get_global_rank(group: Any, rank: int) -> int: def get_global_rank(group: Any, rank: int) -> int:
if group is dist.group.WORLD: if group is dist.group.WORLD:
...@@ -510,6 +519,8 @@ class OSS(Optimizer): ...@@ -510,6 +519,8 @@ class OSS(Optimizer):
"""Helper function to broadcast all the parameters from a given device""" """Helper function to broadcast all the parameters from a given device"""
with torch.no_grad(): with torch.no_grad():
i_param = 0
for ( for (
device, device,
device_params, device_params,
...@@ -523,7 +534,7 @@ class OSS(Optimizer): ...@@ -523,7 +534,7 @@ class OSS(Optimizer):
for param in params: for param in params:
# Bucket broadcast # Bucket broadcast
if self.bucket_size > 0 and self.should_bucket_param[param]: if self.bucket_size > 0 and self.should_bucket_param[i_param]:
assert bucket.append(param), "Bucket overflow: max %s - current %s - adding %s" % ( assert bucket.append(param), "Bucket overflow: max %s - current %s - adding %s" % (
bucket.max_size, bucket.max_size,
bucket.current_offset, bucket.current_offset,
...@@ -551,6 +562,8 @@ class OSS(Optimizer): ...@@ -551,6 +562,8 @@ class OSS(Optimizer):
) )
) )
i_param += 1
self._consume_work_handles() self._consume_work_handles()
def _consume_work_handles(self) -> None: def _consume_work_handles(self) -> None:
...@@ -613,12 +626,11 @@ class OSS(Optimizer): ...@@ -613,12 +626,11 @@ class OSS(Optimizer):
for dst_rank, params in enumerate(per_rank_params): for dst_rank, params in enumerate(per_rank_params):
offset = 0 offset = 0
# Only consider the params which will require a gradient for param in params:
for param in filter(lambda p: p.requires_grad, params):
# Criteria to decide whether this parameter is to be bucketed or not: # Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket # - enough room in the bucket
if (offset + param.numel()) < self.buckets[device][dst_rank].max_size: if param.requires_grad and (offset + param.numel()) < self.buckets[device][dst_rank].max_size:
self.should_bucket_param[param] = True self.should_bucket_param.append(True)
if offset == 0: if offset == 0:
# count this bucket, only once # count this bucket, only once
...@@ -626,7 +638,7 @@ class OSS(Optimizer): ...@@ -626,7 +638,7 @@ class OSS(Optimizer):
offset += param.numel() offset += param.numel()
else: else:
self.should_bucket_param[param] = False self.should_bucket_param.append(False)
# Register the max offset for this buffer, and the reference rank # Register the max offset for this buffer, and the reference rank
self.buckets[device][dst_rank].max_offset = offset self.buckets[device][dst_rank].max_offset = offset
...@@ -635,4 +647,4 @@ class OSS(Optimizer): ...@@ -635,4 +647,4 @@ class OSS(Optimizer):
# Determine the max work handles in flight: # Determine the max work handles in flight:
# - all the direct reduce/broadcast # - all the direct reduce/broadcast
self._max_work_handles += sum(not value for value in self.should_bucket_param.values()) self._max_work_handles += sum(not value for value in self.should_bucket_param)
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
Testing OssDdp class. Testing OssDdp class.
""" """
from contextlib import suppress
import copy import copy
import tempfile import tempfile
from typing import List from typing import List
...@@ -21,12 +22,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP ...@@ -21,12 +22,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.nn.data_parallel import ShardedDataParallel from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS from fairscale.optim import OSS
from fairscale.utils.testing import GPT2
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
skip_if_single_gpu = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="multiple GPUs required") skip_if_single_gpu = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="multiple GPUs required")
from contextlib import suppress
from fairscale.utils.testing import GPT2
def run_one_step(rank, world_size, backend, device, temp_file_name): def run_one_step(rank, world_size, backend, device, temp_file_name):
...@@ -44,6 +43,8 @@ def run_one_step(rank, world_size, backend, device, temp_file_name): ...@@ -44,6 +43,8 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
model.register_buffer("test_buffer", torch.ones((1)) * rank) model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device) model.to(device)
next(model.parameters()).requires_grad = False # Test non-trainable parameters
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer, broadcast_buffers=broadcast_buffers) ddp_model = ShardedDataParallel(model, optimizer, broadcast_buffers=broadcast_buffers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment