"src/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "f62367a6c9cf26cffa17c9af5cbe28e5b7f6b91b"
Unverified Commit 49a3d9bc authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat][minor] OSS benchmark - pick the model via args (#152)

* Minor, ease of life to debug and makes it possible to test a host of models with the same code
parent 61bb32b5
......@@ -3,6 +3,7 @@
import argparse
from enum import Enum
import importlib
import math
import time
from typing import Any, List, Optional, cast
......@@ -16,7 +17,6 @@ import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torchvision.datasets import FakeData
from torchvision.models import resnet101
from torchvision.transforms import ToTensor
from fairscale.nn.data_parallel import ShardedDataParallel
......@@ -30,9 +30,10 @@ def dist_init(rank, world_size, backend):
dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
def get_problem(rank, data_size, batch_size, device):
# Standard RN101
model = resnet101(pretrained=False, progress=True).to(device)
def get_problem(rank, data_size, batch_size, device, model_name: str):
# Select the desired model on the fly
print(f"Using {model_name} for benchmarking")
model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False).to(device)
# Data setup, dummy data
def collate(inputs: List[Any]):
......@@ -78,7 +79,7 @@ def train(
torch.backends.cudnn.benchmark = False
device = torch.device("cpu") if args.cpu else torch.device(rank)
model, dataloader, loss_fn = get_problem(rank, args.data_size, args.batch_size, device)
model, dataloader, loss_fn = get_problem(rank, args.data_size, args.batch_size, device, args.torchvision_model)
# Shard the optimizer
optimizer: Optional[torch.optim.Optimizer] = None
......@@ -204,6 +205,7 @@ if __name__ == "__main__":
parser.add_argument("--gloo", action="store_true", default=False)
parser.add_argument("--profile", action="store_true", default=False)
parser.add_argument("--cpu", action="store_true", default=False)
parser.add_argument("--torchvision_model", type=str, help="Any torchvision model name (str)", default="resnet101")
args = parser.parse_args()
print(f"Benchmark arguments: {args}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment