Unverified Commit ce2f64f9 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] OSS tensor view corner case + corresponding unit tests (#315)

parent 44b9bcd8
......@@ -109,25 +109,10 @@ class OSS(Optimizer):
# Current default device is set by the parameters allocated to this rank
self._device = list(self.per_device_params.keys())[0]
self.buckets: Dict[torch.device, List[torch.Tensor]] = {}
self.buffer_max_size = broadcast_buffer_size
# Get the correct size for the buckets, cannot be bigger than the model
model_size = sum([p.numel() for p in self.param_to_rank.keys()])
self.bucket_size = min(broadcast_buffer_size, model_size)
logging.info(
"Bucket size: {:.2f}M parameters, model size {:.2f}M parameters".format(
self.bucket_size / 2 ** 20, model_size / 2 ** 20
)
)
# Allocate one buffer per rank and per device to group the small parameters
for device, per_device in self.per_device_params.items():
self.buckets[device] = [
torch.zeros(self.bucket_size, dtype=per_device[0][0].dtype, device=device)
for _ in range(len(per_device))
]
self.should_bucket_param: List[bool] = []
self.work_handles: Deque[Workhandle] = deque()
self._max_work_handles = -1
self._setup_bucket_strategy()
# Partition helpers
......@@ -624,10 +609,24 @@ class OSS(Optimizer):
network requests have been issued.
"""
# Determine the max work handles in flight:
# - count all the buckets on the fly
self._max_work_handles = 0
# (re) allocate the buckets
# - Get the correct size for the buckets, cannot be bigger than the model
model_size = sum([p.numel() for p in self.param_to_rank.keys()])
self.bucket_size = min(self.buffer_max_size, model_size)
logging.info(
"Bucket size: {:.2f}M parameters, model size {:.2f}M parameters".format(
self.bucket_size / 2 ** 20, model_size / 2 ** 20
)
)
# - Allocate one buffer per rank and per device to group the small parameters
for device, per_device in self.per_device_params.items():
self.buckets[device] = [
torch.zeros(self.bucket_size, dtype=per_device[0][0].dtype, device=device)
for _ in range(len(per_device))
]
# Devise the bucketing strategy
for device, per_rank_params in self.per_device_params.items():
for dst_rank, params in enumerate(per_rank_params):
offset = 0
......@@ -638,10 +637,6 @@ class OSS(Optimizer):
if param.requires_grad and (offset + param.numel()) < self.bucket_size:
self.should_bucket_param.append(True)
if offset == 0:
# count this bucket, only once
self._max_work_handles += 1
# This parameter becomes a view of the bucket
offset_next = offset + param.numel()
......@@ -654,11 +649,3 @@ class OSS(Optimizer):
# Resize the bucket to remove lost space in the end
self.buckets[device][dst_rank].resize_(offset)
# Make sure that the memory previously taken by the bucketed parameters is released
if self._device.type == "cuda":
torch.cuda.empty_cache()
# Determine the max work handles in flight:
# - all the direct reduce/broadcast
self._max_work_handles += sum(not value for value in self.should_bucket_param)
......@@ -28,4 +28,4 @@ use_parentheses = true
skip_glob = ["build/*", "stubs/*"]
# Don't split "import" and "from".
force_sort_within_sections = true
known_third_party = ["benchmark_dataset", "datasets", "helpers", "models", "numpy", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"]
known_third_party = ["datasets", "golden_configs", "helpers", "models", "numpy", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"]
......@@ -191,7 +191,9 @@ def run_test_add_param_group(rank, world_size, tempfile_name):
# Test with all parameters trainable to begin with
def all_trainable():
params = []
for size in [4, 5, 2, 6, 4]:
sizes = [9, 7, 5, 3]
sizes_world = sizes * world_size
for size in sizes_world[:-1]:
params.append(torch.rand(size, 1))
# Make sure that the params are trainable, enforces size-based partitioning
......@@ -204,8 +206,9 @@ def run_test_add_param_group(rank, world_size, tempfile_name):
o.add_param_group({"params": [torch.rand(3, 1)]})
assert len(o.param_groups) == 2
# Verify that added group is added to the correct partition making all have 8 elements.
assert sum([x.numel() for g in o.optim.param_groups for x in g["params"]]) == 8
# Verify that added group is added to the correct partition making all have the same number of elements
assert sum([x.numel() for g in o.optim.param_groups for x in g["params"]]) == sum(sizes)
assert len(o.optim.param_groups) == 2
# Test a pathological config with a first big non-trainable param
......@@ -233,9 +236,10 @@ def run_test_add_param_group(rank, world_size, tempfile_name):
def test_add_param_group():
world_size = 3
world_size = 4
if not torch.cuda.is_available() or torch.cuda.device_count() < world_size:
pytest.skip("Not enough GPUs for NCCL-based test")
world_size = min(world_size, torch.cuda.device_count())
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_test_add_param_group, args=(world_size, temp_file_name), nprocs=world_size, join=True)
......@@ -591,10 +595,11 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
target = torch.rand((batch, target_width), device=device)
inputs = torch.rand((batch, input_width), device=device)
model_oss1 = torch.nn.Sequential(
torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, hidden), torch.nn.Linear(hidden, target_width),
).to(device)
model_oss1 = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, hidden),).to(device)
head_oss1 = torch.nn.Linear(hidden, target_width).to(device)
model_oss2 = copy.deepcopy(model_oss1)
head_oss2 = copy.deepcopy(head_oss1)
# For this test the gradients are (all) reduced in the same way in between the torch reference and fairscale.
# Normally OSS would use ShardedDDP and only reduce to the proper rank, but this does not change the
......@@ -602,16 +607,19 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
# to keep the comparison apples-to-apples DDP is used in both cases
model_oss1 = DDP(module=model_oss1, device_ids=[rank],)
sharded_optimizer1 = optim.OSS(model_oss1.parameters(), lr=0.1, momentum=0.99)
sharded_optimizer1.add_param_group({"params": head_oss1.parameters()})
model_oss2 = DDP(module=model_oss2, device_ids=[rank],)
sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=0.1, momentum=0.99)
sharded_optimizer2.add_param_group({"params": head_oss2.parameters()})
def run_grad_step(device, model, optimizer):
def run_grad_step(device, model, head, optimizer):
loss_fn = torch.nn.L1Loss()
loss_fn.to(device)
model.zero_grad()
outputs = model(inputs)
outputs = head(model(inputs))
loss = loss_fn(outputs, target)
loss.backward()
......@@ -622,21 +630,23 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
# save and reload without taking any steps
sharded_optimizer2.consolidate_state_dict()
state_dict2 = sharded_optimizer2.state_dict()
sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=0.1, momentum=0.99)
sharded_optimizer2.add_param_group({"params": head_oss2.parameters()})
sharded_optimizer2.load_state_dict(state_dict2)
# now take a step and check that parameters are equal
# take a step
run_grad_step(device, model_oss1, sharded_optimizer1)
run_grad_step(device, model_oss2, sharded_optimizer2)
run_grad_step(device, model_oss1, head_oss1, sharded_optimizer1)
run_grad_step(device, model_oss2, head_oss2, sharded_optimizer2)
# check that model parameters are equal
for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()):
assert torch.allclose(param1, param2), "parameters of the two identical models have diverged (before any steps)"
# take a step
run_grad_step(device, model_oss1, sharded_optimizer1)
run_grad_step(device, model_oss2, sharded_optimizer2)
run_grad_step(device, model_oss1, head_oss1, sharded_optimizer1)
run_grad_step(device, model_oss2, head_oss2, sharded_optimizer2)
# check that model parameters are equal
for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()):
......@@ -653,8 +663,8 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
assert state_dict2["param_groups"][replica][k] == sharded_optimizer2.param_groups[0][k]
# take a step
run_grad_step(device, model_oss1, sharded_optimizer1)
run_grad_step(device, model_oss2, sharded_optimizer2)
run_grad_step(device, model_oss1, head_oss1, sharded_optimizer1)
run_grad_step(device, model_oss2, head_oss2, sharded_optimizer2)
# check that saving did not cause a change in the parameters
for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()):
......@@ -668,11 +678,12 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
# reload the state_dict
sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=0.1, momentum=0.99)
sharded_optimizer2.add_param_group({"params": head_oss2.parameters()})
sharded_optimizer2.load_state_dict(state_dict2)
# take a step
run_grad_step(device, model_oss1, sharded_optimizer1)
run_grad_step(device, model_oss2, sharded_optimizer2)
run_grad_step(device, model_oss1, head_oss1, sharded_optimizer1)
run_grad_step(device, model_oss2, head_oss2, sharded_optimizer2)
# check that reloading a saved state dict does not change the parameters
for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment