Unverified Commit 53553474 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] OSS benchmark cleanup (#109)

- small benchmark refactor, only one for all backends and ddp
- deterministic, enforce alignment with pytorch ddp
parent 7c5203eb
...@@ -100,14 +100,9 @@ run_oss_benchmark: &run_oss_benchmark ...@@ -100,14 +100,9 @@ run_oss_benchmark: &run_oss_benchmark
- run: - run:
name: Run OSS Benchmark name: Run OSS Benchmark
command: | command: |
python benchmarks/oss.py python benchmarks/oss.py --check_regression
python benchmarks/oss.py --gloo python benchmarks/oss.py --gloo --optim_type oss
run_oss_ddp_benchmark: &run_oss_ddp_benchmark
- run:
name: Run OSS DDP Benchmark
command: |
python benchmarks/oss.py --oss_ddp
# ------------------------------------------------------------------------------------- # -------------------------------------------------------------------------------------
# Jobs to run # Jobs to run
...@@ -259,8 +254,6 @@ jobs: ...@@ -259,8 +254,6 @@ jobs:
- <<: *run_oss_benchmark - <<: *run_oss_benchmark
- <<: *run_oss_ddp_benchmark
workflows: workflows:
......
...@@ -2,10 +2,12 @@ ...@@ -2,10 +2,12 @@
import argparse import argparse
from enum import Enum
import math import math
import time import time
from typing import Any, List, Optional, cast from typing import Any, List, Optional, cast
import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
...@@ -44,74 +46,6 @@ def get_problem(rank, data_size, batch_size): ...@@ -44,74 +46,6 @@ def get_problem(rank, data_size, batch_size):
return model, dataloader, loss_fn return model, dataloader, loss_fn
def train_oss_ddp(
rank: int, world_size: int, num_epochs: int = 10, batch_size: int = 32, data_size: int = 200, backend: str = "gloo",
):
# DDP
dist_init(rank, world_size, backend)
# Setup
model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)
ddp = ShardedDataParallel(
module=model, optimizer=torch.optim.SGD, optimizer_params={"lr": 1e-4, "momentum": 0.9}, world_size=world_size
)
optimizer = ddp.optimizer
# Reset the memory use counter
torch.cuda.reset_peak_memory_stats(rank)
# Dummy training loop
torch.cuda.synchronize(rank)
training_start = time.monotonic()
model.train()
measurements = []
for epoch in range(num_epochs):
epoch_start = time.monotonic()
for batch in dataloader:
def closure():
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
loss /= world_size
loss.backward()
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
if dist.get_rank() == 0:
print(f"Loss: {loss.item()}")
ddp.reduce() # Send the gradients to the appropriate shards
return loss
optimizer.step(closure)
epoch_end = time.monotonic()
measurements.append(data_size / (epoch_end - epoch_start))
if dist.get_rank() == 0:
print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec")
torch.cuda.synchronize(rank)
training_stop = time.monotonic()
img_per_sec = data_size / (training_stop - training_start) * num_epochs
max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20
print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")
# Compute the mean and average img per second
mean = sum(measurements) / len(measurements)
diff = map(lambda x: pow(x - mean, 2.0), measurements)
std = math.sqrt(sum(diff) / (len(measurements) - 1))
print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")
def train( def train(
rank: int, rank: int,
world_size: int, world_size: int,
...@@ -120,18 +54,40 @@ def train( ...@@ -120,18 +54,40 @@ def train(
data_size: int = 200, data_size: int = 200,
backend: str = "gloo", backend: str = "gloo",
use_oss: bool = True, use_oss: bool = True,
use_sdp: bool = False,
check_regression: bool = True, check_regression: bool = True,
reference_speed: float = -1.0, reference_speed: float = -1.0,
reference_memory: float = -1.0, reference_memory: float = -1.0,
reference_loss: float = -1.0,
): ):
assert not use_sdp or (use_sdp and use_oss), "ShardedDataParallel requires OSS"
# DDP # DDP
dist_init(rank, world_size, backend) dist_init(rank, world_size, backend)
# Setup # Setup
torch.cuda.set_device(rank)
torch.cuda.manual_seed(0)
torch.manual_seed(0) # also sets the cuda seed
np.random.seed(0)
if backend == "nccl":
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
model, dataloader, loss_fn = get_problem(rank, data_size, batch_size) model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)
# Shard the optimizer # Shard the optimizer
optimizer: torch.optim.Optimizer = ( optimizer: Optional[torch.optim.Optimizer] = None
if use_sdp:
ddp = ShardedDataParallel(
module=model, optimizer=OPTIM, optimizer_params={"lr": 1e-4, "momentum": 0.9}, world_size=world_size,
)
ddp.train()
optimizer = ddp.optimizer
model = ddp
else:
optimizer = (
OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9) OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
if use_oss if use_oss
else OPTIM(model.parameters(), lr=1e-4, momentum=0.9) else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
...@@ -162,6 +118,9 @@ def train( ...@@ -162,6 +118,9 @@ def train(
dist.all_reduce(loss, op=dist.ReduceOp.SUM) dist.all_reduce(loss, op=dist.ReduceOp.SUM)
if use_sdp:
ddp.reduce() # Send the gradients to the appropriate shards
return loss return loss
final_loss = optimizer.step(closure) final_loss = optimizer.step(closure)
...@@ -179,7 +138,7 @@ def train( ...@@ -179,7 +138,7 @@ def train(
measurements.append(data_size / (epoch_end - epoch_start)) measurements.append(data_size / (epoch_end - epoch_start))
if dist.get_rank() == 0: if dist.get_rank() == 0:
print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss}") print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")
torch.cuda.synchronize(rank) torch.cuda.synchronize(rank)
training_stop = time.monotonic() training_stop = time.monotonic()
...@@ -198,11 +157,19 @@ def train( ...@@ -198,11 +157,19 @@ def train(
if use_oss and check_regression and dist.get_rank() == 0: if use_oss and check_regression and dist.get_rank() == 0:
assert (mean + 3.0 * std) > reference_speed, "Speed regression detected" assert (mean + 3.0 * std) > reference_speed, "Speed regression detected"
assert max_memory < 1.05 * reference_memory, "Memory use regression detected" assert max_memory < 1.05 * reference_memory, "Memory use regression detected"
assert abs(cast(float, final_loss) - reference_loss) < 1e-3, "Loss regression detected"
print("[Regression Test] VALID") print("[Regression Test] VALID")
if __name__ == "__main__": if __name__ == "__main__":
class OptimType(str, Enum):
vanilla = "pytorch"
oss = "oss"
oss_sdp = "oss_sdp"
everyone = "everyone"
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Benchmark the optimizer state sharding, on a typical computer vision workload" description="Benchmark the optimizer state sharding, on a typical computer vision workload"
) )
...@@ -211,34 +178,38 @@ if __name__ == "__main__": ...@@ -211,34 +178,38 @@ if __name__ == "__main__":
parser.add_argument("--batch_size", action="store", default=32, type=int) parser.add_argument("--batch_size", action="store", default=32, type=int)
parser.add_argument("--data_size", action="store", default=512, type=int) parser.add_argument("--data_size", action="store", default=512, type=int)
parser.add_argument("--check_regression", action="store_true", default=False) parser.add_argument("--check_regression", action="store_true", default=False)
parser.add_argument("--reference_speed", action="store", default=32.32, type=float) parser.add_argument("--reference_speed", action="store", default=29.7, type=float)
parser.add_argument("--reference_memory", action="store", default=4475, type=float) parser.add_argument("--reference_memory", action="store", default=4475, type=float)
parser.add_argument("--reference_loss", action="store", default=0.866, type=float)
parser.add_argument(
"--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
)
parser.add_argument("--gloo", action="store_true", default=False) parser.add_argument("--gloo", action="store_true", default=False)
# beta - test oss_ddp
parser.add_argument("--oss_ddp", action="store_true", default=False)
args = parser.parse_args() args = parser.parse_args()
print(f"Benchmark arguments: {args}") print(f"Benchmark arguments: {args}")
backend = "nccl" if not args.gloo or not torch.cuda.is_available() else "gloo" backend = "nccl" if not args.gloo or not torch.cuda.is_available() else "gloo"
if args.oss_ddp:
print("\nBenchmark OSS DDP") if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
mp.spawn(
train_oss_ddp,
args=(args.world_size, args.epochs, args.batch_size, args.data_size, backend),
nprocs=args.world_size,
join=True,
)
else:
print("\nBenchmark vanilla optimizer") print("\nBenchmark vanilla optimizer")
mp.spawn( mp.spawn(
train, train,
args=(args.world_size, args.epochs, args.batch_size, args.data_size, backend, False, False), args=(
args.world_size,
args.epochs,
args.batch_size,
args.data_size,
backend,
False, # OSS
False, # SDP
False, # no regression check
),
nprocs=args.world_size, nprocs=args.world_size,
join=True, join=True,
) )
if args.optim_type == OptimType.oss or args.optim_type == OptimType.everyone:
print("\nBenchmark OSS") print("\nBenchmark OSS")
mp.spawn( mp.spawn(
train, train,
...@@ -248,10 +219,30 @@ if __name__ == "__main__": ...@@ -248,10 +219,30 @@ if __name__ == "__main__":
args.batch_size, args.batch_size,
args.data_size, args.data_size,
backend, backend,
True, True, # OSS
False, # SDP
args.check_regression, args.check_regression,
args.reference_speed, args.reference_speed,
args.reference_memory, args.reference_memory,
args.reference_loss,
),
nprocs=args.world_size,
join=True,
)
if args.optim_type == OptimType.oss_sdp or args.optim_type == OptimType.everyone:
print("\nBenchmark OSS DDP")
mp.spawn(
train,
args=(
args.world_size,
args.epochs,
args.batch_size,
args.data_size,
backend,
True, # OSS
True, # SDP
False, # no regression check
), ),
nprocs=args.world_size, nprocs=args.world_size,
join=True, join=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment