Unverified Commit d1fab39e authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix][minor] Change empty shard handling for OSS, do not rely on asserts (#460)

* change empty shard handling for OSS, do not rely on asserts
* code review
parent f565d443
...@@ -140,13 +140,6 @@ class OSS(Optimizer): ...@@ -140,13 +140,6 @@ class OSS(Optimizer):
param_group_rank["params"] = params param_group_rank["params"] = params
self._partition_parameters[rank].append(param_group_rank) self._partition_parameters[rank].append(param_group_rank)
assert min(sum(len(pg["params"]) for pg in partition) for partition in self._partition_parameters) > 0, (
"One or more empty shards detected, the world size is too big or the model too small.\n"
+ "Please reduce your world size if this is the model you would like to train\n"
+ f"Current world size: {self.world_size}\n"
+ "Current number of parameters: {}".format(sum(len(pg["params"]) for pg in self.param_groups))
)
return self._partition_parameters return self._partition_parameters
@property @property
...@@ -552,8 +545,11 @@ class OSS(Optimizer): ...@@ -552,8 +545,11 @@ class OSS(Optimizer):
for device in self.buckets.keys(): for device in self.buckets.keys():
for src_rank, bucket in enumerate(self.buckets[device]): for src_rank, bucket in enumerate(self.buckets[device]):
if bucket.numel() > 0:
global_src_rank = self.get_global_rank(self.group, src_rank) global_src_rank = self.get_global_rank(self.group, src_rank)
last_work_handle = dist.broadcast(tensor=bucket, src=global_src_rank, group=self.group, async_op=True) last_work_handle = dist.broadcast(
tensor=bucket, src=global_src_rank, group=self.group, async_op=True
)
# Only check on the last handle, they're all inlined on the same CUDA stream # Only check on the last handle, they're all inlined on the same CUDA stream
if last_work_handle: if last_work_handle:
...@@ -597,4 +593,5 @@ class OSS(Optimizer): ...@@ -597,4 +593,5 @@ class OSS(Optimizer):
else: else:
self.buckets[device][dst_rank] = bucket self.buckets[device][dst_rank] = bucket
else: else:
self.buckets[device].append(torch.zeros(1, device=device)) # This rank has an empty shard, that's fine
self.buckets[device].append(torch.zeros(0, device=device))
...@@ -262,19 +262,31 @@ def test_zero_grad(): ...@@ -262,19 +262,31 @@ def test_zero_grad():
mp.spawn(run_test_zero_grad, args=(world_size, temp_file_name), nprocs=world_size, join=True) mp.spawn(run_test_zero_grad, args=(world_size, temp_file_name), nprocs=world_size, join=True)
def run_test_catch_empty_shardd(rank, world_size, tempfile_name): def run_test_empty_shard(rank, world_size, tempfile_name, backend):
dist_init(rank, world_size, tempfile_name, backend="gloo") dist_init(rank, world_size, tempfile_name, backend=backend)
m = torch.nn.Linear(1, 1) m = torch.nn.Linear(1, 1)
with pytest.raises(AssertionError): x = torch.rand(20, 1)
_ = optim.OSS(m.parameters(), lr=0.1)
if torch.cuda.is_available():
m = m.to(rank)
x = x.to(rank)
o = optim.OSS(m.parameters(), lr=0.1)
y = m(x).sum()
y.backward()
o.step()
dist.destroy_process_group() dist.destroy_process_group()
def test_empty_shard(): @pytest.mark.parametrize("backend", ["gloo", "nccl"])
def test_empty_shard(backend):
world_size = 4 world_size = 4
if torch.cuda.is_available() and torch.cuda.device_count() < world_size:
mp.spawn(run_test_catch_empty_shardd, args=(world_size, tempfile.mkstemp()[1]), nprocs=world_size, join=True) world_size = min(world_size, torch.cuda.device_count())
if world_size == 1 or (backend == "nccl" and not torch.cuda.is_available()):
pytest.skip("Not enough GPUs to test with NCCL, or CUDA not present")
mp.spawn(run_test_empty_shard, args=(world_size, tempfile.mkstemp()[1], backend), nprocs=world_size, join=True)
def run_test_step(rank, world_size, tempfile_name): def run_test_step(rank, world_size, tempfile_name):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment