Unverified Commit 69cbdf5d authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[draft][chore] SDP : increase code coverage (#653)

* increasing the code coverage, good practice and raising bugs.  hopefully getting to 100%
* small bugfix
parent c65a48f3
......@@ -224,10 +224,7 @@ class ShardedDataParallel(nn.Module):
return self.module(*inputs, **kwargs)
def to( # type: ignore
self,
device: Optional[Union[int, torch.device]],
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
self, device: Optional[torch.device], dtype: Optional[torch.dtype] = None, non_blocking: bool = False,
) -> "ShardedDataParallel":
"""
Moves and/or casts the parameters and buffers.
......@@ -257,20 +254,24 @@ class ShardedDataParallel(nn.Module):
Returns:
Module: self.
"""
if isinstance(device, str):
device = torch.device(device)
assert (
len(self._buckets.keys()) == 0 or device in self._buckets.keys()
device is None
or len(self._buckets.keys()) == 0
or device.type in map(lambda x: x.type, self._buckets.keys())
), "Changing devices is not supported, because this would break OSSs state"
assert (
len(self._buckets.keys()) < 2
), "Several devices specified to begin with, incompatible with setting a single device here"
for _device in self._buckets.keys():
for bucket in self._buckets[_device].values():
bucket.to(device=_device, dtype=dtype, non_blocking=non_blocking)
self.module.to(device=device, dtype=dtype, non_blocking=non_blocking)
# Re-build the buckets, hooks, etc..
self.refresh_trainable()
def refresh_trainable(self) -> None:
""" If the module trainability has changed, update all the assumptions """
......@@ -350,8 +351,8 @@ class ShardedDataParallel(nn.Module):
See :meth:`torch.optim.Optimizer.zero_grad` for details.
"""
for index, trainable_param in enumerate(self._all_params):
if set_to_none and not self._should_bucket_grad[index]:
for index, trainable_param in enumerate(self._trainable_params):
if set_to_none and (len(self._should_bucket_grad) == 0 or not self._should_bucket_grad[index]):
trainable_param.grad = None
elif trainable_param.grad is not None:
trainable_param.grad.zero_()
......
......@@ -108,7 +108,7 @@ def run_one_step(
# Optim loop
def closure():
optimizer.zero_grad()
ddp_model.zero_grad(set_to_none=True)
with ddp_model.no_sync() if grad_accumulation else suppress():
input_tensor = torch.rand((64, 2)).to(device)
......@@ -193,7 +193,7 @@ def run_test_two_inputs(rank, world_size, backend, device, temp_file_name, reduc
# Optim loop
def closure():
optimizer.zero_grad()
ddp_model.zero_grad(set_to_none=True)
input_tensor = torch.rand((64, 2)).to(device)
loss = ddp_model(input_tensor, input_tensor).abs().sum()
loss.backward()
......@@ -477,7 +477,7 @@ def test_two_optimizers():
)
def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
def run_test_gpt2(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
INPUT_DIM = 16
BACH_SIZE = 10
STEPS = 10
......@@ -492,14 +492,19 @@ def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
embed_dim=256, num_heads=2, num_layers=12, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2
)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
ddp_model = ShardedDataParallel(model, optimizer, reduce_buffer_size=reduce_buffer_size)
# Move the model to another device post-construction
model = model.to(device)
# Optim loop
set_to_none = True
def closure():
optimizer.zero_grad()
nonlocal set_to_none
ddp_model.zero_grad(set_to_none=set_to_none)
set_to_none = not set_to_none
# Force int inputs to prevent the first grad from firing
input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device)
loss = ddp_model(input_tensor).abs().sum()
......@@ -510,18 +515,28 @@ def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
for i in range(STEPS):
_ = optimizer.step(closure=closure)
# Stress test the .to() method
ddp_model.to(device=device, dtype=torch.float16)
ddp_model.to(device=device, dtype=torch.float32)
dist.destroy_process_group()
@skip_if_no_cuda
@skip_if_single_gpu
@pytest.mark.parametrize("world_size", [1, 2])
def test_gpt2(world_size):
@pytest.mark.parametrize("reduce_buffer", [2 ** 23, 2 ** 40])
def test_gpt2(world_size, reduce_buffer):
# Check that having trainable unused params is fine
backend = "gloo"
device = "cuda"
with temp_files_ctx(num=1) as temp_files:
mp.spawn(run_test_gpt2, args=(world_size, backend, device, temp_files[0]), nprocs=world_size, join=True)
mp.spawn(
run_test_gpt2,
args=(world_size, backend, device, temp_files[0], reduce_buffer),
nprocs=world_size,
join=True,
)
def run_test_multiple_groups(rank, world_size, tempfile_name, backend, reduce_buffer_size):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment