Unverified Commit ece0cbf9 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[chore][fix] SDP: yet another unit test improvement + bugfixes (#546)

* re-activating unit test
* removing changed that slipped in
parent 9474d75d
......@@ -178,6 +178,7 @@ class ShardedDataParallel(nn.Module):
# - setup backward hooks which will be called by Torch's autograd in due time
self._grad_accs: List[Callable] = []
self._manual_reduce: List[Callable] = []
# passing a handle to torch.nn.SyncBatchNorm layer
self._passing_sync_batchnorm_handle(self.module)
......@@ -202,7 +203,6 @@ class ShardedDataParallel(nn.Module):
trainable_mask = list(map(_trainable, self._all_params))
if trainable_mask != self._reference_trainable_mask:
logging.warning("ShardedDDP detected that the trainable params changed, updating the partitioning")
self.refresh_trainable()
self._reference_trainable_mask = trainable_mask
......@@ -303,7 +303,8 @@ class ShardedDataParallel(nn.Module):
), "No grads waiting to be reduced, maybe that this was called twice or there was no BW pass ?"
# Trigger all the current BW hooks
_ = map(lambda x: x(), self._grad_accs)
self._bucket_flush_callback_set = True # no need to flush in the end, we own the callback execution
_ = list(map(lambda x: x(), self._manual_reduce))
# Make sure that all the futures are consumed
self._consume_work_handles()
......@@ -433,6 +434,7 @@ class ShardedDataParallel(nn.Module):
@torch.no_grad()
def reduce(*_: Any) -> None:
# Skip gradient reduction, do not alter status flags
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
......@@ -478,6 +480,7 @@ class ShardedDataParallel(nn.Module):
# Go through the parameters, attach the hook
self._grad_accs = []
self._manual_reduce = []
for index, param in enumerate(self._trainable_params):
if param.grad is not None and param.grad.requires_grad:
raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad")
......@@ -489,8 +492,10 @@ class ShardedDataParallel(nn.Module):
grad_acc = p_tmp.grad_fn.next_functions[0][0]
dst_rank = self._trainable_param_to_rank[param]
grad_acc.register_hook(self._get_reduce_fn(index, param, dst_rank))
self._grad_accs.append(grad_acc) # keep this function in scope
reduce_function = self._get_reduce_fn(index, param, dst_rank)
grad_acc.register_hook(reduce_function)
self._manual_reduce.append(reduce_function)
self._grad_accs.append(grad_acc) # keep this hook in scope
@torch.no_grad()
def _sync_params_and_buffers(self) -> None:
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, List, Optional, Union
import torch
class GradBucket:
"""
Helper class to simplify the handling of gradient buckets
"""
def __init__(self, size: int, dtype: torch.dtype, device: torch.device, destination: int) -> None:
self._max_size = size
self._params: List[torch.Tensor] = []
self._param_ids: List[int] = []
self._fill = 0
self._is_collapsed = False
# The actual flat tensor
self.buffer: Optional[torch.Tensor] = torch.zeros(self._max_size, dtype=dtype, device=device)
self.params_checked_in = 0
self.destination = destination
self.sent = True
self.callback: Optional[Callable[[Any], None]] = None
def reset_checked_in(self) -> None:
""" Reset the counter of the parameter grads which have been checked in
"""
self.params_checked_in = 0
self.sent = False
@property
def all_checked_in(self) -> bool:
""" Have all the expected gradient check-in happened ?"""
return len(self._params) == self.params_checked_in
def can_add_grad_view(self, param: torch.Tensor) -> bool:
""" Is there enough room in the bucket to add this parameter gradient, and is this param not already checked in ?
"""
return self._fill + param.numel() < self._max_size and id(param) not in self._param_ids
def to( # type: ignore
self,
device: Optional[Union[int, torch.device]],
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> "GradBucket":
"""
Move the underlying buffer
"""
if self._is_collapsed:
self.rebuild()
assert self.buffer is not None, "Cannot move a collapsed bucket, please rebuild it"
self.buffer.to(device, dtype, non_blocking)
def zero(self) -> None:
"""
Set all the grads to zero
"""
if self.buffer is not None:
self.buffer.fill_(0.0)
@torch.no_grad()
def add_grad(self, param: torch.Tensor) -> None:
"""
Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer
"""
assert id(param) not in self._param_ids, "The same gradients cannot be checked in twice"
if param.grad is None:
param.grad = torch.zeros_like(param)
self._add_grad_as_view(param)
self._params.append(param)
self._param_ids.append(id(param))
@torch.no_grad()
def collapse(self) -> None:
"""
Release the buffer from memory. The bucket will need to be rebuilt before use
"""
if not self._is_collapsed:
for p in self._params:
assert p.grad is not None
p.grad.detach_()
p.grad = None
self.buffer = None
self._fill = 0
self.params_checked_in = 0
self._is_collapsed = True
@torch.no_grad()
def rebuild(self) -> None:
"""
Given the parameter gradients which have been registered previously, rebuild the whole bucket
"""
assert len(self._params) > 0
if self._is_collapsed:
self.buffer = torch.zeros(self._max_size, dtype=self._params[0].dtype, device=self._params[0].device)
for p in self._params:
self._add_grad_as_view(p)
self._is_collapsed = False
@torch.no_grad()
def shrink(self) -> None:
"""
Shrink the buffer to the size of the parameter gradients currently checked in, release the extra memory
"""
assert self.buffer is not None, "Cannot shrink a collapsed bucket, please rebuild"
self.buffer = self.buffer.resize_(self._fill).clone()
self._fill = 0
for p in self._params:
self._add_grad_as_view(p)
self._max_size = self._fill
@torch.no_grad()
def _add_grad_as_view(self, param: torch.Tensor) -> None:
assert self.buffer is not None
assert param.dtype == self.buffer.dtype
assert param.device == self.buffer.device
fill_next = self._fill + param.numel()
assert fill_next <= self.buffer.numel()
# Copy the current grad value, if any
if param.grad is not None:
# keep param.grad in place
self.buffer[self._fill : fill_next].copy_(param.grad.data.flatten())
param.grad.data = self.buffer[self._fill : fill_next].view_as(param.data)
else:
param.grad = self.buffer[self._fill : fill_next].view_as(param.data)
self._fill = fill_next
......@@ -54,6 +54,7 @@ def run_ddp_parity(
fp16_reduction,
clip_grad_norm,
amp,
manual_reduction,
):
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
......@@ -64,132 +65,134 @@ def run_ddp_parity(
NUMBER_BATCHS = 5
BATCH_SIZE = 8
def check_parity(manual_reduction: bool):
# Test all combinations: AMP, Accumulate, Change train graph, reduce buckets
print(
f"{rank}: Checking configuration: accumulate {grad_accumulation}"
+ f" - change train graph {change_train_graph}"
+ f" - amp {amp}"
+ f" - manual reduction {manual_reduction}"
+ f" - buffers {reduce_buffer_size}",
flush=True,
)
# The API should be the exact same in between the sharded and non-sharded variants, generic closure
def closure(model, scaler, input_tensor, should_accumulate, _manual_reduction=False):
accumulate_steps = 3 if should_accumulate else 1
# The API should be the exact same in between the sharded and non-sharded variants, generic closure
def closure(model, scaler, input_tensor, should_accumulate, _manual_reduction=False):
accumulate_steps = 3 if should_accumulate else 1
model.zero_grad()
model.zero_grad()
def step():
if scaler is not None:
with torch.cuda.amp.autocast():
loss = model(input_tensor).abs().sum()
scaler.scale(loss).backward()
else:
def step():
if scaler is not None:
with torch.cuda.amp.autocast():
loss = model(input_tensor).abs().sum()
loss.backward()
scaler.scale(loss).backward()
else:
loss = model(input_tensor).abs().sum()
loss.backward()
with model.no_sync() if should_accumulate else suppress():
for _ in range(accumulate_steps - 1):
step()
with model.no_sync() if should_accumulate else suppress():
for _ in range(accumulate_steps - 1):
step()
if not _manual_reduction:
if not _manual_reduction:
step()
else:
with model.no_sync():
step()
else:
with model.no_sync():
step()
model.reduce()
# Any model works. Add one different buffer per rank
model = _get_mlp()
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
# Make sure that the model starts with non-trainable, so that we check for the buckets to be
# properly reassigned when/if this changes
next(model.parameters()).requires_grad = False
sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-4, momentum=0.99)
sharded_ddp_model = ShardedDataParallel(
module=model,
sharded_optimizer=sharded_optimizer,
broadcast_buffers=True,
reduce_buffer_size=reduce_buffer_size,
reduce_fp16=fp16_reduction,
)
ddp_model_single = copy.deepcopy(model)
ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-4, momentum=0.99)
ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True)
if fp16_reduction:
from dist.algorithms.ddp_com_hooks.default_hooks import fp16_compress_hook
ddp_model.register_comm_hook(state=None, hook=fp16_compress_hook) # type: ignore
ddp_scaler = TorchGradScaler() if amp else None
sharded_scaler = ShardedGradScaler() if amp else None
# The model should be synchronized in between the ranks at construction time, check that
check_same_model_params(sharded_ddp_model, ddp_model)
# Typical training loop, check that we get the exact same results as DDP
for i in range(NUMBER_BATCHS):
input_tensor = torch.rand((BATCH_SIZE, 2)).to(device)
def ddp_closure(input_tensor=input_tensor):
return closure(ddp_model, ddp_scaler, input_tensor, grad_accumulation)
def sharded_closure(input_tensor=input_tensor):
return closure(
sharded_ddp_model,
sharded_scaler,
input_tensor,
grad_accumulation,
_manual_reduction=manual_reduction,
)
# Step/scale both
for _scaler, _closure, _optimizer in (
(ddp_scaler, ddp_closure, ddp_optimizer),
(sharded_scaler, sharded_closure, sharded_optimizer),
):
if _scaler is not None:
_ = _closure(input_tensor)
_scaler.step(_optimizer)
_scaler.update()
check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Step {i} broke")
# Check that the two grad norm are equivalent
# NOTE: The grads can occasionally be NaNs, the scaler will skip the step in that case
# This is not ShardedDDP specific. If the grads are not NaN for DDP then they should also
# be valid for ShardedDDP
if clip_grad_norm:
total_norm = torch.nn.utils.clip_grad_norm_(ddp_model.parameters(), 0.3, norm_type=2.0) # type: ignore
if not torch.isnan(total_norm):
oss_total_norm = sharded_optimizer.clip_grad_norm(0.3, norm_type=2.0)
assert torch.allclose(
oss_total_norm, total_norm, atol=1e-2 if amp else 1e-8
), f"torch and fairscale should return the same grad norm\n {oss_total_norm} vs {total_norm}"
else:
print(rank, "NaN grad norm in DDP", flush=True)
# Flip the trainability of the first parameter back and forth
if i == 0 and change_train_graph:
next(sharded_ddp_model.parameters()).requires_grad = not next(
sharded_ddp_model.parameters()
).requires_grad
next(ddp_model.parameters()).requires_grad = not next(ddp_model.parameters()).requires_grad
check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke")
model.reduce()
# Test all combinations: AMP, Accumulate, Change train graph, reduce buckets
manual_reductions = [False, True] if not grad_accumulation and not change_train_graph else [False]
for manual_reduction in manual_reductions:
print(
f"{rank}: Checking configuration: accumulate {grad_accumulation}"
+ f" - change train graph {change_train_graph}"
+ f" - amp {amp}"
+ f" - manual reduction {manual_reduction}"
+ f" - buffers {reduce_buffer_size}",
flush=True,
)
check_parity(manual_reduction=manual_reduction)
torch.cuda.synchronize()
torch.distributed.barrier()
# Any model works. Add one different buffer per rank
model = _get_mlp()
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
# Make sure that the model starts with non-trainable, so that we check for the buckets to be
# properly reassigned when/if this changes
next(model.parameters()).requires_grad = False
sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-4, momentum=0.99)
sharded_ddp_model = ShardedDataParallel(
module=model,
sharded_optimizer=sharded_optimizer,
broadcast_buffers=True,
reduce_buffer_size=reduce_buffer_size,
reduce_fp16=fp16_reduction,
)
ddp_model_single = copy.deepcopy(model)
ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-4, momentum=0.99)
ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True)
if fp16_reduction:
from dist.algorithms.ddp_com_hooks.default_hooks import fp16_compress_hook
ddp_model.register_comm_hook(state=None, hook=fp16_compress_hook) # type: ignore
ddp_scaler = TorchGradScaler() if amp else None
sharded_scaler = ShardedGradScaler() if amp else None
# The model should be synchronized in between the ranks at construction time, check that
check_same_model_params(sharded_ddp_model, ddp_model)
# Typical training loop, check that we get the exact same results as DDP
for i in range(NUMBER_BATCHS):
input_tensor = torch.rand((BATCH_SIZE, 2)).to(device)
def ddp_closure(input_tensor=input_tensor):
return closure(ddp_model, ddp_scaler, input_tensor, grad_accumulation)
def sharded_closure(input_tensor=input_tensor):
return closure(
sharded_ddp_model, sharded_scaler, input_tensor, grad_accumulation, _manual_reduction=manual_reduction,
)
# Step/scale both
for _scaler, _closure, _optimizer in (
(ddp_scaler, ddp_closure, ddp_optimizer),
(sharded_scaler, sharded_closure, sharded_optimizer),
):
if _scaler is not None:
_ = _closure(input_tensor)
_scaler.step(_optimizer)
_scaler.update()
else:
_optimizer.step(_closure())
check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Step {i} broke")
# Check that the two grad norm are equivalent
# NOTE: The grads can occasionally be NaNs, the scaler will skip the step in that case
# This is not ShardedDDP specific. If the grads are not NaN for DDP then they should also
# be valid for ShardedDDP
# NOTE: DDP does not handle parameters trainability being changed after the fact, see
# https://github.com/pytorch/pytorch/blob/5781aec74ef00284e0262817a649278c2e8072bf/torch/nn/parallel/distributed.py#L471
if clip_grad_norm and not change_train_graph:
total_norm = torch.nn.utils.clip_grad_norm_(ddp_model.parameters(), 0.3, norm_type=2.0) # type: ignore
if not torch.isnan(total_norm):
oss_total_norm = sharded_optimizer.clip_grad_norm(0.3, norm_type=2.0)
allclose = torch.allclose(oss_total_norm, total_norm, atol=1e-2 if amp else 1e-8)
if not allclose:
# Debug helper if this unit test does not pass, compare the gradients in between DDP and ShardedDDP
for idx, (p_ddp, p_sdp) in enumerate(zip(ddp_model.parameters(), sharded_ddp_model.parameters())):
if p_ddp.grad is not None:
if p_sdp.grad is not None:
print(rank, idx, torch.norm(p_ddp.grad), torch.norm(p_sdp.grad), flush=True)
else:
print(rank, idx, torch.norm(p_ddp.grad), "not owned", flush=True)
assert (
allclose
), f"torch and fairscale should return the same grad norm\n {oss_total_norm} vs {total_norm}"
else:
print(rank, "NaN grad norm in DDP", flush=True)
# Flip the trainability of the first parameter back and forth
if i == 0 and change_train_graph:
next(sharded_ddp_model.parameters()).requires_grad = not next(sharded_ddp_model.parameters()).requires_grad
next(ddp_model.parameters()).requires_grad = not next(ddp_model.parameters()).requires_grad
check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke")
dist.destroy_process_group()
......@@ -202,7 +205,13 @@ def run_ddp_parity(
@pytest.mark.parametrize("fp16_reduction", _test_fp16_reduction)
@pytest.mark.parametrize("clip_grad_norm", [True, False])
@pytest.mark.parametrize("amp", _test_amp)
def test_ddp_parity(reduce_buffer_size, grad_accumulation, change_train_graph, fp16_reduction, clip_grad_norm, amp):
@pytest.mark.parametrize("manual_reduction", [True, False])
def test_ddp_parity(
reduce_buffer_size, grad_accumulation, change_train_graph, fp16_reduction, clip_grad_norm, amp, manual_reduction
):
if manual_reduction and change_train_graph:
pytest.skip("Skipping changing model and grad accumulation combination, makes little sense")
world_size = torch.cuda.device_count()
backend = dist.Backend.NCCL
mp.spawn(
......@@ -217,6 +226,7 @@ def test_ddp_parity(reduce_buffer_size, grad_accumulation, change_train_graph, f
fp16_reduction,
clip_grad_norm,
amp,
manual_reduction,
),
nprocs=world_size,
join=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment