Unverified Commit efed9cee authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[test] AdaScale & SDP/FSDP (#468)

- cover them in terms of code path only
- numerically, AdaScale is different on SDP/FSDP than DDP, mainly
  due to partial view of the gradients.
- this doesn't mean it is definitely not useful but it is yet to
  be validated.
- not going to spend too much time until we have a real use case.
parent eeabc6f1
......@@ -160,7 +160,7 @@ class AsyncAMPnetEventLoop:
reqd_input = transform_logger_object.transform_input(cur_batch).to(self.input_device)
batch = Batch(reqd_input, count)
if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True) # type: ignore
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True)
activations[count], message = self.async_send_inner(batch, count)
self.transport.send_message(message, sync=True)
count += 1
......@@ -177,7 +177,7 @@ class AsyncAMPnetEventLoop:
reqd_input = transform_logger_object.transform_input(cur_batch).to(self.input_device)
batch = Batch(reqd_input, count)
if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True) # type: ignore
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True)
activations[count], forward_message = self.async_send_inner(batch, count)
count += 1
......@@ -186,7 +186,7 @@ class AsyncAMPnetEventLoop:
args: AsyncMessageBody = message.args
assert args.message_type is AsyncMessageType.Gradients
if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False) # type: ignore
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False)
self.async_grad_inner(message, activations)
# Send after grad
......@@ -208,7 +208,7 @@ class AsyncAMPnetEventLoop:
args = message.args
assert args.message_type is AsyncMessageType.Gradients
if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False) # type: ignore
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False)
self.async_grad_inner(message, activations)
num_gradients += 1
......@@ -248,7 +248,7 @@ class AsyncAMPnetEventLoop:
batch = self.get_batch_from_message(message, EVENT_LOOP_GRADIENTS_QUEUE)
if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True) # type: ignore
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True)
task = create_task_without_skip_trackers(
self.checkpoint_stop, args.microbatch_index, self.group.rank(), batch, self.partitions[0].module,
)
......@@ -257,7 +257,7 @@ class AsyncAMPnetEventLoop:
task.finalize(output)
# one backward
if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False) # type: ignore
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False)
output_tensor = transform_logger_object.transform_output_before_loss(output.tensor)
loss = criterion(output_tensor, reqd_target)
......@@ -307,7 +307,7 @@ class AsyncAMPnetEventLoop:
n_warmup = ranks[-1] - cur_rank
for _ in range(n_warmup):
if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True) # type: ignore
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True)
message = self.event_loop_trunk_forward_helper(activations)
self.transport.send_message(message, sync=True)
num_activations += 1
......@@ -316,13 +316,13 @@ class AsyncAMPnetEventLoop:
while num_activations < num_microbatch:
# 1 Forward
if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True) # type: ignore
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=True)
message = self.event_loop_trunk_forward_helper(activations)
num_activations += 1
# 1 Backward
if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False) # type: ignore
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False)
self.event_loop_trunk_backward_helper(activations)
num_gradients += 1
if self.perform_optimizer_step(optimizer, num_gradients):
......@@ -336,7 +336,7 @@ class AsyncAMPnetEventLoop:
remaining = len(activations)
for _ in range(remaining):
if self.weight_prediction:
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False) # type: ignore
optimizer.update_weight_using_future_predictions(cur_rank, N, forward=False)
self.event_loop_trunk_backward_helper(activations)
num_gradients += 1
if self.perform_optimizer_step(optimizer, num_gradients):
......
......@@ -103,7 +103,7 @@ class ShardedDataParallel(nn.Module):
super().__init__()
self.module = module
self.sharded_optimizers = [sharded_optimizer] if isinstance(sharded_optimizer, OSS) else sharded_optimizer
self.sharded_optimizers = [sharded_optimizer] if not isinstance(sharded_optimizer, list) else sharded_optimizer
self.enable_broadcast_buffers = broadcast_buffers
self.auto_refresh_trainable = auto_refresh_trainable
self.reduce_fp16 = reduce_fp16
......
......@@ -596,6 +596,13 @@ class AdaScale(Optimizer):
# not needed, so the smoothing factor is 0.
self._smoothing = max(1 - self._world_size * self._num_grads_to_accum / 1000, 0)
def __getattr__(self, name: str) -> Any:
"""Forward missing attributes to wrapped optimizer."""
try:
return super().__getattr__(name) # defer to Optimizer logic
except AttributeError:
return getattr(self._optimizer, name) # fallback to wrapped optim
class AdaScaleWrapper(AdaScale):
"""
......
......@@ -9,6 +9,7 @@ class Optimizer(object):
param_groups: List[Dict]
state: Dict
def __init__(self, params: _params_t, defaults: Optional[Dict]=None, lr: Optional[float]=None) -> None: ...
def __getattr__(self, name: str) -> Any: ...
def state_dict(self) -> Dict: ...
def load_state_dict(self, state_dict: Dict) -> None: ...
def zero_grad(self) -> None: ...
......
......@@ -7,7 +7,19 @@
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test AdaScale with DDP. """
""" Test AdaScale with DDP/SDP/FSDP.
Even though it is tested here, AdaScale does NOT work with SDP/FSDP the
same way as DDP & gradient accumulation modes, because the full
gradients are not sent to each worker.
So they only have a slice of the reduced gradient in FSDP's case or
only a subset of gradients are reduced in SDP's. OTOH, each AdaScale
work receives full local-gradient. So the gain value computation is
off. If they use a slice (or subset) of their local-gradient, the gain
values they each compute will be different, which might or might not
be helpful for training.
"""
import tempfile
......@@ -21,7 +33,9 @@ from torch.nn import Linear
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD
from fairscale.optim import AdaScale
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import ShardedDataParallel as SDP
from fairscale.optim import OSS, AdaScale
from fairscale.utils.golden_testing_data import adascale_test_data
from fairscale.utils.testing import skip_if_single_gpu
......@@ -32,20 +46,35 @@ def _dist_init(rank, world_size, tempfile_name, backend):
torch.cuda.set_device(rank)
def _test_basic_func(rank, world_size, tempfile_name, test_case):
def _test_basic_func(rank, ddp_cls, world_size, tempfile_name, test_case):
_dist_init(rank, world_size, tempfile_name, backend="nccl") # Covers nccl
model = Linear(2, 2)
model.to("cuda")
model = DDP(model, device_ids=[rank])
optim = AdaScale(SGD(model.parameters(), lr=0.1))
if ddp_cls is DDP:
model = ddp_cls(model, device_ids=[rank])
optim = AdaScale(SGD(model.parameters(), lr=0.1))
elif ddp_cls is SDP:
optim = AdaScale(OSS(model.parameters(), SGD, lr=0.1))
model = ddp_cls(model, sharded_optimizer=optim)
else:
assert ddp_cls is FSDP, ddp_cls
# Two cases:
# flatten=True : AdaScale wrapper must be after FSDP and it receives
# a single grad tensor. It won't receive grad if
# wrapped before.
# flatten=False: AdaScale can be both before or after FSDP.
# So, it is better to do AdaScale after FSDP.
model = ddp_cls(model, flatten_parameters=False)
optim = AdaScale(SGD(model.parameters(), lr=0.1))
if "input" in test_case:
# single iter
in_data = Tensor(test_case["input"][rank])
in_data = in_data.cuda()
out = model(in_data)
out.sum().backward()
assert np.allclose(optim.gain(), test_case["expected_gain"]), optim.gain()
if ddp_cls is DDP:
assert np.allclose(optim.gain(), test_case["expected_gain"]), optim.gain()
optim.step()
optim.zero_grad()
else:
......@@ -56,19 +85,29 @@ def _test_basic_func(rank, world_size, tempfile_name, test_case):
out.sum().backward()
optim.step()
optim.zero_grad()
assert np.allclose(optim.gain(), test_case["expected_gain"]), optim.gain()
if ddp_cls is DDP:
assert np.allclose(optim.gain(), test_case["expected_gain"]), optim.gain()
dist.destroy_process_group()
@skip_if_single_gpu
@pytest.mark.parametrize("ddp_cls", [DDP])
@pytest.mark.parametrize("test_case", adascale_test_data)
def test_basic(test_case):
def test_basic(ddp_cls, test_case):
"""Test adascale with DDP without gradient accumulation"""
world_size = 2
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(_test_basic_func, args=(world_size, temp_file_name, test_case), nprocs=world_size, join=True)
mp.spawn(_test_basic_func, args=(ddp_cls, world_size, temp_file_name, test_case), nprocs=world_size, join=True)
@skip_if_single_gpu
@pytest.mark.parametrize("ddp_cls", [DDP, SDP, FSDP])
@pytest.mark.parametrize("test_case", adascale_test_data[:1])
def test_basic_all_dp(ddp_cls, test_case):
"""Test adascale with DDP/SDP/FSDP with just one test case."""
test_basic(ddp_cls, test_case)
def _test_grad_accum_func(rank, world_size, tempfile_name):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment