Unverified Commit 6f8a8652 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feature] OSS: Use MNIST to benchmark (#159)

* switching to MNIST
* updating the reference values, should be good to go
* download dataset once for all processes
parent 577dcd98
...@@ -108,7 +108,7 @@ run_oss_benchmark: &run_oss_benchmark ...@@ -108,7 +108,7 @@ run_oss_benchmark: &run_oss_benchmark
- run: - run:
name: Run OSS Benchmark name: Run OSS Benchmark
command: | command: |
python benchmarks/oss.py --check_regression --world_size 4 --reference_speed 13.7 --reference_memory 4390 --reference_loss 0.302 python benchmarks/oss.py --check_regression --world_size 4 --reference_speed 800 --reference_memory 1120 --reference_loss 0.049
run_oss_gloo: &run_oss_gloo run_oss_gloo: &run_oss_gloo
- run: - run:
......
...@@ -4,7 +4,10 @@ ...@@ -4,7 +4,10 @@
import argparse import argparse
from enum import Enum from enum import Enum
import importlib import importlib
import logging
import math import math
import shutil
import tempfile
import time import time
from typing import Any, List, Optional, cast from typing import Any, List, Optional, cast
...@@ -15,38 +18,40 @@ import torch.distributed as dist ...@@ -15,38 +18,40 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader from torch.utils.data import BatchSampler, DataLoader, Sampler
from torchvision.datasets import FakeData from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor from torchvision.transforms import ToTensor
from fairscale.nn.data_parallel import ShardedDataParallel from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim.oss import OSS from fairscale.optim.oss import OSS
OPTIM = torch.optim.RMSprop OPTIM = torch.optim.RMSprop
TEMPDIR = tempfile.gettempdir()
def dist_init(rank, world_size, backend): def dist_init(rank, world_size, backend):
print(f"Using backend: {backend}") logging.info(f"Using backend: {backend}")
dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size) dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
def get_problem(rank, data_size, batch_size, device, model_name: str): def get_problem(rank, world_size, batch_size, device, model_name: str):
# Select the desired model on the fly # Select the desired model on the fly
print(f"Using {model_name} for benchmarking") logging.info(f"Using {model_name} for benchmarking")
model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False).to(device) model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False).to(device)
# Data setup, dummy data # Data setup, duplicate the grey channels to get pseudo color
def collate(inputs: List[Any]): def collate(inputs: List[Any]):
return { return {
"inputs": torch.stack([i[0] for i in inputs]).to(device), "inputs": torch.stack([i[0] for i in inputs]).repeat(1, 3, 1, 1).to(device),
"label": torch.stack([i[1] for i in inputs]).to(device), "label": torch.tensor([i[1] for i in inputs]).to(device),
} }
dataloader = DataLoader( dataset = MNIST(transform=ToTensor(), download=False, root=TEMPDIR)
dataset=FakeData(transform=ToTensor(), size=data_size, random_offset=rank), sampler: Sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
batch_size=batch_size, batch_sampler = BatchSampler(sampler, batch_size, drop_last=True)
collate_fn=collate, dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate)
)
loss_fn = nn.CrossEntropyLoss() loss_fn = nn.CrossEntropyLoss()
return model, dataloader, loss_fn return model, dataloader, loss_fn
...@@ -65,6 +70,8 @@ def train( ...@@ -65,6 +70,8 @@ def train(
optim_type: OptimType = OptimType.vanilla, optim_type: OptimType = OptimType.vanilla,
check_regression: bool = True, check_regression: bool = True,
): ):
logging.basicConfig(level=logging.INFO)
# DDP # DDP
dist_init(rank=rank, world_size=args.world_size, backend=backend) dist_init(rank=rank, world_size=args.world_size, backend=backend)
...@@ -79,10 +86,11 @@ def train( ...@@ -79,10 +86,11 @@ def train(
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
device = torch.device("cpu") if args.cpu else torch.device(rank) device = torch.device("cpu") if args.cpu else torch.device(rank)
model, dataloader, loss_fn = get_problem(rank, args.data_size, args.batch_size, device, args.torchvision_model) model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.torchvision_model)
# Shard the optimizer # Shard the optimizer
optimizer: Optional[torch.optim.Optimizer] = None optimizer: Optional[torch.optim.Optimizer] = None
model = cast(nn.Module, model)
if optim_type == OptimType.oss_sdp: if optim_type == OptimType.oss_sdp:
ddp = ShardedDataParallel( ddp = ShardedDataParallel(
...@@ -108,7 +116,7 @@ def train( ...@@ -108,7 +116,7 @@ def train(
torch.cuda.reset_peak_memory_stats(rank) torch.cuda.reset_peak_memory_stats(rank)
torch.cuda.synchronize(rank) torch.cuda.synchronize(rank)
# Dummy training loop # Standard training loop
training_start = time.monotonic() training_start = time.monotonic()
model.train() model.train()
...@@ -117,9 +125,11 @@ def train( ...@@ -117,9 +125,11 @@ def train(
need_profiling = args.profile need_profiling = args.profile
for epoch in range(args.epochs): for epoch in range(args.epochs):
epoch_start = time.monotonic() n_items = 0
epoch_runtime = 0.0
for batch in dataloader: for batch in dataloader:
batch__start = time.monotonic()
def closure(): def closure():
model.zero_grad() model.zero_grad()
...@@ -133,11 +143,11 @@ def train( ...@@ -133,11 +143,11 @@ def train(
return loss return loss
if need_profiling and not args.cpu: if need_profiling and not args.cpu:
print("Profiling the run") logging.info("Profiling the run")
with profiler.profile(use_cuda=True) as prof: # type: ignore with profiler.profile(use_cuda=True) as prof: # type: ignore
with profiler.record_function("batch"): with profiler.record_function("batch"):
final_loss = optimizer.step(closure) final_loss = optimizer.step(closure)
print("profiling done, final loss ", cast(float, final_loss)) logging.info("profiling done, final loss ", cast(float, final_loss))
if rank == 0: if rank == 0:
prof.export_chrome_trace(f"{optim_type}_trace.json") prof.export_chrome_trace(f"{optim_type}_trace.json")
...@@ -147,7 +157,10 @@ def train( ...@@ -147,7 +157,10 @@ def train(
else: else:
final_loss = optimizer.step(closure) final_loss = optimizer.step(closure)
epoch_end = time.monotonic() n_items += args.batch_size
batch_end = time.monotonic()
epoch_runtime += batch_end - batch__start
if optim_type == OptimType.oss: if optim_type == OptimType.oss:
# Check the checkpointing in the case of the OSS optimizer # Check the checkpointing in the case of the OSS optimizer
...@@ -156,33 +169,33 @@ def train( ...@@ -156,33 +169,33 @@ def train(
optimizer.consolidate_state_dict() optimizer.consolidate_state_dict()
if dist.get_rank() == 0: if dist.get_rank() == 0:
_ = optimizer.state_dict() _ = optimizer.state_dict()
print("... State dict collected") logging.info("... State dict collected")
measurements.append(args.data_size / (epoch_end - epoch_start)) measurements.append(n_items / epoch_runtime)
if dist.get_rank() == 0: if dist.get_rank() == 0:
print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}") logging.info(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")
if not args.cpu: if not args.cpu:
torch.cuda.synchronize(rank) torch.cuda.synchronize(rank)
training_stop = time.monotonic() training_stop = time.monotonic()
img_per_sec = args.data_size / (training_stop - training_start) * args.epochs img_per_sec = n_items / (training_stop - training_start) * args.epochs
max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20 max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20
print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall") logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB") logging.info(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")
# Compute the mean and average img per second # Compute the mean and average img per second
mean = sum(measurements) / len(measurements) mean = sum(measurements) / len(measurements)
diff = map(lambda x: pow(x - mean, 2.0), measurements) diff = map(lambda x: pow(x - mean, 2.0), measurements)
std = math.sqrt(sum(diff) / (len(measurements) - 1)) std = math.sqrt(sum(diff) / (len(measurements) - 1))
print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}") logging.info(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")
if check_regression and dist.get_rank() == 0: if check_regression and dist.get_rank() == 0:
assert (mean + 3.0 * std) > args.reference_speed, "Speed regression detected" assert (mean + 3.0 * std) > args.reference_speed, "Speed regression detected"
assert max_memory < 1.05 * args.reference_memory, "Memory use regression detected" assert max_memory < 1.05 * args.reference_memory, "Memory use regression detected"
assert abs(cast(float, final_loss) - args.reference_loss) < 1e-3, "Loss regression detected" assert abs(cast(float, final_loss) - args.reference_loss) < 1e-3, "Loss regression detected"
print("[Regression Test] VALID") logging.info("[Regression Test] VALID")
dist.destroy_process_group() # type: ignore dist.destroy_process_group() # type: ignore
...@@ -193,12 +206,11 @@ if __name__ == "__main__": ...@@ -193,12 +206,11 @@ if __name__ == "__main__":
) )
parser.add_argument("--world_size", action="store", default=2, type=int) parser.add_argument("--world_size", action="store", default=2, type=int)
parser.add_argument("--epochs", action="store", default=10, type=int) parser.add_argument("--epochs", action="store", default=10, type=int)
parser.add_argument("--batch_size", action="store", default=32, type=int) parser.add_argument("--batch_size", action="store", default=256, type=int)
parser.add_argument("--data_size", action="store", default=512, type=int)
parser.add_argument("--check_regression", action="store_true", default=False) parser.add_argument("--check_regression", action="store_true", default=False)
parser.add_argument("--reference_speed", action="store", default=29.7, type=float) parser.add_argument("--reference_speed", action="store", default=1430, type=float)
parser.add_argument("--reference_memory", action="store", default=4475, type=float) parser.add_argument("--reference_memory", action="store", default=1220, type=float)
parser.add_argument("--reference_loss", action="store", default=0.866, type=float) parser.add_argument("--reference_loss", action="store", default=0.006, type=float)
parser.add_argument( parser.add_argument(
"--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone "--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
) )
...@@ -208,12 +220,33 @@ if __name__ == "__main__": ...@@ -208,12 +220,33 @@ if __name__ == "__main__":
parser.add_argument("--torchvision_model", type=str, help="Any torchvision model name (str)", default="resnet101") parser.add_argument("--torchvision_model", type=str, help="Any torchvision model name (str)", default="resnet101")
args = parser.parse_args() args = parser.parse_args()
print(f"Benchmark arguments: {args}")
logging.basicConfig(level=logging.INFO)
logging.info(f"Benchmark arguments: {args}")
backend = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo" backend = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo"
# Download dataset once for all processes
dataset, tentatives = None, 0
while dataset is None and tentatives < 5:
try:
dataset = MNIST(transform=None, download=True, root=TEMPDIR)
except (RuntimeError, EOFError) as e:
if isinstance(e, RuntimeError):
# Corrupted data, erase and restart
shutil.rmtree(TEMPDIR + "/MNIST")
logging.warning("Failed loading dataset: ", e)
tentatives += 1
if dataset is None:
logging.error("Could not download MNIST dataset")
exit(-1)
else:
logging.info("Dataset downloaded")
if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone: if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
print("\nBenchmark vanilla optimizer") logging.info("*** Benchmark vanilla optimizer")
mp.spawn( mp.spawn(
train, train,
args=(args, backend, OptimType.vanilla, False,), # no regression check args=(args, backend, OptimType.vanilla, False,), # no regression check
...@@ -222,13 +255,13 @@ if __name__ == "__main__": ...@@ -222,13 +255,13 @@ if __name__ == "__main__":
) )
if args.optim_type == OptimType.oss or args.optim_type == OptimType.everyone: if args.optim_type == OptimType.oss or args.optim_type == OptimType.everyone:
print("\nBenchmark OSS with DDP") logging.info("*** Benchmark OSS with DDP")
mp.spawn( mp.spawn(
train, args=(args, backend, OptimType.oss, args.check_regression), nprocs=args.world_size, join=True, train, args=(args, backend, OptimType.oss, args.check_regression), nprocs=args.world_size, join=True,
) )
if args.optim_type == OptimType.oss_sdp or args.optim_type == OptimType.everyone: if args.optim_type == OptimType.oss_sdp or args.optim_type == OptimType.everyone:
print("\nBenchmark OSS with SDP") logging.info("*** Benchmark OSS with SDP")
mp.spawn( mp.spawn(
train, train,
args=(args, backend, OptimType.oss_sdp, False,), # FIXME: @lefaudeux - SDP should give the same results args=(args, backend, OptimType.oss_sdp, False,), # FIXME: @lefaudeux - SDP should give the same results
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment