Unverified Commit 49a3d9bc authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat][minor] OSS benchmark - pick the model via args (#152)

* Minor, ease of life to debug and makes it possible to test a host of models with the same code
parent 61bb32b5
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import argparse import argparse
from enum import Enum from enum import Enum
import importlib
import math import math
import time import time
from typing import Any, List, Optional, cast from typing import Any, List, Optional, cast
...@@ -16,7 +17,6 @@ import torch.nn as nn ...@@ -16,7 +17,6 @@ import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision.datasets import FakeData from torchvision.datasets import FakeData
from torchvision.models import resnet101
from torchvision.transforms import ToTensor from torchvision.transforms import ToTensor
from fairscale.nn.data_parallel import ShardedDataParallel from fairscale.nn.data_parallel import ShardedDataParallel
...@@ -30,9 +30,10 @@ def dist_init(rank, world_size, backend): ...@@ -30,9 +30,10 @@ def dist_init(rank, world_size, backend):
dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size) dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
def get_problem(rank, data_size, batch_size, device): def get_problem(rank, data_size, batch_size, device, model_name: str):
# Standard RN101 # Select the desired model on the fly
model = resnet101(pretrained=False, progress=True).to(device) print(f"Using {model_name} for benchmarking")
model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False).to(device)
# Data setup, dummy data # Data setup, dummy data
def collate(inputs: List[Any]): def collate(inputs: List[Any]):
...@@ -78,7 +79,7 @@ def train( ...@@ -78,7 +79,7 @@ def train(
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
device = torch.device("cpu") if args.cpu else torch.device(rank) device = torch.device("cpu") if args.cpu else torch.device(rank)
model, dataloader, loss_fn = get_problem(rank, args.data_size, args.batch_size, device) model, dataloader, loss_fn = get_problem(rank, args.data_size, args.batch_size, device, args.torchvision_model)
# Shard the optimizer # Shard the optimizer
optimizer: Optional[torch.optim.Optimizer] = None optimizer: Optional[torch.optim.Optimizer] = None
...@@ -204,6 +205,7 @@ if __name__ == "__main__": ...@@ -204,6 +205,7 @@ if __name__ == "__main__":
parser.add_argument("--gloo", action="store_true", default=False) parser.add_argument("--gloo", action="store_true", default=False)
parser.add_argument("--profile", action="store_true", default=False) parser.add_argument("--profile", action="store_true", default=False)
parser.add_argument("--cpu", action="store_true", default=False) parser.add_argument("--cpu", action="store_true", default=False)
parser.add_argument("--torchvision_model", type=str, help="Any torchvision model name (str)", default="resnet101")
args = parser.parse_args() args = parser.parse_args()
print(f"Benchmark arguments: {args}") print(f"Benchmark arguments: {args}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment