Unverified Commit 6552cbf8 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[booster] fix no_sync method (#3709)

* [booster] fix no_sync method

* [booster] add test for ddp no_sync

* [booster] fix merge

* [booster] update unit test

* [booster] update unit test

* [booster] update unit test
parent 3bf09efe
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
import os import os
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -286,3 +286,6 @@ class GeminiPlugin(DPPluginBase): ...@@ -286,3 +286,6 @@ class GeminiPlugin(DPPluginBase):
def get_checkpoint_io(self) -> CheckpointIO: def get_checkpoint_io(self) -> CheckpointIO:
return GeminiCheckpointIO() return GeminiCheckpointIO()
def no_sync(self, model: nn.Module) -> Iterator[None]:
raise NotImplementedError
import warnings import warnings
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -197,3 +197,6 @@ class LowLevelZeroPlugin(DPPluginBase): ...@@ -197,3 +197,6 @@ class LowLevelZeroPlugin(DPPluginBase):
def get_checkpoint_io(self) -> CheckpointIO: def get_checkpoint_io(self) -> CheckpointIO:
return LowLevelZeroCheckpointIO() return LowLevelZeroCheckpointIO()
def no_sync(self, model: nn.Module) -> Iterator[None]:
raise NotImplementedError
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, List, Tuple, Union from typing import Callable, Iterator, List, Tuple, Union
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
...@@ -60,6 +60,13 @@ class Plugin(ABC): ...@@ -60,6 +60,13 @@ class Plugin(ABC):
""" """
pass pass
@abstractmethod
def no_sync(self, model: nn.Module) -> Iterator[None]:
"""
Context manager to disable gradient synchronization.
"""
pass
@abstractmethod @abstractmethod
def prepare_dataloader(self, def prepare_dataloader(self,
dataset: Dataset, dataset: Dataset,
......
from typing import Callable, List, Tuple, Union from typing import Callable, Iterator, List, Tuple, Union
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
...@@ -142,3 +142,7 @@ class TorchDDPPlugin(DPPluginBase): ...@@ -142,3 +142,7 @@ class TorchDDPPlugin(DPPluginBase):
def get_checkpoint_io(self) -> CheckpointIO: def get_checkpoint_io(self) -> CheckpointIO:
return TorchDDPCheckpointIO() return TorchDDPCheckpointIO()
def no_sync(self, model: nn.Module) -> Iterator[None]:
assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.'
return model.module.no_sync()
from typing import Callable, List, Tuple, Union from typing import Callable, Iterator, List, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -49,6 +49,9 @@ class DPPluginWrapper(DPPluginBase): ...@@ -49,6 +49,9 @@ class DPPluginWrapper(DPPluginBase):
def supported_precisions(self) -> List[str]: def supported_precisions(self) -> List[str]:
pass pass
def no_sync(self, model: nn.Module) -> Iterator[None]:
pass
def check_dataloader_sharding(): def check_dataloader_sharding():
plugin = DPPluginWrapper() plugin = DPPluginWrapper()
......
from contextlib import nullcontext
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD from torch.optim import SGD
...@@ -44,10 +47,67 @@ def check_torch_ddp_plugin(): ...@@ -44,10 +47,67 @@ def check_torch_ddp_plugin():
torch.cuda.empty_cache() torch.cuda.empty_cache()
class DummyModel(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.rand(1))
def forward(self, x):
return self.weight * x
def check_torch_ddp_no_sync():
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = DummyModel()
criterion = lambda x: x.mean()
optimizer = SGD(model.parameters(), lr=1e-3)
# create a custom dasetset with 0 to 10
dataset = torch.arange(0, 10)
train_dataloader = plugin.prepare_dataloader(dataset, batch_size=2)
model, optimizer, criterion, train_dataloader, _ = booster.boost(model,
optimizer,
criterion,
dataloader=train_dataloader)
def fwd_bwd():
output = model(batch.cuda())
loss = criterion(output)
booster.backward(loss, optimizer)
def get_grad_set_over_all_ranks():
for p in model.parameters():
# grad shape is (1, )
assert p.grad.shape == (1,)
grad_list = [torch.empty_like(p.grad) for _ in range(dist.get_world_size())]
dist.all_gather(grad_list, p.grad)
# get grad set of all ranks
grad_set = set([grad.item() for grad in grad_list])
# as the model only has one parameter, we can return here
return grad_set
for i, batch in enumerate(train_dataloader):
if i > 1:
# only check the first two batches
break
# no_sync for the first batch, sync for the second batch
ctx = booster.no_sync(model) if i == 0 else nullcontext()
with ctx:
fwd_bwd()
grad_set = get_grad_set_over_all_ranks()
# for the first batch, all ranks should have different grads
# for the second batch, as grad is synchronized,all ranks should have the same grads
target_num_different_grad = dist.get_world_size() if i == 0 else 1
assert len(grad_set) == target_num_different_grad
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
# init dist env # init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
check_torch_ddp_plugin() check_torch_ddp_plugin()
check_torch_ddp_no_sync()
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment