Unverified Commit 6f8a8652 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feature] OSS: Use MNIST to benchmark (#159)

* switching to MNIST
* updating the reference values, should be good to go
* download dataset once for all processes
parent 577dcd98
......@@ -108,7 +108,7 @@ run_oss_benchmark: &run_oss_benchmark
- run:
name: Run OSS Benchmark
command: |
python benchmarks/oss.py --check_regression --world_size 4 --reference_speed 13.7 --reference_memory 4390 --reference_loss 0.302
python benchmarks/oss.py --check_regression --world_size 4 --reference_speed 800 --reference_memory 1120 --reference_loss 0.049
run_oss_gloo: &run_oss_gloo
- run:
......
......@@ -4,7 +4,10 @@
import argparse
from enum import Enum
import importlib
import logging
import math
import shutil
import tempfile
import time
from typing import Any, List, Optional, cast
......@@ -15,38 +18,40 @@ import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torchvision.datasets import FakeData
from torch.utils.data import BatchSampler, DataLoader, Sampler
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim.oss import OSS
OPTIM = torch.optim.RMSprop
TEMPDIR = tempfile.gettempdir()
def dist_init(rank, world_size, backend):
print(f"Using backend: {backend}")
logging.info(f"Using backend: {backend}")
dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
def get_problem(rank, data_size, batch_size, device, model_name: str):
def get_problem(rank, world_size, batch_size, device, model_name: str):
# Select the desired model on the fly
print(f"Using {model_name} for benchmarking")
logging.info(f"Using {model_name} for benchmarking")
model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False).to(device)
# Data setup, dummy data
# Data setup, duplicate the grey channels to get pseudo color
def collate(inputs: List[Any]):
return {
"inputs": torch.stack([i[0] for i in inputs]).to(device),
"label": torch.stack([i[1] for i in inputs]).to(device),
"inputs": torch.stack([i[0] for i in inputs]).repeat(1, 3, 1, 1).to(device),
"label": torch.tensor([i[1] for i in inputs]).to(device),
}
dataloader = DataLoader(
dataset=FakeData(transform=ToTensor(), size=data_size, random_offset=rank),
batch_size=batch_size,
collate_fn=collate,
)
dataset = MNIST(transform=ToTensor(), download=False, root=TEMPDIR)
sampler: Sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
batch_sampler = BatchSampler(sampler, batch_size, drop_last=True)
dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate)
loss_fn = nn.CrossEntropyLoss()
return model, dataloader, loss_fn
......@@ -65,6 +70,8 @@ def train(
optim_type: OptimType = OptimType.vanilla,
check_regression: bool = True,
):
logging.basicConfig(level=logging.INFO)
# DDP
dist_init(rank=rank, world_size=args.world_size, backend=backend)
......@@ -79,10 +86,11 @@ def train(
torch.backends.cudnn.benchmark = False
device = torch.device("cpu") if args.cpu else torch.device(rank)
model, dataloader, loss_fn = get_problem(rank, args.data_size, args.batch_size, device, args.torchvision_model)
model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.torchvision_model)
# Shard the optimizer
optimizer: Optional[torch.optim.Optimizer] = None
model = cast(nn.Module, model)
if optim_type == OptimType.oss_sdp:
ddp = ShardedDataParallel(
......@@ -108,7 +116,7 @@ def train(
torch.cuda.reset_peak_memory_stats(rank)
torch.cuda.synchronize(rank)
# Dummy training loop
# Standard training loop
training_start = time.monotonic()
model.train()
......@@ -117,9 +125,11 @@ def train(
need_profiling = args.profile
for epoch in range(args.epochs):
epoch_start = time.monotonic()
n_items = 0
epoch_runtime = 0.0
for batch in dataloader:
batch__start = time.monotonic()
def closure():
model.zero_grad()
......@@ -133,11 +143,11 @@ def train(
return loss
if need_profiling and not args.cpu:
print("Profiling the run")
logging.info("Profiling the run")
with profiler.profile(use_cuda=True) as prof: # type: ignore
with profiler.record_function("batch"):
final_loss = optimizer.step(closure)
print("profiling done, final loss ", cast(float, final_loss))
logging.info("profiling done, final loss ", cast(float, final_loss))
if rank == 0:
prof.export_chrome_trace(f"{optim_type}_trace.json")
......@@ -147,7 +157,10 @@ def train(
else:
final_loss = optimizer.step(closure)
epoch_end = time.monotonic()
n_items += args.batch_size
batch_end = time.monotonic()
epoch_runtime += batch_end - batch__start
if optim_type == OptimType.oss:
# Check the checkpointing in the case of the OSS optimizer
......@@ -156,33 +169,33 @@ def train(
optimizer.consolidate_state_dict()
if dist.get_rank() == 0:
_ = optimizer.state_dict()
print("... State dict collected")
logging.info("... State dict collected")
measurements.append(args.data_size / (epoch_end - epoch_start))
measurements.append(n_items / epoch_runtime)
if dist.get_rank() == 0:
print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")
logging.info(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")
if not args.cpu:
torch.cuda.synchronize(rank)
training_stop = time.monotonic()
img_per_sec = args.data_size / (training_stop - training_start) * args.epochs
img_per_sec = n_items / (training_stop - training_start) * args.epochs
max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20
print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")
logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
logging.info(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")
# Compute the mean and average img per second
mean = sum(measurements) / len(measurements)
diff = map(lambda x: pow(x - mean, 2.0), measurements)
std = math.sqrt(sum(diff) / (len(measurements) - 1))
print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")
logging.info(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")
if check_regression and dist.get_rank() == 0:
assert (mean + 3.0 * std) > args.reference_speed, "Speed regression detected"
assert max_memory < 1.05 * args.reference_memory, "Memory use regression detected"
assert abs(cast(float, final_loss) - args.reference_loss) < 1e-3, "Loss regression detected"
print("[Regression Test] VALID")
logging.info("[Regression Test] VALID")
dist.destroy_process_group() # type: ignore
......@@ -193,12 +206,11 @@ if __name__ == "__main__":
)
parser.add_argument("--world_size", action="store", default=2, type=int)
parser.add_argument("--epochs", action="store", default=10, type=int)
parser.add_argument("--batch_size", action="store", default=32, type=int)
parser.add_argument("--data_size", action="store", default=512, type=int)
parser.add_argument("--batch_size", action="store", default=256, type=int)
parser.add_argument("--check_regression", action="store_true", default=False)
parser.add_argument("--reference_speed", action="store", default=29.7, type=float)
parser.add_argument("--reference_memory", action="store", default=4475, type=float)
parser.add_argument("--reference_loss", action="store", default=0.866, type=float)
parser.add_argument("--reference_speed", action="store", default=1430, type=float)
parser.add_argument("--reference_memory", action="store", default=1220, type=float)
parser.add_argument("--reference_loss", action="store", default=0.006, type=float)
parser.add_argument(
"--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
)
......@@ -208,12 +220,33 @@ if __name__ == "__main__":
parser.add_argument("--torchvision_model", type=str, help="Any torchvision model name (str)", default="resnet101")
args = parser.parse_args()
print(f"Benchmark arguments: {args}")
logging.basicConfig(level=logging.INFO)
logging.info(f"Benchmark arguments: {args}")
backend = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo"
# Download dataset once for all processes
dataset, tentatives = None, 0
while dataset is None and tentatives < 5:
try:
dataset = MNIST(transform=None, download=True, root=TEMPDIR)
except (RuntimeError, EOFError) as e:
if isinstance(e, RuntimeError):
# Corrupted data, erase and restart
shutil.rmtree(TEMPDIR + "/MNIST")
logging.warning("Failed loading dataset: ", e)
tentatives += 1
if dataset is None:
logging.error("Could not download MNIST dataset")
exit(-1)
else:
logging.info("Dataset downloaded")
if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
print("\nBenchmark vanilla optimizer")
logging.info("*** Benchmark vanilla optimizer")
mp.spawn(
train,
args=(args, backend, OptimType.vanilla, False,), # no regression check
......@@ -222,13 +255,13 @@ if __name__ == "__main__":
)
if args.optim_type == OptimType.oss or args.optim_type == OptimType.everyone:
print("\nBenchmark OSS with DDP")
logging.info("*** Benchmark OSS with DDP")
mp.spawn(
train, args=(args, backend, OptimType.oss, args.check_regression), nprocs=args.world_size, join=True,
)
if args.optim_type == OptimType.oss_sdp or args.optim_type == OptimType.everyone:
print("\nBenchmark OSS with SDP")
logging.info("*** Benchmark OSS with SDP")
mp.spawn(
train,
args=(args, backend, OptimType.oss_sdp, False,), # FIXME: @lefaudeux - SDP should give the same results
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment