"...text-generation-inference.git" did not exist on "4c693e65245058a4d0ca227ee30b6d8a35d115f1"
Commit e0422994 authored by Kai Chen's avatar Kai Chen
Browse files

update cifar10 example

parent 961c3388
...@@ -8,23 +8,24 @@ batch_size = 64 ...@@ -8,23 +8,24 @@ batch_size = 64
# optimizer and learning rate # optimizer and learning rate
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=5e-4) optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=5e-4)
lr_policy = dict(policy='step', step=2) optimizer_config = dict(grad_clip=None)
lr_config = dict(policy='step', step=2)
# runtime settings # runtime settings
work_dir = './demo' work_dir = './demo'
gpus = range(2) gpus = range(2)
dist_params = dict(backend='gloo') # gloo is much slower than nccl
data_workers = 2 # data workers per gpu data_workers = 2 # data workers per gpu
checkpoint_cfg = dict(interval=1) # save checkpoint at every epoch checkpoint_config = dict(interval=1) # save checkpoint at every epoch
workflow = [('train', 1), ('val', 1)] workflow = [('train', 1), ('val', 1)]
max_epoch = 6 total_epochs = 6
resume_from = None resume_from = None
load_from = None load_from = None
# logging settings # logging settings
log_level = 'INFO' log_level = 'INFO'
log_cfg = dict( log_config = dict(
# log at every 50 iterations interval=50, # log at every 50 iterations
interval=50,
hooks=[ hooks=[
dict(type='TextLoggerHook'), dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook', log_dir=work_dir + '/log'), # dict(type='TensorboardLoggerHook', log_dir=work_dir + '/log'),
......
#!/usr/bin/env bash
PYTHON=${PYTHON:-"python"}
$PYTHON -m torch.distributed.launch --nproc_per_node=$2 train_cifar10.py $1 --launcher pytorch ${@:3}
\ No newline at end of file
import logging
import os
from argparse import ArgumentParser from argparse import ArgumentParser
from collections import OrderedDict from collections import OrderedDict
import torch import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F import torch.nn.functional as F
from mmcv import Config from mmcv import Config
from mmcv.torchpack import Runner from mmcv.torchpack import Runner, DistSamplerSeedHook
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision import datasets, transforms from torchvision import datasets, transforms
import resnet_cifar import resnet_cifar
...@@ -41,18 +48,55 @@ def batch_processor(model, data, train_mode): ...@@ -41,18 +48,55 @@ def batch_processor(model, data, train_mode):
return outputs return outputs
def get_logger(log_level):
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=log_level)
logger = logging.getLogger()
return logger
def init_dist(backend='nccl', **kwargs):
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def parse_args(): def parse_args():
parser = ArgumentParser(description='Train CIFAR-10 classification') parser = ArgumentParser(description='Train CIFAR-10 classification')
parser.add_argument('config', help='train config file path') parser.add_argument('config', help='train config file path')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
return parser.parse_args() return parser.parse_args()
def main(): def main():
args = parse_args() args = parse_args()
cfg = Config.fromfile(args.config) cfg = Config.fromfile(args.config)
model = getattr(resnet_cifar, cfg.model)()
model = torch.nn.DataParallel(model, device_ids=cfg.gpus).cuda()
logger = get_logger(cfg.log_level)
# init distributed environment if necessary
if args.launcher == 'none':
dist = False
logger.info('Disabled distributed training.')
else:
dist = True
init_dist(**cfg.dist_params)
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
if rank != 0:
logger.setLevel('ERROR')
logger.info('Enabled distributed training.')
# build datasets and dataloaders
normalize = transforms.Normalize(mean=cfg.mean, std=cfg.std) normalize = transforms.Normalize(mean=cfg.mean, std=cfg.std)
train_dataset = datasets.CIFAR10( train_dataset = datasets.CIFAR10(
root=cfg.data_root, root=cfg.data_root,
...@@ -65,37 +109,67 @@ def main(): ...@@ -65,37 +109,67 @@ def main():
])) ]))
val_dataset = datasets.CIFAR10( val_dataset = datasets.CIFAR10(
root=cfg.data_root, root=cfg.data_root,
train=False,
transform=transforms.Compose([ transform=transforms.Compose([
transforms.ToTensor(), transforms.ToTensor(),
normalize, normalize,
])) ]))
if dist:
num_workers = cfg.data_workers * len(cfg.gpus) num_workers = cfg.data_workers
train_loader = torch.utils.data.DataLoader( assert cfg.batch_size % world_size == 0
batch_size = cfg.batch_size // world_size
train_sampler = DistributedSampler(train_dataset, world_size, rank)
val_sampler = DistributedSampler(val_dataset, world_size, rank)
shuffle = False
else:
num_workers = cfg.data_workers * len(cfg.gpus)
batch_size = cfg.batch_size
train_sampler = None
val_sampler = None
shuffle = True
train_loader = DataLoader(
train_dataset, train_dataset,
batch_size=cfg.batch_size, batch_size=batch_size,
shuffle=True, shuffle=shuffle,
num_workers=num_workers, sampler=train_sampler,
pin_memory=True) num_workers=num_workers)
val_loader = torch.utils.data.DataLoader( val_loader = DataLoader(
val_dataset, val_dataset,
batch_size=cfg.batch_size, batch_size=batch_size,
shuffle=False, shuffle=False,
num_workers=num_workers, sampler=val_sampler,
pin_memory=True) num_workers=num_workers)
runner = Runner(model, cfg.optimizer, batch_processor, cfg.work_dir)
runner.register_default_hooks(
lr_config=cfg.lr_policy,
checkpoint_config=cfg.checkpoint_cfg,
log_config=cfg.log_cfg)
# build model
model = getattr(resnet_cifar, cfg.model)()
if dist:
model = DistributedDataParallel(
model.cuda(), device_ids=[torch.cuda.current_device()])
else:
model = DataParallel(model, device_ids=cfg.gpus).cuda()
# build runner and register hooks
runner = Runner(
model,
batch_processor,
cfg.optimizer,
cfg.work_dir,
log_level=cfg.log_level)
runner.register_training_hooks(
lr_config=cfg.lr_config,
optimizer_config=cfg.optimizer_config,
checkpoint_config=cfg.checkpoint_config,
log_config=cfg.log_config)
if dist:
runner.register_hook(DistSamplerSeedHook())
# load param (if necessary) and run
if cfg.get('resume_from') is not None: if cfg.get('resume_from') is not None:
runner.resume(cfg.resume_from) runner.resume(cfg.resume_from)
elif cfg.get('load_from') is not None: elif cfg.get('load_from') is not None:
runner.load_checkpoint(cfg.load_from) runner.load_checkpoint(cfg.load_from)
runner.run([train_loader, val_loader], cfg.workflow, cfg.max_epoch) runner.run([train_loader, val_loader], cfg.workflow, cfg.total_epochs)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment