Unverified Commit ece0cbf9 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[chore][fix] SDP: yet another unit test improvement + bugfixes (#546)

* re-activating unit test
* removing changed that slipped in
parent 9474d75d
...@@ -178,6 +178,7 @@ class ShardedDataParallel(nn.Module): ...@@ -178,6 +178,7 @@ class ShardedDataParallel(nn.Module):
# - setup backward hooks which will be called by Torch's autograd in due time # - setup backward hooks which will be called by Torch's autograd in due time
self._grad_accs: List[Callable] = [] self._grad_accs: List[Callable] = []
self._manual_reduce: List[Callable] = []
# passing a handle to torch.nn.SyncBatchNorm layer # passing a handle to torch.nn.SyncBatchNorm layer
self._passing_sync_batchnorm_handle(self.module) self._passing_sync_batchnorm_handle(self.module)
...@@ -202,7 +203,6 @@ class ShardedDataParallel(nn.Module): ...@@ -202,7 +203,6 @@ class ShardedDataParallel(nn.Module):
trainable_mask = list(map(_trainable, self._all_params)) trainable_mask = list(map(_trainable, self._all_params))
if trainable_mask != self._reference_trainable_mask: if trainable_mask != self._reference_trainable_mask:
logging.warning("ShardedDDP detected that the trainable params changed, updating the partitioning") logging.warning("ShardedDDP detected that the trainable params changed, updating the partitioning")
self.refresh_trainable() self.refresh_trainable()
self._reference_trainable_mask = trainable_mask self._reference_trainable_mask = trainable_mask
...@@ -303,7 +303,8 @@ class ShardedDataParallel(nn.Module): ...@@ -303,7 +303,8 @@ class ShardedDataParallel(nn.Module):
), "No grads waiting to be reduced, maybe that this was called twice or there was no BW pass ?" ), "No grads waiting to be reduced, maybe that this was called twice or there was no BW pass ?"
# Trigger all the current BW hooks # Trigger all the current BW hooks
_ = map(lambda x: x(), self._grad_accs) self._bucket_flush_callback_set = True # no need to flush in the end, we own the callback execution
_ = list(map(lambda x: x(), self._manual_reduce))
# Make sure that all the futures are consumed # Make sure that all the futures are consumed
self._consume_work_handles() self._consume_work_handles()
...@@ -433,6 +434,7 @@ class ShardedDataParallel(nn.Module): ...@@ -433,6 +434,7 @@ class ShardedDataParallel(nn.Module):
@torch.no_grad() @torch.no_grad()
def reduce(*_: Any) -> None: def reduce(*_: Any) -> None:
# Skip gradient reduction, do not alter status flags # Skip gradient reduction, do not alter status flags
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]: if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None" assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
...@@ -478,6 +480,7 @@ class ShardedDataParallel(nn.Module): ...@@ -478,6 +480,7 @@ class ShardedDataParallel(nn.Module):
# Go through the parameters, attach the hook # Go through the parameters, attach the hook
self._grad_accs = [] self._grad_accs = []
self._manual_reduce = []
for index, param in enumerate(self._trainable_params): for index, param in enumerate(self._trainable_params):
if param.grad is not None and param.grad.requires_grad: if param.grad is not None and param.grad.requires_grad:
raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad") raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad")
...@@ -489,8 +492,10 @@ class ShardedDataParallel(nn.Module): ...@@ -489,8 +492,10 @@ class ShardedDataParallel(nn.Module):
grad_acc = p_tmp.grad_fn.next_functions[0][0] grad_acc = p_tmp.grad_fn.next_functions[0][0]
dst_rank = self._trainable_param_to_rank[param] dst_rank = self._trainable_param_to_rank[param]
grad_acc.register_hook(self._get_reduce_fn(index, param, dst_rank)) reduce_function = self._get_reduce_fn(index, param, dst_rank)
self._grad_accs.append(grad_acc) # keep this function in scope grad_acc.register_hook(reduce_function)
self._manual_reduce.append(reduce_function)
self._grad_accs.append(grad_acc) # keep this hook in scope
@torch.no_grad() @torch.no_grad()
def _sync_params_and_buffers(self) -> None: def _sync_params_and_buffers(self) -> None:
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, List, Optional, Union
import torch
class GradBucket:
"""
Helper class to simplify the handling of gradient buckets
"""
def __init__(self, size: int, dtype: torch.dtype, device: torch.device, destination: int) -> None:
self._max_size = size
self._params: List[torch.Tensor] = []
self._param_ids: List[int] = []
self._fill = 0
self._is_collapsed = False
# The actual flat tensor
self.buffer: Optional[torch.Tensor] = torch.zeros(self._max_size, dtype=dtype, device=device)
self.params_checked_in = 0
self.destination = destination
self.sent = True
self.callback: Optional[Callable[[Any], None]] = None
def reset_checked_in(self) -> None:
""" Reset the counter of the parameter grads which have been checked in
"""
self.params_checked_in = 0
self.sent = False
@property
def all_checked_in(self) -> bool:
""" Have all the expected gradient check-in happened ?"""
return len(self._params) == self.params_checked_in
def can_add_grad_view(self, param: torch.Tensor) -> bool:
""" Is there enough room in the bucket to add this parameter gradient, and is this param not already checked in ?
"""
return self._fill + param.numel() < self._max_size and id(param) not in self._param_ids
def to( # type: ignore
self,
device: Optional[Union[int, torch.device]],
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> "GradBucket":
"""
Move the underlying buffer
"""
if self._is_collapsed:
self.rebuild()
assert self.buffer is not None, "Cannot move a collapsed bucket, please rebuild it"
self.buffer.to(device, dtype, non_blocking)
def zero(self) -> None:
"""
Set all the grads to zero
"""
if self.buffer is not None:
self.buffer.fill_(0.0)
@torch.no_grad()
def add_grad(self, param: torch.Tensor) -> None:
"""
Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer
"""
assert id(param) not in self._param_ids, "The same gradients cannot be checked in twice"
if param.grad is None:
param.grad = torch.zeros_like(param)
self._add_grad_as_view(param)
self._params.append(param)
self._param_ids.append(id(param))
@torch.no_grad()
def collapse(self) -> None:
"""
Release the buffer from memory. The bucket will need to be rebuilt before use
"""
if not self._is_collapsed:
for p in self._params:
assert p.grad is not None
p.grad.detach_()
p.grad = None
self.buffer = None
self._fill = 0
self.params_checked_in = 0
self._is_collapsed = True
@torch.no_grad()
def rebuild(self) -> None:
"""
Given the parameter gradients which have been registered previously, rebuild the whole bucket
"""
assert len(self._params) > 0
if self._is_collapsed:
self.buffer = torch.zeros(self._max_size, dtype=self._params[0].dtype, device=self._params[0].device)
for p in self._params:
self._add_grad_as_view(p)
self._is_collapsed = False
@torch.no_grad()
def shrink(self) -> None:
"""
Shrink the buffer to the size of the parameter gradients currently checked in, release the extra memory
"""
assert self.buffer is not None, "Cannot shrink a collapsed bucket, please rebuild"
self.buffer = self.buffer.resize_(self._fill).clone()
self._fill = 0
for p in self._params:
self._add_grad_as_view(p)
self._max_size = self._fill
@torch.no_grad()
def _add_grad_as_view(self, param: torch.Tensor) -> None:
assert self.buffer is not None
assert param.dtype == self.buffer.dtype
assert param.device == self.buffer.device
fill_next = self._fill + param.numel()
assert fill_next <= self.buffer.numel()
# Copy the current grad value, if any
if param.grad is not None:
# keep param.grad in place
self.buffer[self._fill : fill_next].copy_(param.grad.data.flatten())
param.grad.data = self.buffer[self._fill : fill_next].view_as(param.data)
else:
param.grad = self.buffer[self._fill : fill_next].view_as(param.data)
self._fill = fill_next
...@@ -54,6 +54,7 @@ def run_ddp_parity( ...@@ -54,6 +54,7 @@ def run_ddp_parity(
fp16_reduction, fp16_reduction,
clip_grad_norm, clip_grad_norm,
amp, amp,
manual_reduction,
): ):
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
...@@ -64,132 +65,134 @@ def run_ddp_parity( ...@@ -64,132 +65,134 @@ def run_ddp_parity(
NUMBER_BATCHS = 5 NUMBER_BATCHS = 5
BATCH_SIZE = 8 BATCH_SIZE = 8
def check_parity(manual_reduction: bool): # Test all combinations: AMP, Accumulate, Change train graph, reduce buckets
print(
f"{rank}: Checking configuration: accumulate {grad_accumulation}"
+ f" - change train graph {change_train_graph}"
+ f" - amp {amp}"
+ f" - manual reduction {manual_reduction}"
+ f" - buffers {reduce_buffer_size}",
flush=True,
)
# The API should be the exact same in between the sharded and non-sharded variants, generic closure # The API should be the exact same in between the sharded and non-sharded variants, generic closure
def closure(model, scaler, input_tensor, should_accumulate, _manual_reduction=False): def closure(model, scaler, input_tensor, should_accumulate, _manual_reduction=False):
accumulate_steps = 3 if should_accumulate else 1 accumulate_steps = 3 if should_accumulate else 1
model.zero_grad() model.zero_grad()
def step(): def step():
if scaler is not None: if scaler is not None:
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
loss = model(input_tensor).abs().sum()
scaler.scale(loss).backward()
else:
loss = model(input_tensor).abs().sum() loss = model(input_tensor).abs().sum()
loss.backward() scaler.scale(loss).backward()
else:
loss = model(input_tensor).abs().sum()
loss.backward()
with model.no_sync() if should_accumulate else suppress(): with model.no_sync() if should_accumulate else suppress():
for _ in range(accumulate_steps - 1): for _ in range(accumulate_steps - 1):
step() step()
if not _manual_reduction: if not _manual_reduction:
step()
else:
with model.no_sync():
step() step()
else:
with model.no_sync():
step()
model.reduce()
# Any model works. Add one different buffer per rank
model = _get_mlp()
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
# Make sure that the model starts with non-trainable, so that we check for the buckets to be
# properly reassigned when/if this changes
next(model.parameters()).requires_grad = False
sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-4, momentum=0.99)
sharded_ddp_model = ShardedDataParallel(
module=model,
sharded_optimizer=sharded_optimizer,
broadcast_buffers=True,
reduce_buffer_size=reduce_buffer_size,
reduce_fp16=fp16_reduction,
)
ddp_model_single = copy.deepcopy(model) model.reduce()
ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-4, momentum=0.99)
ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True)
if fp16_reduction:
from dist.algorithms.ddp_com_hooks.default_hooks import fp16_compress_hook
ddp_model.register_comm_hook(state=None, hook=fp16_compress_hook) # type: ignore
ddp_scaler = TorchGradScaler() if amp else None
sharded_scaler = ShardedGradScaler() if amp else None
# The model should be synchronized in between the ranks at construction time, check that
check_same_model_params(sharded_ddp_model, ddp_model)
# Typical training loop, check that we get the exact same results as DDP
for i in range(NUMBER_BATCHS):
input_tensor = torch.rand((BATCH_SIZE, 2)).to(device)
def ddp_closure(input_tensor=input_tensor):
return closure(ddp_model, ddp_scaler, input_tensor, grad_accumulation)
def sharded_closure(input_tensor=input_tensor):
return closure(
sharded_ddp_model,
sharded_scaler,
input_tensor,
grad_accumulation,
_manual_reduction=manual_reduction,
)
# Step/scale both
for _scaler, _closure, _optimizer in (
(ddp_scaler, ddp_closure, ddp_optimizer),
(sharded_scaler, sharded_closure, sharded_optimizer),
):
if _scaler is not None:
_ = _closure(input_tensor)
_scaler.step(_optimizer)
_scaler.update()
check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Step {i} broke")
# Check that the two grad norm are equivalent
# NOTE: The grads can occasionally be NaNs, the scaler will skip the step in that case
# This is not ShardedDDP specific. If the grads are not NaN for DDP then they should also
# be valid for ShardedDDP
if clip_grad_norm:
total_norm = torch.nn.utils.clip_grad_norm_(ddp_model.parameters(), 0.3, norm_type=2.0) # type: ignore
if not torch.isnan(total_norm):
oss_total_norm = sharded_optimizer.clip_grad_norm(0.3, norm_type=2.0)
assert torch.allclose(
oss_total_norm, total_norm, atol=1e-2 if amp else 1e-8
), f"torch and fairscale should return the same grad norm\n {oss_total_norm} vs {total_norm}"
else:
print(rank, "NaN grad norm in DDP", flush=True)
# Flip the trainability of the first parameter back and forth
if i == 0 and change_train_graph:
next(sharded_ddp_model.parameters()).requires_grad = not next(
sharded_ddp_model.parameters()
).requires_grad
next(ddp_model.parameters()).requires_grad = not next(ddp_model.parameters()).requires_grad
check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke")
# Test all combinations: AMP, Accumulate, Change train graph, reduce buckets # Any model works. Add one different buffer per rank
manual_reductions = [False, True] if not grad_accumulation and not change_train_graph else [False] model = _get_mlp()
for manual_reduction in manual_reductions: model.register_buffer("test_buffer", torch.ones((1)) * rank)
print( model.to(device)
f"{rank}: Checking configuration: accumulate {grad_accumulation}"
+ f" - change train graph {change_train_graph}" # Make sure that the model starts with non-trainable, so that we check for the buckets to be
+ f" - amp {amp}" # properly reassigned when/if this changes
+ f" - manual reduction {manual_reduction}" next(model.parameters()).requires_grad = False
+ f" - buffers {reduce_buffer_size}",
flush=True, sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-4, momentum=0.99)
) sharded_ddp_model = ShardedDataParallel(
check_parity(manual_reduction=manual_reduction) module=model,
torch.cuda.synchronize() sharded_optimizer=sharded_optimizer,
torch.distributed.barrier() broadcast_buffers=True,
reduce_buffer_size=reduce_buffer_size,
reduce_fp16=fp16_reduction,
)
ddp_model_single = copy.deepcopy(model)
ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-4, momentum=0.99)
ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True)
if fp16_reduction:
from dist.algorithms.ddp_com_hooks.default_hooks import fp16_compress_hook
ddp_model.register_comm_hook(state=None, hook=fp16_compress_hook) # type: ignore
ddp_scaler = TorchGradScaler() if amp else None
sharded_scaler = ShardedGradScaler() if amp else None
# The model should be synchronized in between the ranks at construction time, check that
check_same_model_params(sharded_ddp_model, ddp_model)
# Typical training loop, check that we get the exact same results as DDP
for i in range(NUMBER_BATCHS):
input_tensor = torch.rand((BATCH_SIZE, 2)).to(device)
def ddp_closure(input_tensor=input_tensor):
return closure(ddp_model, ddp_scaler, input_tensor, grad_accumulation)
def sharded_closure(input_tensor=input_tensor):
return closure(
sharded_ddp_model, sharded_scaler, input_tensor, grad_accumulation, _manual_reduction=manual_reduction,
)
# Step/scale both
for _scaler, _closure, _optimizer in (
(ddp_scaler, ddp_closure, ddp_optimizer),
(sharded_scaler, sharded_closure, sharded_optimizer),
):
if _scaler is not None:
_ = _closure(input_tensor)
_scaler.step(_optimizer)
_scaler.update()
else:
_optimizer.step(_closure())
check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Step {i} broke")
# Check that the two grad norm are equivalent
# NOTE: The grads can occasionally be NaNs, the scaler will skip the step in that case
# This is not ShardedDDP specific. If the grads are not NaN for DDP then they should also
# be valid for ShardedDDP
# NOTE: DDP does not handle parameters trainability being changed after the fact, see
# https://github.com/pytorch/pytorch/blob/5781aec74ef00284e0262817a649278c2e8072bf/torch/nn/parallel/distributed.py#L471
if clip_grad_norm and not change_train_graph:
total_norm = torch.nn.utils.clip_grad_norm_(ddp_model.parameters(), 0.3, norm_type=2.0) # type: ignore
if not torch.isnan(total_norm):
oss_total_norm = sharded_optimizer.clip_grad_norm(0.3, norm_type=2.0)
allclose = torch.allclose(oss_total_norm, total_norm, atol=1e-2 if amp else 1e-8)
if not allclose:
# Debug helper if this unit test does not pass, compare the gradients in between DDP and ShardedDDP
for idx, (p_ddp, p_sdp) in enumerate(zip(ddp_model.parameters(), sharded_ddp_model.parameters())):
if p_ddp.grad is not None:
if p_sdp.grad is not None:
print(rank, idx, torch.norm(p_ddp.grad), torch.norm(p_sdp.grad), flush=True)
else:
print(rank, idx, torch.norm(p_ddp.grad), "not owned", flush=True)
assert (
allclose
), f"torch and fairscale should return the same grad norm\n {oss_total_norm} vs {total_norm}"
else:
print(rank, "NaN grad norm in DDP", flush=True)
# Flip the trainability of the first parameter back and forth
if i == 0 and change_train_graph:
next(sharded_ddp_model.parameters()).requires_grad = not next(sharded_ddp_model.parameters()).requires_grad
next(ddp_model.parameters()).requires_grad = not next(ddp_model.parameters()).requires_grad
check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke")
dist.destroy_process_group() dist.destroy_process_group()
...@@ -202,7 +205,13 @@ def run_ddp_parity( ...@@ -202,7 +205,13 @@ def run_ddp_parity(
@pytest.mark.parametrize("fp16_reduction", _test_fp16_reduction) @pytest.mark.parametrize("fp16_reduction", _test_fp16_reduction)
@pytest.mark.parametrize("clip_grad_norm", [True, False]) @pytest.mark.parametrize("clip_grad_norm", [True, False])
@pytest.mark.parametrize("amp", _test_amp) @pytest.mark.parametrize("amp", _test_amp)
def test_ddp_parity(reduce_buffer_size, grad_accumulation, change_train_graph, fp16_reduction, clip_grad_norm, amp): @pytest.mark.parametrize("manual_reduction", [True, False])
def test_ddp_parity(
reduce_buffer_size, grad_accumulation, change_train_graph, fp16_reduction, clip_grad_norm, amp, manual_reduction
):
if manual_reduction and change_train_graph:
pytest.skip("Skipping changing model and grad accumulation combination, makes little sense")
world_size = torch.cuda.device_count() world_size = torch.cuda.device_count()
backend = dist.Backend.NCCL backend = dist.Backend.NCCL
mp.spawn( mp.spawn(
...@@ -217,6 +226,7 @@ def test_ddp_parity(reduce_buffer_size, grad_accumulation, change_train_graph, f ...@@ -217,6 +226,7 @@ def test_ddp_parity(reduce_buffer_size, grad_accumulation, change_train_graph, f
fp16_reduction, fp16_reduction,
clip_grad_norm, clip_grad_norm,
amp, amp,
manual_reduction,
), ),
nprocs=world_size, nprocs=world_size,
join=True, join=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment