Unverified Commit 81ac5b28 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] OSS unit test to check data group (#129)

* new unit test to catch rank issues in OSS
parent 22ff665d
...@@ -101,6 +101,11 @@ run_oss_benchmark: &run_oss_benchmark ...@@ -101,6 +101,11 @@ run_oss_benchmark: &run_oss_benchmark
name: Run OSS Benchmark name: Run OSS Benchmark
command: | command: |
python benchmarks/oss.py --check_regression --world_size 4 --reference_speed 21.2 --reference_memory 4220 --reference_loss 0.63 python benchmarks/oss.py --check_regression --world_size 4 --reference_speed 21.2 --reference_memory 4220 --reference_loss 0.63
run_oss_gloo: &run_oss_gloo
- run:
name: Run OSS with Gloo
command: |
python benchmarks/oss.py --gloo --optim_type oss python benchmarks/oss.py --gloo --optim_type oss
...@@ -254,6 +259,9 @@ jobs: ...@@ -254,6 +259,9 @@ jobs:
- <<: *run_oss_benchmark - <<: *run_oss_benchmark
- <<: *run_oss_gloo
workflows: workflows:
......
...@@ -194,7 +194,7 @@ class OSS(Optimizer): ...@@ -194,7 +194,7 @@ class OSS(Optimizer):
device, device,
device_params, device_params,
) in self.per_device_params.items(): # all the params on this device (inc all ranks) ) in self.per_device_params.items(): # all the params on this device (inc all ranks)
self._broadcast_params(self._broadcast_buffers[device], device_params, self.group, self.global_rank) self._broadcast_params(self._broadcast_buffers[device], device_params)
return loss return loss
...@@ -408,10 +408,7 @@ class OSS(Optimizer): ...@@ -408,10 +408,7 @@ class OSS(Optimizer):
global_rank = dist.distributed_c10d._get_global_rank(group, rank) # type: ignore global_rank = dist.distributed_c10d._get_global_rank(group, rank) # type: ignore
return global_rank return global_rank
@staticmethod def _broadcast_params(self, buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]]) -> None:
def _broadcast_params(
buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]], group: Any, self_rank: int
) -> None:
"""Helper function to broadcast all the parameters from a given device """Helper function to broadcast all the parameters from a given device
""" """
buffer_size = buffers[0].numel() buffer_size = buffers[0].numel()
...@@ -425,7 +422,7 @@ class OSS(Optimizer): ...@@ -425,7 +422,7 @@ class OSS(Optimizer):
if len(params) == 0: if len(params) == 0:
continue continue
global_rank = OSS.get_global_rank(group, rank) global_rank = OSS.get_global_rank(self.group, rank)
# Copy small parameters into per-GPU buffers # Copy small parameters into per-GPU buffers
i_bucketed = 0 # the number of tensors packed in the buffer i_bucketed = 0 # the number of tensors packed in the buffer
...@@ -434,14 +431,14 @@ class OSS(Optimizer): ...@@ -434,14 +431,14 @@ class OSS(Optimizer):
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones. # Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size: while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size:
end = offset + params[i_bucketed].numel() end = offset + params[i_bucketed].numel()
if global_rank == self_rank: if global_rank == self.global_rank:
buffer[offset:end].copy_(params[i_bucketed].data.view(-1)) # type: ignore buffer[offset:end].copy_(params[i_bucketed].data.view(-1)) # type: ignore
offset = end offset = end
i_bucketed += 1 i_bucketed += 1
if i_bucketed > 0: if i_bucketed > 0:
future = dist.broadcast(tensor=buffer, src=global_rank, group=group, async_op=True) future = dist.broadcast(tensor=buffer, src=global_rank, group=self.group, async_op=True)
if global_rank != self_rank: if global_rank != self.global_rank:
# This request will need to be unrolled # This request will need to be unrolled
bucket_requests.append((future, rank)) bucket_requests.append((future, rank))
...@@ -455,7 +452,7 @@ class OSS(Optimizer): ...@@ -455,7 +452,7 @@ class OSS(Optimizer):
restore_require_grad.append(param) restore_require_grad.append(param)
param.requires_grad = False param.requires_grad = False
requests.append(dist.broadcast(tensor=param, src=global_rank, group=group, async_op=True)) requests.append(dist.broadcast(tensor=param, src=global_rank, group=self.group, async_op=True))
# Unroll the initial packed small parameters # Unroll the initial packed small parameters
for gate, rank in bucket_requests: for gate, rank in bucket_requests:
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
import os import os
import numpy as np
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -334,3 +335,78 @@ def test_collect_shards(): ...@@ -334,3 +335,78 @@ def test_collect_shards():
mp.spawn( mp.spawn(
run_test_collect_shards, args=(world_size, reference_rank), nprocs=world_size, join=True, run_test_collect_shards, args=(world_size, reference_rank), nprocs=world_size, join=True,
) )
def run_test_multiple_groups(rank, world_size):
# Only work with the even ranks, to check that the global_rank indexing is properly used
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
dist.init_process_group(backend="gloo", rank=rank, world_size=world_size)
sub_group_ranks = [0, 2, 4]
process_group = torch.distributed.new_group(ranks=sub_group_ranks, backend="gloo")
# Make sure that all the ranks get different training data
# So that the sync check in between their models is meaningful
torch.manual_seed(rank)
np.random.seed(rank)
# Standard deep learning setup
device = "cpu"
epochs, batch, input_width, hidden, target_width = 5, 3, 20, 10, 5
loss_fn = torch.nn.L1Loss().to(device)
def check(optimizer):
# Just run a couple of epochs, check that the model is properly updated
for _ in range(epochs):
target = torch.rand((batch, target_width), device=device)
inputs = torch.rand((batch, input_width), device=device)
def closure():
optimizer.zero_grad()
output = model(inputs)
loss = loss_fn(output, target)
loss /= world_size
loss.backward()
dist.all_reduce(loss, group=process_group) # Not strictly needed for the test below
return loss
_ = optimizer.step(closure=closure)
# Check that all the params are the same on all ranks
for pg in optimizer.param_groups:
for p in pg["params"]:
receptacle = [p.clone() for _ in sub_group_ranks] if rank == 0 else []
dist.gather(p, receptacle, dst=0, group=process_group)
if rank == 0:
for sync_p in receptacle[1:]:
assert torch.all(torch.eq(receptacle[0], sync_p)), "Models differ in between ranks"
if rank in sub_group_ranks:
# Model fitting in the broadcast bucket
model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width)).to(
device
)
# With SGD, Momentum is required to get a state to shard
optimizer = optim.OSS(
model.parameters(), lr=0.1, momentum=0.99, group=process_group, broadcast_buffer_size=2 ** 20
)
check(optimizer)
# Model not-fitting in the broadcast bucket
model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width)).to(
device
)
# With SGD, Momentum is required to get a state to shard
optimizer = optim.OSS(model.parameters(), lr=0.1, momentum=0.99, group=process_group, broadcast_buffer_size=0)
check(optimizer)
def test_multiple_groups():
world_size = 6
mp.spawn(
run_test_multiple_groups, args=(world_size,), nprocs=world_size, join=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment