Unverified Commit 69cbdf5d authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[draft][chore] SDP : increase code coverage (#653)

* increasing the code coverage, good practice and raising bugs.  hopefully getting to 100%
* small bugfix
parent c65a48f3
...@@ -224,10 +224,7 @@ class ShardedDataParallel(nn.Module): ...@@ -224,10 +224,7 @@ class ShardedDataParallel(nn.Module):
return self.module(*inputs, **kwargs) return self.module(*inputs, **kwargs)
def to( # type: ignore def to( # type: ignore
self, self, device: Optional[torch.device], dtype: Optional[torch.dtype] = None, non_blocking: bool = False,
device: Optional[Union[int, torch.device]],
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> "ShardedDataParallel": ) -> "ShardedDataParallel":
""" """
Moves and/or casts the parameters and buffers. Moves and/or casts the parameters and buffers.
...@@ -257,20 +254,24 @@ class ShardedDataParallel(nn.Module): ...@@ -257,20 +254,24 @@ class ShardedDataParallel(nn.Module):
Returns: Returns:
Module: self. Module: self.
""" """
if isinstance(device, str):
device = torch.device(device)
assert ( assert (
len(self._buckets.keys()) == 0 or device in self._buckets.keys() device is None
or len(self._buckets.keys()) == 0
or device.type in map(lambda x: x.type, self._buckets.keys())
), "Changing devices is not supported, because this would break OSSs state" ), "Changing devices is not supported, because this would break OSSs state"
assert ( assert (
len(self._buckets.keys()) < 2 len(self._buckets.keys()) < 2
), "Several devices specified to begin with, incompatible with setting a single device here" ), "Several devices specified to begin with, incompatible with setting a single device here"
for _device in self._buckets.keys():
for bucket in self._buckets[_device].values():
bucket.to(device=_device, dtype=dtype, non_blocking=non_blocking)
self.module.to(device=device, dtype=dtype, non_blocking=non_blocking) self.module.to(device=device, dtype=dtype, non_blocking=non_blocking)
# Re-build the buckets, hooks, etc..
self.refresh_trainable()
def refresh_trainable(self) -> None: def refresh_trainable(self) -> None:
""" If the module trainability has changed, update all the assumptions """ """ If the module trainability has changed, update all the assumptions """
...@@ -350,8 +351,8 @@ class ShardedDataParallel(nn.Module): ...@@ -350,8 +351,8 @@ class ShardedDataParallel(nn.Module):
See :meth:`torch.optim.Optimizer.zero_grad` for details. See :meth:`torch.optim.Optimizer.zero_grad` for details.
""" """
for index, trainable_param in enumerate(self._all_params): for index, trainable_param in enumerate(self._trainable_params):
if set_to_none and not self._should_bucket_grad[index]: if set_to_none and (len(self._should_bucket_grad) == 0 or not self._should_bucket_grad[index]):
trainable_param.grad = None trainable_param.grad = None
elif trainable_param.grad is not None: elif trainable_param.grad is not None:
trainable_param.grad.zero_() trainable_param.grad.zero_()
......
...@@ -108,7 +108,7 @@ def run_one_step( ...@@ -108,7 +108,7 @@ def run_one_step(
# Optim loop # Optim loop
def closure(): def closure():
optimizer.zero_grad() ddp_model.zero_grad(set_to_none=True)
with ddp_model.no_sync() if grad_accumulation else suppress(): with ddp_model.no_sync() if grad_accumulation else suppress():
input_tensor = torch.rand((64, 2)).to(device) input_tensor = torch.rand((64, 2)).to(device)
...@@ -193,7 +193,7 @@ def run_test_two_inputs(rank, world_size, backend, device, temp_file_name, reduc ...@@ -193,7 +193,7 @@ def run_test_two_inputs(rank, world_size, backend, device, temp_file_name, reduc
# Optim loop # Optim loop
def closure(): def closure():
optimizer.zero_grad() ddp_model.zero_grad(set_to_none=True)
input_tensor = torch.rand((64, 2)).to(device) input_tensor = torch.rand((64, 2)).to(device)
loss = ddp_model(input_tensor, input_tensor).abs().sum() loss = ddp_model(input_tensor, input_tensor).abs().sum()
loss.backward() loss.backward()
...@@ -477,7 +477,7 @@ def test_two_optimizers(): ...@@ -477,7 +477,7 @@ def test_two_optimizers():
) )
def run_test_gpt2(rank, world_size, backend, device, temp_file_name): def run_test_gpt2(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
INPUT_DIM = 16 INPUT_DIM = 16
BACH_SIZE = 10 BACH_SIZE = 10
STEPS = 10 STEPS = 10
...@@ -492,14 +492,19 @@ def run_test_gpt2(rank, world_size, backend, device, temp_file_name): ...@@ -492,14 +492,19 @@ def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
embed_dim=256, num_heads=2, num_layers=12, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2 embed_dim=256, num_heads=2, num_layers=12, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2
) )
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer) ddp_model = ShardedDataParallel(model, optimizer, reduce_buffer_size=reduce_buffer_size)
# Move the model to another device post-construction # Move the model to another device post-construction
model = model.to(device) model = model.to(device)
# Optim loop # Optim loop
set_to_none = True
def closure(): def closure():
optimizer.zero_grad() nonlocal set_to_none
ddp_model.zero_grad(set_to_none=set_to_none)
set_to_none = not set_to_none
# Force int inputs to prevent the first grad from firing # Force int inputs to prevent the first grad from firing
input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device) input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device)
loss = ddp_model(input_tensor).abs().sum() loss = ddp_model(input_tensor).abs().sum()
...@@ -510,18 +515,28 @@ def run_test_gpt2(rank, world_size, backend, device, temp_file_name): ...@@ -510,18 +515,28 @@ def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
for i in range(STEPS): for i in range(STEPS):
_ = optimizer.step(closure=closure) _ = optimizer.step(closure=closure)
# Stress test the .to() method
ddp_model.to(device=device, dtype=torch.float16)
ddp_model.to(device=device, dtype=torch.float32)
dist.destroy_process_group() dist.destroy_process_group()
@skip_if_no_cuda @skip_if_no_cuda
@skip_if_single_gpu @skip_if_single_gpu
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
def test_gpt2(world_size): @pytest.mark.parametrize("reduce_buffer", [2 ** 23, 2 ** 40])
def test_gpt2(world_size, reduce_buffer):
# Check that having trainable unused params is fine # Check that having trainable unused params is fine
backend = "gloo" backend = "gloo"
device = "cuda" device = "cuda"
with temp_files_ctx(num=1) as temp_files: with temp_files_ctx(num=1) as temp_files:
mp.spawn(run_test_gpt2, args=(world_size, backend, device, temp_files[0]), nprocs=world_size, join=True) mp.spawn(
run_test_gpt2,
args=(world_size, backend, device, temp_files[0], reduce_buffer),
nprocs=world_size,
join=True,
)
def run_test_multiple_groups(rank, world_size, tempfile_name, backend, reduce_buffer_size): def run_test_multiple_groups(rank, world_size, tempfile_name, backend, reduce_buffer_size):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment