Unverified Commit 34f35fba authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat][minor] OSS Benchmark - add a debug option to add some tensor dumps (#166)

* Some ease of use in the benchmark tool, add a debug option
parent a31b08a5
......@@ -114,7 +114,7 @@ run_oss_gloo: &run_oss_gloo
- run:
name: Run OSS with Gloo
command: |
python benchmarks/oss.py --gloo --optim_type oss
python benchmarks/oss.py --gloo --optim_type oss_ddp --epochs 3
# -------------------------------------------------------------------------------------
......
......@@ -23,8 +23,8 @@ from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim.oss import OSS
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS
OPTIM = torch.optim.RMSprop
TEMPDIR = tempfile.gettempdir()
......@@ -58,8 +58,8 @@ def get_problem(rank, world_size, batch_size, device, model_name: str):
class OptimType(str, Enum):
vanilla = "pytorch"
oss = "oss"
oss_sdp = "oss_sdp"
oss_ddp = "oss_ddp"
oss_sharded_ddp = "oss_sharded_ddp"
everyone = "everyone"
......@@ -70,14 +70,15 @@ def train(
optim_type: OptimType = OptimType.vanilla,
check_regression: bool = True,
):
logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
# DDP
dist_init(rank=rank, world_size=args.world_size, backend=backend)
# Setup
torch.cuda.set_device(rank)
torch.cuda.manual_seed(0)
if not args.cpu:
torch.cuda.set_device(rank)
torch.cuda.manual_seed(0)
torch.manual_seed(0) # also sets the cuda seed
np.random.seed(0)
......@@ -92,24 +93,24 @@ def train(
optimizer: Optional[torch.optim.Optimizer] = None
model = cast(nn.Module, model)
if optim_type == OptimType.oss_sdp:
ddp = ShardedDataParallel(
module=model,
if optim_type == OptimType.oss_sharded_ddp:
model = ShardedDDP(
model,
optimizer=OPTIM,
optimizer_params={"lr": 1e-4, "momentum": 0.9},
world_size=args.world_size,
broadcast_buffers=True,
)
ddp.train()
optimizer = ddp.optimizer
model = ddp
optimizer = model.sharded_optimizer
else:
model = DDP(model, device_ids=[rank], find_unused_parameters=True) # type: ignore
model = DDP(model, device_ids=[rank], find_unused_parameters=False) # type: ignore
optimizer = (
OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
if optim_type == OptimType.oss
if optim_type == OptimType.oss_ddp
else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
)
optimizer = cast(torch.optim.Optimizer, optimizer)
# Reset the memory use counter
if not args.cpu:
......@@ -133,21 +134,34 @@ def train(
def closure():
model.zero_grad()
if args.debug and rank == 0 and next(model.parameters()).grad is not None:
logging.debug(
"\nbefore: param {} -- grad {}".format(
next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
)
)
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
loss.backward()
if optim_type == OptimType.oss_sdp:
ddp.reduce() # Send the gradients to the appropriate shards
if optim_type == OptimType.oss_sharded_ddp:
model.reduce()
if args.debug and rank == 0 and next(model.parameters()).grad is not None:
logging.debug(
"after BW: param {} -- grad {}".format(
next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
)
)
return loss
if need_profiling and not args.cpu:
logging.info("Profiling the run")
with profiler.profile(use_cuda=True) as prof: # type: ignore
with profiler.profile(use_cuda=True, record_shapes=True, profile_memory=True) as prof: # type: ignore
with profiler.record_function("batch"):
final_loss = optimizer.step(closure)
logging.info("profiling done, final loss ", cast(float, final_loss))
logging.info("profiling done")
if rank == 0:
prof.export_chrome_trace(f"{optim_type}_trace.json")
......@@ -157,12 +171,20 @@ def train(
else:
final_loss = optimizer.step(closure)
if args.debug and rank == 0:
logging.debug("buffer: {}".format(next(model.buffers()).norm().item()))
logging.debug(
"after update: param {} -- grad {}".format(
next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
)
)
n_items += args.batch_size
batch_end = time.monotonic()
epoch_runtime += batch_end - batch__start
if optim_type == OptimType.oss:
if optim_type == OptimType.oss_ddp or optim_type == OptimType.oss_sharded_ddp:
# Check the checkpointing in the case of the OSS optimizer
# Memory usage could spill over from there
optimizer = cast(OSS, optimizer)
......@@ -175,19 +197,23 @@ def train(
if dist.get_rank() == 0:
logging.info(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")
max_memory = -1.0
if not args.cpu:
torch.cuda.synchronize(rank)
max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20
logging.info(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")
training_stop = time.monotonic()
img_per_sec = n_items / (training_stop - training_start) * args.epochs
max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20
logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint")
logging.info(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")
# Compute the mean and average img per second
mean = sum(measurements) / len(measurements)
diff = map(lambda x: pow(x - mean, 2.0), measurements)
std = math.sqrt(sum(diff) / (len(measurements) - 1))
std = math.sqrt(sum(diff) / (len(measurements) - 1)) if args.epochs > 2 else -1
logging.info(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")
if check_regression and dist.get_rank() == 0:
......@@ -218,10 +244,11 @@ if __name__ == "__main__":
parser.add_argument("--profile", action="store_true", default=False)
parser.add_argument("--cpu", action="store_true", default=False)
parser.add_argument("--torchvision_model", type=str, help="Any torchvision model name (str)", default="resnet101")
parser.add_argument("--debug", action="store_true", default=False)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
logging.info(f"Benchmark arguments: {args}")
backend = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo"
......@@ -245,8 +272,9 @@ if __name__ == "__main__":
else:
logging.info("Dataset downloaded")
# Benchmark the different configurations, via multiple processes
if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
logging.info("*** Benchmark vanilla optimizer")
logging.info("\n*** Benchmark vanilla optimizer")
mp.spawn(
train,
args=(args, backend, OptimType.vanilla, False,), # no regression check
......@@ -254,17 +282,22 @@ if __name__ == "__main__":
join=True,
)
if args.optim_type == OptimType.oss or args.optim_type == OptimType.everyone:
logging.info("*** Benchmark OSS with DDP")
if args.optim_type == OptimType.oss_ddp or args.optim_type == OptimType.everyone:
logging.info("\n*** Benchmark OSS with DDP")
mp.spawn(
train, args=(args, backend, OptimType.oss, args.check_regression), nprocs=args.world_size, join=True,
train, args=(args, backend, OptimType.oss_ddp, args.check_regression), nprocs=args.world_size, join=True,
)
if args.optim_type == OptimType.oss_sdp or args.optim_type == OptimType.everyone:
logging.info("*** Benchmark OSS with SDP")
if args.optim_type == OptimType.oss_sharded_ddp or args.optim_type == OptimType.everyone:
logging.info("\n*** Benchmark OSS with ShardedDDP")
mp.spawn(
train,
args=(args, backend, OptimType.oss_sdp, False,), # FIXME: @lefaudeux - SDP should give the same results
args=(
args,
backend,
OptimType.oss_sharded_ddp,
False,
), # FIXME: @lefaudeux - SDP should give the same results
nprocs=args.world_size,
join=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment