Commit e0422994 authored by Kai Chen's avatar Kai Chen
Browse files

update cifar10 example

parent 961c3388
......@@ -8,23 +8,24 @@ batch_size = 64
# optimizer and learning rate
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=5e-4)
lr_policy = dict(policy='step', step=2)
optimizer_config = dict(grad_clip=None)
lr_config = dict(policy='step', step=2)
# runtime settings
work_dir = './demo'
gpus = range(2)
dist_params = dict(backend='gloo') # gloo is much slower than nccl
data_workers = 2 # data workers per gpu
checkpoint_cfg = dict(interval=1) # save checkpoint at every epoch
checkpoint_config = dict(interval=1) # save checkpoint at every epoch
workflow = [('train', 1), ('val', 1)]
max_epoch = 6
total_epochs = 6
resume_from = None
load_from = None
# logging settings
log_level = 'INFO'
log_cfg = dict(
# log at every 50 iterations
interval=50,
log_config = dict(
interval=50, # log at every 50 iterations
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook', log_dir=work_dir + '/log'),
......
#!/usr/bin/env bash
PYTHON=${PYTHON:-"python"}
$PYTHON -m torch.distributed.launch --nproc_per_node=$2 train_cifar10.py $1 --launcher pytorch ${@:3}
\ No newline at end of file
import logging
import os
from argparse import ArgumentParser
from collections import OrderedDict
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
from mmcv import Config
from mmcv.torchpack import Runner
from mmcv.torchpack import Runner, DistSamplerSeedHook
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision import datasets, transforms
import resnet_cifar
......@@ -41,18 +48,55 @@ def batch_processor(model, data, train_mode):
return outputs
def get_logger(log_level):
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=log_level)
logger = logging.getLogger()
return logger
def init_dist(backend='nccl', **kwargs):
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def parse_args():
parser = ArgumentParser(description='Train CIFAR-10 classification')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
return parser.parse_args()
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
model = getattr(resnet_cifar, cfg.model)()
model = torch.nn.DataParallel(model, device_ids=cfg.gpus).cuda()
logger = get_logger(cfg.log_level)
# init distributed environment if necessary
if args.launcher == 'none':
dist = False
logger.info('Disabled distributed training.')
else:
dist = True
init_dist(**cfg.dist_params)
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
if rank != 0:
logger.setLevel('ERROR')
logger.info('Enabled distributed training.')
# build datasets and dataloaders
normalize = transforms.Normalize(mean=cfg.mean, std=cfg.std)
train_dataset = datasets.CIFAR10(
root=cfg.data_root,
......@@ -65,37 +109,67 @@ def main():
]))
val_dataset = datasets.CIFAR10(
root=cfg.data_root,
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
normalize,
]))
num_workers = cfg.data_workers * len(cfg.gpus)
train_loader = torch.utils.data.DataLoader(
if dist:
num_workers = cfg.data_workers
assert cfg.batch_size % world_size == 0
batch_size = cfg.batch_size // world_size
train_sampler = DistributedSampler(train_dataset, world_size, rank)
val_sampler = DistributedSampler(val_dataset, world_size, rank)
shuffle = False
else:
num_workers = cfg.data_workers * len(cfg.gpus)
batch_size = cfg.batch_size
train_sampler = None
val_sampler = None
shuffle = True
train_loader = DataLoader(
train_dataset,
batch_size=cfg.batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True)
val_loader = torch.utils.data.DataLoader(
batch_size=batch_size,
shuffle=shuffle,
sampler=train_sampler,
num_workers=num_workers)
val_loader = DataLoader(
val_dataset,
batch_size=cfg.batch_size,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True)
runner = Runner(model, cfg.optimizer, batch_processor, cfg.work_dir)
runner.register_default_hooks(
lr_config=cfg.lr_policy,
checkpoint_config=cfg.checkpoint_cfg,
log_config=cfg.log_cfg)
sampler=val_sampler,
num_workers=num_workers)
# build model
model = getattr(resnet_cifar, cfg.model)()
if dist:
model = DistributedDataParallel(
model.cuda(), device_ids=[torch.cuda.current_device()])
else:
model = DataParallel(model, device_ids=cfg.gpus).cuda()
# build runner and register hooks
runner = Runner(
model,
batch_processor,
cfg.optimizer,
cfg.work_dir,
log_level=cfg.log_level)
runner.register_training_hooks(
lr_config=cfg.lr_config,
optimizer_config=cfg.optimizer_config,
checkpoint_config=cfg.checkpoint_config,
log_config=cfg.log_config)
if dist:
runner.register_hook(DistSamplerSeedHook())
# load param (if necessary) and run
if cfg.get('resume_from') is not None:
runner.resume(cfg.resume_from)
elif cfg.get('load_from') is not None:
runner.load_checkpoint(cfg.load_from)
runner.run([train_loader, val_loader], cfg.workflow, cfg.max_epoch)
runner.run([train_loader, val_loader], cfg.workflow, cfg.total_epochs)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment