Unverified Commit ee38e1e0 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] Add a memory usage regression test to the OSS benchmark (#62)

* Aligning the optimizer state dict with what PyTorch expects

* Adding a check on the dict keys, ensure that `state` and `param_groups` are there

* after installing the specific isort, black and all, one liner to please the linter..

* Adding some measurement of the memory consumption while training + checkpointing

* mandatory lintfix commit

* brainfart, reset the memory use counter at the beginning of the training in case two of them are run in a row

* move reset stats call, hotfix

* move the optimizer to rmsprop, more stateful and still used in CV

* trying to figure out a sigsev in circleci
parent b6a5e634
...@@ -5,7 +5,7 @@ import argparse ...@@ -5,7 +5,7 @@ import argparse
import math import math
import os import os
import time import time
from typing import Any, List from typing import Any, List, Union, cast
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -19,6 +19,7 @@ from torchvision.transforms import ToTensor ...@@ -19,6 +19,7 @@ from torchvision.transforms import ToTensor
from fairscale.optim.oss import OSS from fairscale.optim.oss import OSS
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
OPTIM = torch.optim.RMSprop
def dist_init(rank, world_size): def dist_init(rank, world_size):
...@@ -36,7 +37,9 @@ def train( ...@@ -36,7 +37,9 @@ def train(
use_oss: bool = True, use_oss: bool = True,
check_regression: bool = True, check_regression: bool = True,
reference_speed: float = -1.0, reference_speed: float = -1.0,
reference_memory: float = -1.0,
): ):
# DDP # DDP
dist_init(rank, world_size) dist_init(rank, world_size)
...@@ -50,21 +53,18 @@ def train( ...@@ -50,21 +53,18 @@ def train(
"label": torch.stack([i[1] for i in inputs]).to(rank), "label": torch.stack([i[1] for i in inputs]).to(rank),
} }
def print_(msg):
if dist.get_rank() == 0:
print(msg)
dataloader = DataLoader( dataloader = DataLoader(
dataset=FakeData(transform=ToTensor(), size=data_size), batch_size=batch_size, collate_fn=collate dataset=FakeData(transform=ToTensor(), size=data_size), batch_size=batch_size, collate_fn=collate
) )
loss_fn = nn.CrossEntropyLoss() loss_fn = nn.CrossEntropyLoss()
# Reset the memory use counter
torch.cuda.reset_peak_memory_stats(rank)
# Shard the optimizer # Shard the optimizer
optimizer = ( optimizer: Union[OSS, OPTIM] = OSS(
OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-4, momentum=0.9) params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9
if use_oss ) if use_oss else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
else torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
)
# Dummy training loop # Dummy training loop
torch.cuda.synchronize(rank) torch.cuda.synchronize(rank)
...@@ -90,8 +90,19 @@ def train( ...@@ -90,8 +90,19 @@ def train(
optimizer.step(closure) optimizer.step(closure)
epoch_end = time.monotonic() epoch_end = time.monotonic()
if use_oss:
# Check the checkpointing in the case of the OSS optimizer
# Memory usage could spill over from there
optimizer = cast(OSS, optimizer)
# optimizer.consolidate_state_dict()
if dist.get_rank() == 0:
# _ = optimizer.state_dict()
print("... State dict collected")
measurements.append(data_size / (epoch_end - epoch_start)) measurements.append(data_size / (epoch_end - epoch_start))
print_(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec") if dist.get_rank() == 0:
print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec")
torch.cuda.synchronize(rank) torch.cuda.synchronize(rank)
training_stop = time.monotonic() training_stop = time.monotonic()
...@@ -101,13 +112,15 @@ def train( ...@@ -101,13 +112,15 @@ def train(
print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall") print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB") print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")
# Compute the mean and average img per second
mean = sum(measurements) / len(measurements)
diff = map(lambda x: pow(x - mean, 2.0), measurements)
std = math.sqrt(sum(diff) / (len(measurements) - 1))
print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")
if use_oss and check_regression and dist.get_rank() == 0: if use_oss and check_regression and dist.get_rank() == 0:
# Compute the mean and average img per second assert (mean - 3.0 * std) < reference_speed, "Speed regression detected"
mean = sum(measurements) / len(measurements) assert max_memory < 1.05 * reference_memory, "Memory use regression detected"
diff = map(lambda x: pow(x - mean, 2.0), measurements)
std = math.sqrt(sum(diff) / (len(measurements) - 1))
print(f"[Regression Test] Mean: {mean:.2f} +/- {std:.2f}")
assert (mean - 3.0 * std) < reference_speed, "Regression detected"
print("[Regression Test] VALID") print("[Regression Test] VALID")
...@@ -122,10 +135,11 @@ if __name__ == "__main__": ...@@ -122,10 +135,11 @@ if __name__ == "__main__":
parser.add_argument("--data_size", action="store", default=512, type=int) parser.add_argument("--data_size", action="store", default=512, type=int)
parser.add_argument("--check_regression", action="store", default=True, type=bool) parser.add_argument("--check_regression", action="store", default=True, type=bool)
parser.add_argument("--reference_speed", action="store", default=39.82, type=float) parser.add_argument("--reference_speed", action="store", default=39.82, type=float)
parser.add_argument("--reference_memory", action="store", default=4475, type=float)
args = parser.parse_args() args = parser.parse_args()
print("\nBenchmark vanilla SGD") print("\nBenchmark vanilla optimizer")
mp.spawn( mp.spawn(
train, train,
args=(args.world_size, args.epochs, args.batch_size, args.data_size, False, False), args=(args.world_size, args.epochs, args.batch_size, args.data_size, False, False),
...@@ -144,6 +158,7 @@ if __name__ == "__main__": ...@@ -144,6 +158,7 @@ if __name__ == "__main__":
True, True,
args.check_regression, args.check_regression,
args.reference_speed, args.reference_speed,
args.reference_memory,
), ),
nprocs=args.world_size, nprocs=args.world_size,
join=True, join=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment