Unverified Commit 8a49a748 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] Enabling ViT in OSS benchmarks (#322)

parent dd441e9d
......@@ -21,7 +21,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import BatchSampler, DataLoader, Sampler
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torchvision.transforms import Compose, Resize, ToTensor
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS
......@@ -39,7 +39,11 @@ def dist_init(rank, world_size, backend):
def get_problem(rank, world_size, batch_size, device, model_name: str):
# Select the desired model on the fly
logging.info(f"Using {model_name} for benchmarking")
model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False).to(device)
try:
model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False).to(device)
except AttributeError:
model = getattr(importlib.import_module("timm.models"), model_name)(pretrained=False).to(device)
# Data setup, duplicate the grey channels to get pseudo color
def collate(inputs: List[Any]):
......@@ -48,7 +52,16 @@ def get_problem(rank, world_size, batch_size, device, model_name: str):
"label": torch.tensor([i[1] for i in inputs]).to(device),
}
dataset = MNIST(transform=ToTensor(), download=False, root=TEMPDIR)
# Transforms
transforms = []
if model_name.startswith("vit"):
# ViT models are fixed size. Add a ad-hoc transform to resize the pictures accordingly
pic_size = int(model_name.split("_")[-1])
transforms.append(Resize(pic_size))
transforms.append(ToTensor())
dataset = MNIST(transform=Compose(transforms), download=False, root=TEMPDIR)
sampler: Sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
batch_sampler = BatchSampler(sampler, batch_size, drop_last=True)
dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate)
......@@ -88,7 +101,7 @@ def train(
torch.backends.cudnn.benchmark = False
device = torch.device("cpu") if args.cpu else torch.device(rank)
model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.torchvision_model)
model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.model)
# Shard the optimizer
optimizer: Optional[torch.optim.Optimizer] = None
......@@ -259,7 +272,7 @@ if __name__ == "__main__":
parser.add_argument("--gloo", action="store_true", default=False)
parser.add_argument("--profile", action="store_true", default=False)
parser.add_argument("--cpu", action="store_true", default=False)
parser.add_argument("--torchvision_model", type=str, help="Any torchvision model name (str)", default="resnet101")
parser.add_argument("--model", type=str, help="Any torchvision or timm model name (str)", default="resnet101")
parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
parser.add_argument("--amp", action="store_true", default=False, help="Activate torch AMP")
......@@ -293,8 +306,8 @@ if __name__ == "__main__":
if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
logging.info("\n*** Benchmark vanilla optimizer")
mp.spawn(
train,
args=(args, BACKEND, OptimType.vanilla, False,), # no regression check
train, # type: ignore
args=(args, BACKEND, OptimType.vanilla, False), # no regression check
nprocs=args.world_size,
join=True,
)
......@@ -302,13 +315,13 @@ if __name__ == "__main__":
if args.optim_type == OptimType.oss_ddp or args.optim_type == OptimType.everyone:
logging.info("\n*** Benchmark OSS with DDP")
mp.spawn(
train, args=(args, BACKEND, OptimType.oss_ddp, args.check_regression), nprocs=args.world_size, join=True,
train, args=(args, BACKEND, OptimType.oss_ddp, args.check_regression), nprocs=args.world_size, join=True, # type: ignore
)
if args.optim_type == OptimType.oss_sharded_ddp or args.optim_type == OptimType.everyone:
logging.info("\n*** Benchmark OSS with ShardedDDP")
mp.spawn(
train,
train, # type: ignore
args=(
args,
BACKEND,
......
......@@ -12,3 +12,4 @@ torch >= 1.5.1
torchvision >= 0.6.0
# NOTE(msb) not a dependency but needed by torch
numpy == 1.17.4
timm == 0.3.4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment