Unverified Commit b488dcfa authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[bug] Make OSS Gloo-compliant (#102)

* Broadcasting grad-enabled tensors is forbidden in Gloo, because this is not differentiable. Workaround
parent d80c38f9
......@@ -101,6 +101,7 @@ run_oss_benchmark: &run_oss_benchmark
name: Run OSS Benchmark
command: |
python benchmarks/oss.py
python benchmarks/oss.py --gloo
run_oss_ddp_benchmark: &run_oss_ddp_benchmark
- run:
......
......@@ -18,14 +18,12 @@ from torchvision.transforms import ToTensor
from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim.oss import OSS
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
OPTIM = torch.optim.RMSprop
def dist_init(rank, world_size):
dist.init_process_group(
backend=BACKEND, init_method="tcp://localhost:29501", rank=rank, world_size=world_size, store=None
)
def dist_init(rank, world_size, backend):
print(f"Using backend: {backend}")
dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
def get_problem(rank, data_size, batch_size):
......@@ -47,11 +45,11 @@ def get_problem(rank, data_size, batch_size):
def train_oss_ddp(
rank: int, world_size: int, num_epochs: int = 10, batch_size: int = 32, data_size: int = 200,
rank: int, world_size: int, num_epochs: int = 10, batch_size: int = 32, data_size: int = 200, backend: str = "gloo",
):
# DDP
dist_init(rank, world_size)
dist_init(rank, world_size, backend)
# Setup
model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)
......@@ -120,13 +118,14 @@ def train(
num_epochs: int = 10,
batch_size: int = 32,
data_size: int = 200,
backend: str = "gloo",
use_oss: bool = True,
check_regression: bool = True,
reference_speed: float = -1.0,
reference_memory: float = -1.0,
):
# DDP
dist_init(rank, world_size)
dist_init(rank, world_size, backend)
# Setup
model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)
......@@ -214,6 +213,7 @@ if __name__ == "__main__":
parser.add_argument("--check_regression", action="store_true", default=False)
parser.add_argument("--reference_speed", action="store", default=32.32, type=float)
parser.add_argument("--reference_memory", action="store", default=4475, type=float)
parser.add_argument("--gloo", action="store_true", default=False)
# beta - test oss_ddp
parser.add_argument("--oss_ddp", action="store_true", default=False)
......@@ -221,11 +221,12 @@ if __name__ == "__main__":
args = parser.parse_args()
print(f"Benchmark arguments: {args}")
backend = "nccl" if not args.gloo or not torch.cuda.is_available() else "gloo"
if args.oss_ddp:
print("\nBenchmark OSS DDP")
mp.spawn(
train_oss_ddp,
args=(args.world_size, args.epochs, args.batch_size, args.data_size),
args=(args.world_size, args.epochs, args.batch_size, args.data_size, backend),
nprocs=args.world_size,
join=True,
)
......@@ -233,7 +234,7 @@ if __name__ == "__main__":
print("\nBenchmark vanilla optimizer")
mp.spawn(
train,
args=(args.world_size, args.epochs, args.batch_size, args.data_size, False, False),
args=(args.world_size, args.epochs, args.batch_size, args.data_size, backend, False, False),
nprocs=args.world_size,
join=True,
)
......@@ -246,6 +247,7 @@ if __name__ == "__main__":
args.epochs,
args.batch_size,
args.data_size,
backend,
True,
args.check_regression,
args.reference_speed,
......
......@@ -160,12 +160,22 @@ class OSS(Optimizer):
# Sync all the states. Broadcast requests are issued async, we check completeness before moving on
requests = []
requires_grad = []
for rank, param_groups in enumerate(self.partition_parameters()):
for param_group in param_groups:
for param in param_group["params"]:
# NOTE: Broadcast is in-place and not differentiable
# Gloo will rightly assert on this operation for any tensor that requires grad.
# We save and restore the grad requirement state to work around that, in our case
# the grad is only useful on the source rank.
requires_grad.append((param, param.requires_grad))
param.requires_grad = False
requests.append(dist.broadcast(tensor=param, src=rank, group=self.group, async_op=True))
_ = list(map(lambda x: x.wait(), requests))
for fut, req_grad in zip(requests, requires_grad):
fut.wait()
req_grad[0].requires_grad = req_grad[1]
return loss
def local_state_dict(self) -> dict:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment