Unverified Commit 46c3776b authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] Simple macro OSS benchmark (#47)



* initial commit, dummy training loop, pure pytorch but not DDP

* probably slightly broken, but rough DDP benchmark run

* adding the torchvision requirement for testing

* brainfart

* reduce the loss, do something slightly distributed

* Some cleanup, distributing the training on two GPUs

* some cleanup + adding a vanilla run, still not good to go

* less silly defaults, gtg for a start I think

* smaller batch to fit the smaller gpus used in the circleci rigs

* Adding some options for the benchmark, and regression testing

* [test] set torch seed for Adam tests (#49)

Set the torch seed for tests. xfail mixed precision and memory-efficient mixed-precision state_dict tests due to their states being cast to FP16 and back to FP32 during load_state_dict.
Co-authored-by: default avatarJun Ru Anderson <andersonic@fb.com>

* linting, I really need to automate this isort insanity
Co-authored-by: default avatarJun Ru Anderson <33384298+andersonic@users.noreply.github.com>
Co-authored-by: default avatarJun Ru Anderson <andersonic@fb.com>
parent 0e8c2a96
......@@ -96,6 +96,12 @@ run_transformer_benchmark: &run_transformer_benchmark
command: |
python benchmarks/transformer.py
run_oss_benchmark: &run_oss_benchmark
- run:
name: Run OSS Benchmark
command: |
python benchmarks/oss.py
# -------------------------------------------------------------------------------------
# Jobs to run
# -------------------------------------------------------------------------------------
......@@ -244,6 +250,9 @@ jobs:
- <<: *run_transformer_benchmark
- <<: *run_oss_benchmark
workflows:
version: 2
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import argparse
import math
import os
import time
from typing import Any, List
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import FakeData
from torchvision.models import resnet101
from torchvision.transforms import ToTensor
from fairscale.optim.oss import OSS
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
def dist_init(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
dist.init_process_group(backend=BACKEND, rank=rank, world_size=world_size)
def train(
rank: int,
world_size: int,
num_epochs: int = 10,
batch_size: int = 32,
data_size: int = 200,
use_oss: bool = True,
check_regression: bool = True,
reference_speed: float = -1.0,
):
# DDP
dist_init(rank, world_size)
# Standard RN101
model = resnet101(pretrained=False, progress=True).to(rank)
# Data setup, dummy data
def collate(inputs: List[Any]):
return {
"inputs": torch.stack([i[0] for i in inputs]).to(rank),
"label": torch.stack([i[1] for i in inputs]).to(rank),
}
def _print(msg):
if dist.get_rank() == 0:
print(msg)
dataloader = DataLoader(
dataset=FakeData(transform=ToTensor(), size=data_size), batch_size=batch_size, collate_fn=collate
)
loss_fn = nn.CrossEntropyLoss()
# Shard the optimizer
optimizer = (
OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-4, momentum=0.9)
if use_oss
else torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
)
# Dummy training loop
torch.cuda.synchronize(rank)
training_start = time.monotonic()
model.train()
measurements = []
for epoch in range(num_epochs):
epoch_start = time.monotonic()
for batch in dataloader:
def closure():
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
loss /= world_size
loss.backward()
return loss
optimizer.step(closure)
epoch_end = time.monotonic()
measurements.append(data_size / (epoch_end - epoch_start))
_print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec")
torch.cuda.synchronize(rank)
training_stop = time.monotonic()
img_per_sec = data_size / (training_stop - training_start) * num_epochs
max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20
print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")
if use_oss and check_regression and dist.get_rank() == 0:
# Compute the mean and average img per second
mean = sum(measurements) / len(measurements)
diff = map(lambda x: pow(x - mean, 2.0), measurements)
std = math.sqrt(sum(diff) / (len(measurements) - 1))
print(f"[Regression Test] Mean: {mean:.2f} +/- {std:.2f}")
assert (mean - 3.0 * std) < reference_speed, "Regression detected"
print("[Regression Test] VALID")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Benchmark the optimizer state sharding, on a typical computer vision workload"
)
parser.add_argument("--world_size", action="store", default=2, type=int)
parser.add_argument("--epochs", action="store", default=10, type=int)
parser.add_argument("--batch_size", action="store", default=32, type=int)
parser.add_argument("--data_size", action="store", default=512, type=int)
parser.add_argument("--check_regression", action="store", default=True, type=bool)
parser.add_argument("--reference_speed", action="store", default=39.82, type=float)
args = parser.parse_args()
print("\nBenchmark vanilla SGD")
mp.spawn(
train,
args=(args.world_size, args.epochs, args.batch_size, args.data_size, False, False),
nprocs=args.world_size,
join=True,
)
print("\nBenchmark OSS")
mp.spawn(
train,
args=(
args.world_size,
args.epochs,
args.batch_size,
args.data_size,
True,
args.check_regression,
args.reference_speed,
),
nprocs=args.world_size,
join=True,
)
......@@ -49,7 +49,7 @@ class OSS(Optimizer):
in_super_constructor: bool
def __init__(self, params: _params_t, optim: Type[Optimizer] = SGD, group: Any = dist.group.WORLD, **defaults: Any):
# Hold all the nmodel params in the root .param_groups
# Hold all the model params in the root .param_groups
self.in_super_constructor = True
super().__init__(params, defaults)
self.in_super_constructor = False
......
......@@ -6,5 +6,6 @@ pytest == 5.4.1
pytest-cov == 2.10.0
torchtext == 0.6.0
torch >= 1.5.1
torchvision >= 0.6.0
# NOTE(msb) not a dependency but needed by torch
numpy == 1.17.4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment