Unverified Commit 8a49a748 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] Enabling ViT in OSS benchmarks (#322)

parent dd441e9d
...@@ -21,7 +21,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP ...@@ -21,7 +21,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import BatchSampler, DataLoader, Sampler from torch.utils.data import BatchSampler, DataLoader, Sampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor from torchvision.transforms import Compose, Resize, ToTensor
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS from fairscale.optim import OSS
...@@ -39,7 +39,11 @@ def dist_init(rank, world_size, backend): ...@@ -39,7 +39,11 @@ def dist_init(rank, world_size, backend):
def get_problem(rank, world_size, batch_size, device, model_name: str): def get_problem(rank, world_size, batch_size, device, model_name: str):
# Select the desired model on the fly # Select the desired model on the fly
logging.info(f"Using {model_name} for benchmarking") logging.info(f"Using {model_name} for benchmarking")
try:
model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False).to(device) model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False).to(device)
except AttributeError:
model = getattr(importlib.import_module("timm.models"), model_name)(pretrained=False).to(device)
# Data setup, duplicate the grey channels to get pseudo color # Data setup, duplicate the grey channels to get pseudo color
def collate(inputs: List[Any]): def collate(inputs: List[Any]):
...@@ -48,7 +52,16 @@ def get_problem(rank, world_size, batch_size, device, model_name: str): ...@@ -48,7 +52,16 @@ def get_problem(rank, world_size, batch_size, device, model_name: str):
"label": torch.tensor([i[1] for i in inputs]).to(device), "label": torch.tensor([i[1] for i in inputs]).to(device),
} }
dataset = MNIST(transform=ToTensor(), download=False, root=TEMPDIR) # Transforms
transforms = []
if model_name.startswith("vit"):
# ViT models are fixed size. Add a ad-hoc transform to resize the pictures accordingly
pic_size = int(model_name.split("_")[-1])
transforms.append(Resize(pic_size))
transforms.append(ToTensor())
dataset = MNIST(transform=Compose(transforms), download=False, root=TEMPDIR)
sampler: Sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) sampler: Sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
batch_sampler = BatchSampler(sampler, batch_size, drop_last=True) batch_sampler = BatchSampler(sampler, batch_size, drop_last=True)
dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate) dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate)
...@@ -88,7 +101,7 @@ def train( ...@@ -88,7 +101,7 @@ def train(
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
device = torch.device("cpu") if args.cpu else torch.device(rank) device = torch.device("cpu") if args.cpu else torch.device(rank)
model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.torchvision_model) model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.model)
# Shard the optimizer # Shard the optimizer
optimizer: Optional[torch.optim.Optimizer] = None optimizer: Optional[torch.optim.Optimizer] = None
...@@ -259,7 +272,7 @@ if __name__ == "__main__": ...@@ -259,7 +272,7 @@ if __name__ == "__main__":
parser.add_argument("--gloo", action="store_true", default=False) parser.add_argument("--gloo", action="store_true", default=False)
parser.add_argument("--profile", action="store_true", default=False) parser.add_argument("--profile", action="store_true", default=False)
parser.add_argument("--cpu", action="store_true", default=False) parser.add_argument("--cpu", action="store_true", default=False)
parser.add_argument("--torchvision_model", type=str, help="Any torchvision model name (str)", default="resnet101") parser.add_argument("--model", type=str, help="Any torchvision or timm model name (str)", default="resnet101")
parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information") parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
parser.add_argument("--amp", action="store_true", default=False, help="Activate torch AMP") parser.add_argument("--amp", action="store_true", default=False, help="Activate torch AMP")
...@@ -293,8 +306,8 @@ if __name__ == "__main__": ...@@ -293,8 +306,8 @@ if __name__ == "__main__":
if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone: if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
logging.info("\n*** Benchmark vanilla optimizer") logging.info("\n*** Benchmark vanilla optimizer")
mp.spawn( mp.spawn(
train, train, # type: ignore
args=(args, BACKEND, OptimType.vanilla, False,), # no regression check args=(args, BACKEND, OptimType.vanilla, False), # no regression check
nprocs=args.world_size, nprocs=args.world_size,
join=True, join=True,
) )
...@@ -302,13 +315,13 @@ if __name__ == "__main__": ...@@ -302,13 +315,13 @@ if __name__ == "__main__":
if args.optim_type == OptimType.oss_ddp or args.optim_type == OptimType.everyone: if args.optim_type == OptimType.oss_ddp or args.optim_type == OptimType.everyone:
logging.info("\n*** Benchmark OSS with DDP") logging.info("\n*** Benchmark OSS with DDP")
mp.spawn( mp.spawn(
train, args=(args, BACKEND, OptimType.oss_ddp, args.check_regression), nprocs=args.world_size, join=True, train, args=(args, BACKEND, OptimType.oss_ddp, args.check_regression), nprocs=args.world_size, join=True, # type: ignore
) )
if args.optim_type == OptimType.oss_sharded_ddp or args.optim_type == OptimType.everyone: if args.optim_type == OptimType.oss_sharded_ddp or args.optim_type == OptimType.everyone:
logging.info("\n*** Benchmark OSS with ShardedDDP") logging.info("\n*** Benchmark OSS with ShardedDDP")
mp.spawn( mp.spawn(
train, train, # type: ignore
args=( args=(
args, args,
BACKEND, BACKEND,
......
...@@ -12,3 +12,4 @@ torch >= 1.5.1 ...@@ -12,3 +12,4 @@ torch >= 1.5.1
torchvision >= 0.6.0 torchvision >= 0.6.0
# NOTE(msb) not a dependency but needed by torch # NOTE(msb) not a dependency but needed by torch
numpy == 1.17.4 numpy == 1.17.4
timm == 0.3.4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment