Unverified Commit de713d1e authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat][minor] OSS Benchmark - regression test + background testing new optims (#352)

* restoring the regression test, adding a test of the for_each optims
* fix the regression test on circleci
* removing unused flags
parent 011c0c41
...@@ -179,7 +179,7 @@ run_oss_benchmark: &run_oss_benchmark ...@@ -179,7 +179,7 @@ run_oss_benchmark: &run_oss_benchmark
name: Run OSS Benchmark name: Run OSS Benchmark
command: | command: |
python benchmarks/oss.py --world_size 4 --epochs 2 python benchmarks/oss.py --world_size 4 --epochs 2
python benchmarks/oss.py --check_regression --world_size 4 --optim_type oss_sharded_ddp --reference_speed 660 --reference_memory 930 --reference_loss 0.023 python benchmarks/oss.py --check_regression --world_size 4 --optim_type oss_sharded_ddp
run_oss_gloo: &run_oss_gloo run_oss_gloo: &run_oss_gloo
- run: - run:
...@@ -194,6 +194,12 @@ run_oss_amp: &run_oss_amp ...@@ -194,6 +194,12 @@ run_oss_amp: &run_oss_amp
command: | command: |
python benchmarks/oss.py --amp --epochs 3 --optim_type oss_sharded_ddp python benchmarks/oss.py --amp --epochs 3 --optim_type oss_sharded_ddp
run_oss_for_each: &run_oss_for_each
- run:
name: Run OSS with Torch AMP and ForEach optmizer
command: |
python benchmarks/oss.py --amp --epochs 3 --optim_type oss_sharded_ddp --multi_tensor_optim
run_doc_build: &run_doc_build run_doc_build: &run_doc_build
- run: - run:
...@@ -458,6 +464,7 @@ jobs: ...@@ -458,6 +464,7 @@ jobs:
- <<: *run_oss_amp - <<: *run_oss_amp
- <<: *run_oss_for_each
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
def get_golden_real_stats(): def get_golden_real_stats():
return { return {
"reference_speed": 1430, "reference_speed": 660,
"reference_memory": 1220, "reference_memory": 1000,
"reference_loss": 0.006, "reference_loss": 0.026,
} }
......
...@@ -28,7 +28,6 @@ from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP ...@@ -28,7 +28,6 @@ from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler from fairscale.optim.grad_scaler import ShardedGradScaler
OPTIM = torch.optim.RMSprop
TEMPDIR = tempfile.gettempdir() TEMPDIR = tempfile.gettempdir()
...@@ -78,7 +77,7 @@ class OptimType(str, Enum): ...@@ -78,7 +77,7 @@ class OptimType(str, Enum):
everyone = "everyone" everyone = "everyone"
def validate_benchmark(measurements, args, check_regression): def validate_benchmark(measurements, final_loss, args, check_regression):
"""Validate the measurments against the golden benchmark config.""" """Validate the measurments against the golden benchmark config."""
golden_data = oss_mnist.get_golden_real_stats() golden_data = oss_mnist.get_golden_real_stats()
...@@ -118,6 +117,10 @@ def train( ...@@ -118,6 +117,10 @@ def train(
): ):
logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG) logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
use_multi_tensor = args.multi_tensor_optim and hasattr(torch.optim, "_multi_tensor")
OPTIM = torch.optim._multi_tensor.RMSprop if use_multi_tensor else torch.optim.RMSprop # type: ignore # attr is checked but mypy misses that
logging.info("Multi tensor optimizer: {}".format(use_multi_tensor))
# DDP # DDP
dist_init(rank=rank, world_size=args.world_size, backend=backend) dist_init(rank=rank, world_size=args.world_size, backend=backend)
...@@ -260,7 +263,7 @@ def train( ...@@ -260,7 +263,7 @@ def train(
img_per_sec = n_items / (training_stop - training_start) * args.epochs img_per_sec = n_items / (training_stop - training_start) * args.epochs
logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint") logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint")
validate_benchmark(measurements, args, check_regression) validate_benchmark(measurements, final_loss, args, check_regression)
dist.destroy_process_group() # type: ignore dist.destroy_process_group() # type: ignore
...@@ -273,9 +276,6 @@ if __name__ == "__main__": ...@@ -273,9 +276,6 @@ if __name__ == "__main__":
parser.add_argument("--epochs", action="store", default=10, type=int) parser.add_argument("--epochs", action="store", default=10, type=int)
parser.add_argument("--batch_size", action="store", default=256, type=int) parser.add_argument("--batch_size", action="store", default=256, type=int)
parser.add_argument("--check_regression", action="store_true", default=False) parser.add_argument("--check_regression", action="store_true", default=False)
parser.add_argument("--reference_speed", action="store", default=1430, type=float)
parser.add_argument("--reference_memory", action="store", default=1220, type=float)
parser.add_argument("--reference_loss", action="store", default=0.006, type=float)
parser.add_argument( parser.add_argument(
"--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone "--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
) )
...@@ -285,6 +285,9 @@ if __name__ == "__main__": ...@@ -285,6 +285,9 @@ if __name__ == "__main__":
parser.add_argument("--model", type=str, help="Any torchvision or timm model name (str)", default="resnet101") parser.add_argument("--model", type=str, help="Any torchvision or timm model name (str)", default="resnet101")
parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information") parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
parser.add_argument("--amp", action="store_true", default=False, help="Activate torch AMP") parser.add_argument("--amp", action="store_true", default=False, help="Activate torch AMP")
parser.add_argument(
"--multi_tensor_optim", action="store_true", default=False, help="Use the faster multi-tensor optimizers"
)
args = parser.parse_args() args = parser.parse_args()
...@@ -332,12 +335,7 @@ if __name__ == "__main__": ...@@ -332,12 +335,7 @@ if __name__ == "__main__":
logging.info("\n*** Benchmark OSS with ShardedDDP") logging.info("\n*** Benchmark OSS with ShardedDDP")
mp.spawn( mp.spawn(
train, # type: ignore train, # type: ignore
args=( args=(args, BACKEND, OptimType.oss_sharded_ddp, args.check_regression,),
args,
BACKEND,
OptimType.oss_sharded_ddp,
False,
), # FIXME: @lefaudeux - SDP should give the same results
nprocs=args.world_size, nprocs=args.world_size,
join=True, join=True,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment