Unverified Commit ca74ee22 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[OSS] Getting rid of the "should bucket" hash table, just use a list + non...

[OSS] Getting rid of the "should bucket" hash table, just use a list + non trainable params fix (#259)

* Getting rid of the "should bucket" hash table, just use a list
Properly handle all params, with or without requires_grad

* make sure that this case is unit tested
parent bd7e25a5
......@@ -88,14 +88,20 @@ class ShardedDataParallel(nn.Module):
self.device_type = list(distinct_device_types)[0]
# Scafolding to be able to reduce the grads during the BW pass
# several optimizers can be present each working on seperate parameter sets,
# we build an iterator which goes through all the parameters involved globally
self._param_iterator = chain(*[optim.should_bucket_param.keys() for optim in self.sharded_optimizers])
self._grad_to_be_reduced = [True for _ in self._param_iterator]
# several optimizers can be present each working on seperate parameter set which is spread across multiple ranks
# - we build an iterator which goes through all the parameters involved globally
all_param_iterator = chain(
*[sum([sum(p, []) for p in optim.per_device_params.values()], []) for optim in self.sharded_optimizers]
)
self._grad_to_be_reduced = [True for _ in filter(lambda x: x.requires_grad, all_param_iterator)]
# - keep track of the grads which have already been reduced
self._reduced_grads: Dict[OSS, int] = {}
self._reduced_grads_max = {o: len(o.param_to_rank.values()) for o in self.sharded_optimizers}
self._clear_counters()
# - setup backward hooks which will be called by Torch's autograd in due time
self._grad_accs: List[Callable] = []
self._setup_backward_hooks()
......@@ -214,20 +220,24 @@ class ShardedDataParallel(nn.Module):
# Go through the parameters, attach the hook
for sharded_optimizer in self.sharded_optimizers:
for param, _ in sharded_optimizer.should_bucket_param.items():
if param.grad is not None and param.grad.requires_grad:
raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad")
# Register the hook to the next function in line,
# so that the hook is fired when this grad has properly been computed
p_tmp = param.expand_as(param)
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0]
dst_rank = sharded_optimizer.param_to_rank[param]
index = len(self._grad_accs)
grad_acc.register_hook(self._get_reduce_fn(index, param, dst_rank, sharded_optimizer))
self._grad_accs.append(grad_acc) # keep this function in scope
for (
device_per_rank_params
) in sharded_optimizer.per_device_params.values(): # all the params on this device (inc all ranks)
for device_params in device_per_rank_params:
for param in filter(lambda x: x.requires_grad, device_params):
if param.grad is not None and param.grad.requires_grad:
raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad")
# Register the hook to the next function in line,
# so that the hook is fired when this grad has properly been computed
p_tmp = param.expand_as(param)
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0]
dst_rank = sharded_optimizer.param_to_rank[param]
index = len(self._grad_accs)
grad_acc.register_hook(self._get_reduce_fn(index, param, dst_rank, sharded_optimizer))
self._grad_accs.append(grad_acc) # keep this function in scope
def _sync_params_and_buffers(self) -> None:
"""
......
......@@ -116,7 +116,7 @@ class OSS(Optimizer):
Bucket(buffer=torch.zeros(broadcast_buffer_size, dtype=per_device[0][0].dtype, device=device))
for _ in range(len(per_device))
]
self.should_bucket_param: Dict[torch.Tensor, bool] = {}
self.should_bucket_param: List[bool] = []
self.work_handles: Deque[Workhandle] = deque()
self._max_work_handles = -1
self._setup_bucket_strategy()
......@@ -385,6 +385,14 @@ class OSS(Optimizer):
if k != "params":
global_group[k] = v
# Force a re-partitioning, in case the model changed with the new state
self._partition_parameters.clear()
self._per_device_params.clear()
self._param_rank.clear()
# Update the bucketing strategy accordingly
self._setup_bucket_strategy()
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Restore the global parameter groups as well as the shard.
......@@ -393,8 +401,6 @@ class OSS(Optimizer):
from a call to :meth:`state_dict`
"""
print("loading state dict")
# Check whether we got a local or global dict
if state_dict["local_state_dict"]:
self.load_local_state_dict(state_dict)
......@@ -426,6 +432,9 @@ class OSS(Optimizer):
if len(param_groups) == len(self.optim.param_groups) + 1:
self.optim.add_param_group(param_groups[-1])
# Update the bucketing strategy accordingly
self._setup_bucket_strategy()
@staticmethod
def get_global_rank(group: Any, rank: int) -> int:
if group is dist.group.WORLD:
......@@ -510,6 +519,8 @@ class OSS(Optimizer):
"""Helper function to broadcast all the parameters from a given device"""
with torch.no_grad():
i_param = 0
for (
device,
device_params,
......@@ -523,7 +534,7 @@ class OSS(Optimizer):
for param in params:
# Bucket broadcast
if self.bucket_size > 0 and self.should_bucket_param[param]:
if self.bucket_size > 0 and self.should_bucket_param[i_param]:
assert bucket.append(param), "Bucket overflow: max %s - current %s - adding %s" % (
bucket.max_size,
bucket.current_offset,
......@@ -551,6 +562,8 @@ class OSS(Optimizer):
)
)
i_param += 1
self._consume_work_handles()
def _consume_work_handles(self) -> None:
......@@ -613,12 +626,11 @@ class OSS(Optimizer):
for dst_rank, params in enumerate(per_rank_params):
offset = 0
# Only consider the params which will require a gradient
for param in filter(lambda p: p.requires_grad, params):
for param in params:
# Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket
if (offset + param.numel()) < self.buckets[device][dst_rank].max_size:
self.should_bucket_param[param] = True
if param.requires_grad and (offset + param.numel()) < self.buckets[device][dst_rank].max_size:
self.should_bucket_param.append(True)
if offset == 0:
# count this bucket, only once
......@@ -626,7 +638,7 @@ class OSS(Optimizer):
offset += param.numel()
else:
self.should_bucket_param[param] = False
self.should_bucket_param.append(False)
# Register the max offset for this buffer, and the reference rank
self.buckets[device][dst_rank].max_offset = offset
......@@ -635,4 +647,4 @@ class OSS(Optimizer):
# Determine the max work handles in flight:
# - all the direct reduce/broadcast
self._max_work_handles += sum(not value for value in self.should_bucket_param.values())
self._max_work_handles += sum(not value for value in self.should_bucket_param)
......@@ -7,6 +7,7 @@
Testing OssDdp class.
"""
from contextlib import suppress
import copy
import tempfile
from typing import List
......@@ -21,12 +22,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS
from fairscale.utils.testing import GPT2
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
skip_if_single_gpu = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="multiple GPUs required")
from contextlib import suppress
from fairscale.utils.testing import GPT2
def run_one_step(rank, world_size, backend, device, temp_file_name):
......@@ -44,6 +43,8 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
next(model.parameters()).requires_grad = False # Test non-trainable parameters
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer, broadcast_buffers=broadcast_buffers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment