You need to sign in or sign up before continuing.
Unverified Commit f74afebb authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] more adascale gradient accumulation tests and smoothing factor fix (#235)

* better ddp adascale tests

* make sure the single node test use the same test cases and expected gains

* added unit test that covers smoothing factor

- tested by re-introducing the bug and see the test fail as expected.
parent 2eef71b9
...@@ -72,7 +72,8 @@ class AdaScale(object): ...@@ -72,7 +72,8 @@ class AdaScale(object):
larger batch size (summed across all world_size) means a scale of larger batch size (summed across all world_size) means a scale of
10. If None, defaults to ``world_size``. 10. If None, defaults to ``world_size``.
smoothing (float): smoothing (float):
Smoothing factor between batches. Default value: 0.9999 Smoothing factor for moving average. If None, it defaults to
max(1 - (world_size * num_gradients_to_accumulate)/1000, 0).
num_gradients_to_accumulate (int): num_gradients_to_accumulate (int):
Number of passes that we accumulate gradients locally. Number of passes that we accumulate gradients locally.
Default to 1, which does not accumulate gradients. Default to 1, which does not accumulate gradients.
...@@ -83,7 +84,7 @@ class AdaScale(object): ...@@ -83,7 +84,7 @@ class AdaScale(object):
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
world_size: Optional[int] = None, world_size: Optional[int] = None,
scale: Optional[float] = None, scale: Optional[float] = None,
smoothing: float = 0.999, smoothing: float = None,
num_gradients_to_accumulate: int = 1, num_gradients_to_accumulate: int = 1,
): ):
self._optimizer = optimizer self._optimizer = optimizer
...@@ -91,7 +92,6 @@ class AdaScale(object): ...@@ -91,7 +92,6 @@ class AdaScale(object):
self._world_size: int = ( self._world_size: int = (
world_size if world_size is not None else dist.get_world_size() if dist.is_initialized() else 1 world_size if world_size is not None else dist.get_world_size() if dist.is_initialized() else 1
) )
self._smoothing = smoothing
self._num_backward_calls = 0 self._num_backward_calls = 0
self._num_grads_to_accum = num_gradients_to_accumulate self._num_grads_to_accum = num_gradients_to_accumulate
...@@ -110,6 +110,12 @@ class AdaScale(object): ...@@ -110,6 +110,12 @@ class AdaScale(object):
self.set_scale(self._world_size * self._num_grads_to_accum if scale is None else scale) self.set_scale(self._world_size * self._num_grads_to_accum if scale is None else scale)
# Set smoothing based on effective world_size rather than scale here, since world_size
# determines the number of samples being averaged over at every update
self._smoothing = (
max(1 - (self._world_size * self._num_grads_to_accum) / 1000, 0) if smoothing is None else smoothing
)
# Register the gradient hooks. Note, don't assume every param will generate # Register the gradient hooks. Note, don't assume every param will generate
# a gradient (i.e. triggering the hook) in every backward pass. # a gradient (i.e. triggering the hook) in every backward pass.
for idx, param_group in enumerate(self._optimizer.param_groups): for idx, param_group in enumerate(self._optimizer.param_groups):
...@@ -295,7 +301,7 @@ class AdaScale(object): ...@@ -295,7 +301,7 @@ class AdaScale(object):
grad_sqr = total_grad_sqr - grad_var / S grad_sqr = total_grad_sqr - grad_var / S
grad_var = np.maximum(grad_var, 1e-6) grad_var = np.maximum(grad_var, 1e-6)
grad_sqr = np.maximum(grad_sqr, 0.0) grad_sqr = np.maximum(grad_sqr, 0.0)
theta = self._smoothing ** S theta = self._smoothing
self._update_avg("grad_sqr_avg", grad_sqr, theta) self._update_avg("grad_sqr_avg", grad_sqr, theta)
self._update_avg("grad_var_avg", grad_var, theta) self._update_avg("grad_var_avg", grad_var, theta)
......
...@@ -32,33 +32,59 @@ def _dist_init(rank, world_size, tempfile_name, backend): ...@@ -32,33 +32,59 @@ def _dist_init(rank, world_size, tempfile_name, backend):
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
def _test_basic_func(rank, world_size, tempfile_name): def _test_basic_func(rank, world_size, tempfile_name, test_case):
_dist_init(rank, world_size, tempfile_name, backend="nccl") # Covers nccl _dist_init(rank, world_size, tempfile_name, backend="nccl") # Covers nccl
model = Linear(2, 2, bias=False) model = Linear(2, 2, bias=False)
model.to("cuda") model.to("cuda")
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[rank])
optim = AdaScale(SGD(model.parameters(), lr=0.1)) optim = AdaScale(SGD(model.parameters(), lr=0.1))
# iter 1 if "input" in test_case:
in_data = Tensor([0.0, 0.0]) # single iter
in_data[rank] = 1.0 in_data = Tensor(test_case["input"][rank])
in_data = in_data.cuda() in_data = in_data.cuda()
out = model(in_data) out = model(in_data)
out.sum().backward() out.sum().backward()
assert np.allclose(optim.gain(), 2.0), optim.gain() assert np.allclose(optim.gain(), test_case["expected_gain"]), optim.gain()
optim.step()
optim.zero_grad()
else:
# multiple iters
for in_data in test_case["inputs"]:
in_data = Tensor(in_data[rank]).cuda()
out = model(in_data)
out.sum().backward()
optim.step() optim.step()
optim.zero_grad() optim.zero_grad()
assert np.allclose(optim.gain(), test_case["expected_gain"]), optim.gain()
dist.destroy_process_group() dist.destroy_process_group()
# IMPORTANT: make sure these test_cases values are sync'ed with the non-DDP
# test in test_single_node_adascale.py. This way, we make sure gradient accumulation
# works exactly like that in DDP.
@skip_if_single_gpu @skip_if_single_gpu
def test_basic(): @pytest.mark.parametrize(
"test_case",
[
# "input" value is a list of input tensors for rank 0 and rank 1.
{"input": [[1.0, 0], [0, 1.0]], "expected_gain": 2.0},
{"input": [[1.0, 1.0], [1.0, 1.0]], "expected_gain": 1.0000001249999846},
{"input": [[-1.0, 1.0], [1.0, -1.0]], "expected_gain": 2.0},
{"input": [[1.0, 4.0], [5.0, 0.5]], "expected_gain": 1.5022222222222221},
{"input": [[-0.2, 3.0], [5.0, 0.5]], "expected_gain": 1.9433267229211089},
# "inputs" to trigger multiple iteration tests, which make sure the
# smoothing factor calculation is also covered.
{"inputs": [[[-0.2, 3.3], [5.2, 0.7]], [[1.0, 4.0], [3.1, 0.1]]], "expected_gain": 1.744159431359284},
],
)
def test_basic(test_case):
"""Test adascale with DDP without gradient accumulation""" """Test adascale with DDP without gradient accumulation"""
world_size = 2 world_size = 2
temp_file_name = tempfile.mkstemp()[1] temp_file_name = tempfile.mkstemp()[1]
mp.spawn(_test_basic_func, args=(world_size, temp_file_name), nprocs=world_size, join=True) mp.spawn(_test_basic_func, args=(world_size, temp_file_name, test_case), nprocs=world_size, join=True)
def _test_grad_accum_func(rank, world_size, tempfile_name): def _test_grad_accum_func(rank, world_size, tempfile_name):
......
...@@ -57,39 +57,61 @@ def test_loss_accum_cpu(): ...@@ -57,39 +57,61 @@ def test_loss_accum_cpu():
assert np.allclose(optim.gain(), 1.0), optim.gain() assert np.allclose(optim.gain(), 1.0), optim.gain()
def test_grad_accum_cpu(cpu=True): # IMPORTANT: make sure these test_cases values are sync'ed with the DDP
"""Test the basic functionality on CPU with gradient accumulation without DDP""" # test in test_ddp_adascale.py. This way, we make sure gradient accumulation
# works exactly like that in DDP.
@pytest.mark.parametrize("cpu", [True, False])
@pytest.mark.parametrize(
"test_case",
[
# "input" value is a list of input tensors for micro-batch 0 and micro-batch 1.
{"input": [[1.0, 0], [0, 1.0]], "expected_gain": 2.0},
{"input": [[1.0, 1.0], [1.0, 1.0]], "expected_gain": 1.0000001249999846},
{"input": [[-1.0, 1.0], [1.0, -1.0]], "expected_gain": 2.0},
{"input": [[1.0, 4.0], [5.0, 0.5]], "expected_gain": 1.5022222222222221},
{"input": [[-0.2, 3.0], [5.0, 0.5]], "expected_gain": 1.9433267229211089},
# "inputs" to trigger multiple iteration tests, which make sure the
# smoothing factor calculation is also covered.
{"inputs": [[[-0.2, 3.3], [5.2, 0.7]], [[1.0, 4.0], [3.1, 0.1]]], "expected_gain": 1.744159431359284},
],
)
def test_grad_accum(test_case, cpu):
"""Test the basic functionality on CPU/GPU with gradient accumulation without DDP"""
model = Linear(2, 2, bias=False) model = Linear(2, 2, bias=False)
if not cpu: if not cpu:
if torch.cuda.device_count() < 1:
pytest.skip("1 GPU is required")
model = model.cuda() model = model.cuda()
optim = AdaScale(SGD(model.parameters(), lr=0.1), num_gradients_to_accumulate=2) optim = AdaScale(SGD(model.parameters(), lr=0.1), num_gradients_to_accumulate=2)
for expected_gain in [2.0, 2.0]: # test 2 iterations catch more corner cases. expected_gain = test_case["expected_gain"]
if "input" in test_case:
data = [test_case["input"]] * 2
gains = [expected_gain] * 2
else:
data = test_case["inputs"]
gains = [None, expected_gain]
for in_data, exp_gain in zip(data, gains): # test 2 iterations catch more corner cases.
# grad pass 1 # grad pass 1
in_data = Tensor([0.0, 1.0]) in_data_0 = Tensor(in_data[0])
if not cpu: if not cpu:
in_data = in_data.cuda() in_data_0 = in_data_0.cuda()
out = model(in_data) out = model(in_data_0)
out.sum().backward() out.sum().backward()
# grad pass 2 # grad pass 2
in_data = Tensor([1.0, 0.0]) in_data_1 = Tensor(in_data[1])
if not cpu: if not cpu:
in_data = in_data.cuda() in_data_1 = in_data_1.cuda()
out = model(in_data) out = model(in_data_1)
out.sum().backward() out.sum().backward()
if exp_gain is not None:
assert np.allclose(optim.gain(), exp_gain), optim.gain()
# stepping it. Note that if we did more than 2 passes as promised by the # stepping it. Note that if we did more than 2 passes as promised by the
# num_gradients_to_accumulate argument above, AdaScale is not be able to # num_gradients_to_accumulate argument above, AdaScale is not be able to
# detect that mistake for now. The result will just be wrong in that case. # detect that mistake for now. The result will just be wrong in that case.
assert np.allclose(optim.gain(), expected_gain), optim.gain()
optim.step() optim.step()
optim.zero_grad() optim.zero_grad()
@skip_if_no_gpu
def test_grad_accum_gpu():
"""Test the basic functionality on GPU with gradient accumulation without DDP"""
test_grad_accum_cpu(cpu=False)
@skip_if_no_gpu @skip_if_no_gpu
def test_state_checkpointing(): def test_state_checkpointing():
""" Test state checkpointing on GPU since that's the common case. """ Test state checkpointing on GPU since that's the common case.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment