Unverified Commit b488dcfa authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[bug] Make OSS Gloo-compliant (#102)

* Broadcasting grad-enabled tensors is forbidden in Gloo, because this is not differentiable. Workaround
parent d80c38f9
...@@ -101,6 +101,7 @@ run_oss_benchmark: &run_oss_benchmark ...@@ -101,6 +101,7 @@ run_oss_benchmark: &run_oss_benchmark
name: Run OSS Benchmark name: Run OSS Benchmark
command: | command: |
python benchmarks/oss.py python benchmarks/oss.py
python benchmarks/oss.py --gloo
run_oss_ddp_benchmark: &run_oss_ddp_benchmark run_oss_ddp_benchmark: &run_oss_ddp_benchmark
- run: - run:
......
...@@ -18,14 +18,12 @@ from torchvision.transforms import ToTensor ...@@ -18,14 +18,12 @@ from torchvision.transforms import ToTensor
from fairscale.nn.data_parallel import ShardedDataParallel from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim.oss import OSS from fairscale.optim.oss import OSS
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
OPTIM = torch.optim.RMSprop OPTIM = torch.optim.RMSprop
def dist_init(rank, world_size): def dist_init(rank, world_size, backend):
dist.init_process_group( print(f"Using backend: {backend}")
backend=BACKEND, init_method="tcp://localhost:29501", rank=rank, world_size=world_size, store=None dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
)
def get_problem(rank, data_size, batch_size): def get_problem(rank, data_size, batch_size):
...@@ -47,11 +45,11 @@ def get_problem(rank, data_size, batch_size): ...@@ -47,11 +45,11 @@ def get_problem(rank, data_size, batch_size):
def train_oss_ddp( def train_oss_ddp(
rank: int, world_size: int, num_epochs: int = 10, batch_size: int = 32, data_size: int = 200, rank: int, world_size: int, num_epochs: int = 10, batch_size: int = 32, data_size: int = 200, backend: str = "gloo",
): ):
# DDP # DDP
dist_init(rank, world_size) dist_init(rank, world_size, backend)
# Setup # Setup
model, dataloader, loss_fn = get_problem(rank, data_size, batch_size) model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)
...@@ -120,13 +118,14 @@ def train( ...@@ -120,13 +118,14 @@ def train(
num_epochs: int = 10, num_epochs: int = 10,
batch_size: int = 32, batch_size: int = 32,
data_size: int = 200, data_size: int = 200,
backend: str = "gloo",
use_oss: bool = True, use_oss: bool = True,
check_regression: bool = True, check_regression: bool = True,
reference_speed: float = -1.0, reference_speed: float = -1.0,
reference_memory: float = -1.0, reference_memory: float = -1.0,
): ):
# DDP # DDP
dist_init(rank, world_size) dist_init(rank, world_size, backend)
# Setup # Setup
model, dataloader, loss_fn = get_problem(rank, data_size, batch_size) model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)
...@@ -214,6 +213,7 @@ if __name__ == "__main__": ...@@ -214,6 +213,7 @@ if __name__ == "__main__":
parser.add_argument("--check_regression", action="store_true", default=False) parser.add_argument("--check_regression", action="store_true", default=False)
parser.add_argument("--reference_speed", action="store", default=32.32, type=float) parser.add_argument("--reference_speed", action="store", default=32.32, type=float)
parser.add_argument("--reference_memory", action="store", default=4475, type=float) parser.add_argument("--reference_memory", action="store", default=4475, type=float)
parser.add_argument("--gloo", action="store_true", default=False)
# beta - test oss_ddp # beta - test oss_ddp
parser.add_argument("--oss_ddp", action="store_true", default=False) parser.add_argument("--oss_ddp", action="store_true", default=False)
...@@ -221,11 +221,12 @@ if __name__ == "__main__": ...@@ -221,11 +221,12 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
print(f"Benchmark arguments: {args}") print(f"Benchmark arguments: {args}")
backend = "nccl" if not args.gloo or not torch.cuda.is_available() else "gloo"
if args.oss_ddp: if args.oss_ddp:
print("\nBenchmark OSS DDP") print("\nBenchmark OSS DDP")
mp.spawn( mp.spawn(
train_oss_ddp, train_oss_ddp,
args=(args.world_size, args.epochs, args.batch_size, args.data_size), args=(args.world_size, args.epochs, args.batch_size, args.data_size, backend),
nprocs=args.world_size, nprocs=args.world_size,
join=True, join=True,
) )
...@@ -233,7 +234,7 @@ if __name__ == "__main__": ...@@ -233,7 +234,7 @@ if __name__ == "__main__":
print("\nBenchmark vanilla optimizer") print("\nBenchmark vanilla optimizer")
mp.spawn( mp.spawn(
train, train,
args=(args.world_size, args.epochs, args.batch_size, args.data_size, False, False), args=(args.world_size, args.epochs, args.batch_size, args.data_size, backend, False, False),
nprocs=args.world_size, nprocs=args.world_size,
join=True, join=True,
) )
...@@ -246,6 +247,7 @@ if __name__ == "__main__": ...@@ -246,6 +247,7 @@ if __name__ == "__main__":
args.epochs, args.epochs,
args.batch_size, args.batch_size,
args.data_size, args.data_size,
backend,
True, True,
args.check_regression, args.check_regression,
args.reference_speed, args.reference_speed,
......
...@@ -160,12 +160,22 @@ class OSS(Optimizer): ...@@ -160,12 +160,22 @@ class OSS(Optimizer):
# Sync all the states. Broadcast requests are issued async, we check completeness before moving on # Sync all the states. Broadcast requests are issued async, we check completeness before moving on
requests = [] requests = []
requires_grad = []
for rank, param_groups in enumerate(self.partition_parameters()): for rank, param_groups in enumerate(self.partition_parameters()):
for param_group in param_groups: for param_group in param_groups:
for param in param_group["params"]: for param in param_group["params"]:
# NOTE: Broadcast is in-place and not differentiable
# Gloo will rightly assert on this operation for any tensor that requires grad.
# We save and restore the grad requirement state to work around that, in our case
# the grad is only useful on the source rank.
requires_grad.append((param, param.requires_grad))
param.requires_grad = False
requests.append(dist.broadcast(tensor=param, src=rank, group=self.group, async_op=True)) requests.append(dist.broadcast(tensor=param, src=rank, group=self.group, async_op=True))
_ = list(map(lambda x: x.wait(), requests)) for fut, req_grad in zip(requests, requires_grad):
fut.wait()
req_grad[0].requires_grad = req_grad[1]
return loss return loss
def local_state_dict(self) -> dict: def local_state_dict(self) -> dict:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment