Unverified Commit 34915bf8 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] OSS: adding a --profile option to the benchmark (#135)

parent 37c686e7
...@@ -9,6 +9,7 @@ from typing import Any, List, Optional, cast ...@@ -9,6 +9,7 @@ from typing import Any, List, Optional, cast
import numpy as np import numpy as np
import torch import torch
import torch.autograd.profiler as profiler
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
...@@ -49,6 +50,13 @@ def get_problem(rank, data_size, batch_size): ...@@ -49,6 +50,13 @@ def get_problem(rank, data_size, batch_size):
return model, dataloader, loss_fn return model, dataloader, loss_fn
class OptimType(str, Enum):
vanilla = "pytorch"
oss = "oss"
oss_sdp = "oss_sdp"
everyone = "everyone"
def train( def train(
rank: int, rank: int,
world_size: int, world_size: int,
...@@ -56,14 +64,13 @@ def train( ...@@ -56,14 +64,13 @@ def train(
batch_size: int = 32, batch_size: int = 32,
data_size: int = 200, data_size: int = 200,
backend: str = "gloo", backend: str = "gloo",
use_oss: bool = True, optim_type: OptimType = OptimType.vanilla,
use_sdp: bool = False, profile: bool = False,
check_regression: bool = True, check_regression: bool = True,
reference_speed: float = -1.0, reference_speed: float = -1.0,
reference_memory: float = -1.0, reference_memory: float = -1.0,
reference_loss: float = -1.0, reference_loss: float = -1.0,
): ):
assert not use_sdp or (use_sdp and use_oss), "ShardedDataParallel requires OSS"
# DDP # DDP
dist_init(rank=rank, world_size=world_size, backend=backend) dist_init(rank=rank, world_size=world_size, backend=backend)
...@@ -82,7 +89,7 @@ def train( ...@@ -82,7 +89,7 @@ def train(
# Shard the optimizer # Shard the optimizer
optimizer: Optional[torch.optim.Optimizer] = None optimizer: Optional[torch.optim.Optimizer] = None
if use_sdp: if optim_type == OptimType.oss_sdp:
ddp = ShardedDataParallel( ddp = ShardedDataParallel(
module=model, module=model,
optimizer=OPTIM, optimizer=OPTIM,
...@@ -97,7 +104,7 @@ def train( ...@@ -97,7 +104,7 @@ def train(
model = DDP(model, device_ids=[rank], find_unused_parameters=True) # type: ignore model = DDP(model, device_ids=[rank], find_unused_parameters=True) # type: ignore
optimizer = ( optimizer = (
OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9) OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
if use_oss if optim_type == OptimType.oss
else OPTIM(model.parameters(), lr=1e-4, momentum=0.9) else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
) )
...@@ -111,6 +118,7 @@ def train( ...@@ -111,6 +118,7 @@ def train(
measurements = [] measurements = []
final_loss: Optional[float] = -1.0 final_loss: Optional[float] = -1.0
need_profiling = profile
for epoch in range(num_epochs): for epoch in range(num_epochs):
epoch_start = time.monotonic() epoch_start = time.monotonic()
...@@ -124,16 +132,29 @@ def train( ...@@ -124,16 +132,29 @@ def train(
loss /= world_size loss /= world_size
loss.backward() loss.backward()
if use_sdp: if optim_type == OptimType.oss_sdp:
ddp.reduce() # Send the gradients to the appropriate shards ddp.reduce() # Send the gradients to the appropriate shards
return loss return loss
if need_profiling:
print("Profiling the run")
with profiler.profile(use_cuda=True) as prof: # type: ignore
with profiler.record_function("batch"):
final_loss = optimizer.step(closure)
print("profiling done, final loss ", cast(float, final_loss))
if rank == 0:
prof.export_chrome_trace(f"{optim_type}_trace.json")
need_profiling = False # only profile once
else:
final_loss = optimizer.step(closure) final_loss = optimizer.step(closure)
epoch_end = time.monotonic() epoch_end = time.monotonic()
if use_oss: if optim_type == OptimType.oss:
# Check the checkpointing in the case of the OSS optimizer # Check the checkpointing in the case of the OSS optimizer
# Memory usage could spill over from there # Memory usage could spill over from there
optimizer = cast(OSS, optimizer) optimizer = cast(OSS, optimizer)
...@@ -160,7 +181,7 @@ def train( ...@@ -160,7 +181,7 @@ def train(
std = math.sqrt(sum(diff) / (len(measurements) - 1)) std = math.sqrt(sum(diff) / (len(measurements) - 1))
print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}") print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")
if use_oss and check_regression and dist.get_rank() == 0: if check_regression and dist.get_rank() == 0:
assert (mean + 3.0 * std) > reference_speed, "Speed regression detected" assert (mean + 3.0 * std) > reference_speed, "Speed regression detected"
assert max_memory < 1.05 * reference_memory, "Memory use regression detected" assert max_memory < 1.05 * reference_memory, "Memory use regression detected"
assert abs(cast(float, final_loss) - reference_loss) < 1e-3, "Loss regression detected" assert abs(cast(float, final_loss) - reference_loss) < 1e-3, "Loss regression detected"
...@@ -171,13 +192,6 @@ def train( ...@@ -171,13 +192,6 @@ def train(
if __name__ == "__main__": if __name__ == "__main__":
class OptimType(str, Enum):
vanilla = "pytorch"
oss = "oss"
oss_sdp = "oss_sdp"
everyone = "everyone"
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Benchmark the optimizer state sharding, on a typical computer vision workload" description="Benchmark the optimizer state sharding, on a typical computer vision workload"
) )
...@@ -193,6 +207,7 @@ if __name__ == "__main__": ...@@ -193,6 +207,7 @@ if __name__ == "__main__":
"--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone "--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
) )
parser.add_argument("--gloo", action="store_true", default=False) parser.add_argument("--gloo", action="store_true", default=False)
parser.add_argument("--profile", action="store_true", default=False)
args = parser.parse_args() args = parser.parse_args()
print(f"Benchmark arguments: {args}") print(f"Benchmark arguments: {args}")
...@@ -209,8 +224,8 @@ if __name__ == "__main__": ...@@ -209,8 +224,8 @@ if __name__ == "__main__":
args.batch_size, args.batch_size,
args.data_size, args.data_size,
backend, backend,
False, # OSS OptimType.vanilla,
False, # SDP args.profile,
False, # no regression check False, # no regression check
), ),
nprocs=args.world_size, nprocs=args.world_size,
...@@ -227,8 +242,8 @@ if __name__ == "__main__": ...@@ -227,8 +242,8 @@ if __name__ == "__main__":
args.batch_size, args.batch_size,
args.data_size, args.data_size,
backend, backend,
True, # OSS OptimType.oss,
False, # SDP args.profile,
args.check_regression, args.check_regression,
args.reference_speed, args.reference_speed,
args.reference_memory, args.reference_memory,
...@@ -248,8 +263,8 @@ if __name__ == "__main__": ...@@ -248,8 +263,8 @@ if __name__ == "__main__":
args.batch_size, args.batch_size,
args.data_size, args.data_size,
backend, backend,
True, # OSS OptimType.oss_sdp,
True, # SDP args.profile,
False, # FIXME: @lefaudeux - SDP should give the same results False, # FIXME: @lefaudeux - SDP should give the same results
-1, # Not checking SDP for speed regression for now, still slower than OSS -1, # Not checking SDP for speed regression for now, still slower than OSS
args.reference_memory, args.reference_memory,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment