"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "0706786e5393e28ecf8b669bdb9d0ee03239b019"
Commit ffdc1d45 authored by Kai Chen's avatar Kai Chen
Browse files

add initial version of torchpack

parent 02ceae83
# model settings
model = 'resnet18'
# dataset settings
data_root = '/mnt/SSD/dataset/cifar10'
mean = [0.4914, 0.4822, 0.4465]
std = [0.2023, 0.1994, 0.2010]
batch_size = 64
# optimizer and learning rate
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=5e-4)
lr_policy = dict(policy='step', step=2)
# runtime settings
work_dir = './demo'
gpus = range(2)
data_workers = 2 # data workers per gpu
checkpoint_cfg = dict(interval=1) # save checkpoint at every epoch
workflow = [('train', 1), ('val', 1)]
max_epoch = 6
resume_from = None
load_from = None
# logging settings
log_level = 'INFO'
log_cfg = dict(
# log at every 50 iterations
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook', log_dir=work_dir + '/log'),
])
# copied from https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
import torch.nn as nn
import torch.nn.functional as F
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False), nn.BatchNorm2d(self.expansion * planes))
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(
planes, self.expansion * planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False), nn.BatchNorm2d(self.expansion * planes))
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
super(ResNet, self).__init__()
self.in_planes = 64
self.conv1 = nn.Conv2d(
3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
def resnet18():
return ResNet(BasicBlock, [2, 2, 2, 2])
def resnet34():
return ResNet(BasicBlock, [3, 4, 6, 3])
def resnet50():
return ResNet(Bottleneck, [3, 4, 6, 3])
def resnet101():
return ResNet(Bottleneck, [3, 4, 23, 3])
def resnet152():
return ResNet(Bottleneck, [3, 8, 36, 3])
from argparse import ArgumentParser
from collections import OrderedDict
import torch
import torch.nn.functional as F
from mmcv import Config
from mmcv.torchpack import Runner
from torchvision import datasets, transforms
import resnet_cifar
def accuracy(output, target, topk=(1, )):
"""Computes the precision@k for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def batch_processor(model, data, train_mode):
img, label = data
label = label.cuda(non_blocking=True)
pred = model(img)
loss = F.cross_entropy(pred, label)
acc_top1, acc_top5 = accuracy(pred, label, topk=(1, 5))
log_vars = OrderedDict()
log_vars['loss'] = loss.item()
log_vars['acc_top1'] = acc_top1.item()
log_vars['acc_top5'] = acc_top5.item()
outputs = dict(loss=loss, log_vars=log_vars, num_samples=img.size(0))
return outputs
def parse_args():
parser = ArgumentParser(description='Train CIFAR-10 classification')
parser.add_argument('config', help='train config file path')
return parser.parse_args()
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
model = getattr(resnet_cifar, cfg.model)()
model = torch.nn.DataParallel(model, device_ids=cfg.gpus).cuda()
normalize = transforms.Normalize(mean=cfg.mean, std=cfg.std)
train_dataset = datasets.CIFAR10(
root=cfg.data_root,
train=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
val_dataset = datasets.CIFAR10(
root=cfg.data_root,
transform=transforms.Compose([
transforms.ToTensor(),
normalize,
]))
num_workers = cfg.data_workers * len(cfg.gpus)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=cfg.batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=cfg.batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True)
runner = Runner(model, cfg.optimizer, batch_processor, cfg.work_dir)
runner.register_default_hooks(
lr_config=cfg.lr_policy,
checkpoint_config=cfg.checkpoint_cfg,
log_config=cfg.log_cfg)
if cfg.get('resume_from') is not None:
runner.resume(cfg.resume_from)
elif cfg.get('load_from') is not None:
runner.load_checkpoint(cfg.load_from)
runner.run([train_loader, val_loader], cfg.workflow, cfg.max_epoch)
if __name__ == '__main__':
main()
from .hooks import *
from .io import *
from .parallel import *
from .runner import *
from .utils import *
from .hook import Hook
from .checkpoint_saver import CheckpointSaverHook
from .closure import ClosureHook
from .lr_updater import LrUpdaterHook
from .optimizer_stepper import OptimizerStepperHook
from .iter_timer import IterTimerHook
from .logger import *
from .hook import Hook
from ..utils import master_only
class CheckpointSaverHook(Hook):
def __init__(self,
interval=-1,
save_optimizer=True,
out_dir=None,
**kwargs):
self.interval = interval
self.save_optimizer = save_optimizer
self.out_dir = out_dir
self.args = kwargs
@master_only
def after_train_epoch(self, runner):
if not self.every_n_epochs(runner, self.interval):
return
if not self.out_dir:
self.out_dir = runner.work_dir
runner.save_checkpoint(
self.out_dir, save_optimizer=self.save_optimizer, **self.args)
from .hook import Hook
class ClosureHook(Hook):
def __init__(self, fn_name, fn):
assert hasattr(self, fn_name)
assert callable(fn)
setattr(self, fn_name, fn)
class Hook(object):
def before_run(self, runner):
pass
def after_run(self, runner):
pass
def before_epoch(self, runner):
pass
def after_epoch(self, runner):
pass
def before_iter(self, runner):
pass
def after_iter(self, runner):
pass
def before_train_epoch(self, runner):
self.before_epoch(runner)
def before_val_epoch(self, runner):
self.before_epoch(runner)
def after_train_epoch(self, runner):
self.after_epoch(runner)
def after_val_epoch(self, runner):
self.after_epoch(runner)
def before_train_iter(self, runner):
self.before_iter(runner)
def before_val_iter(self, runner):
self.before_iter(runner)
def after_train_iter(self, runner):
self.after_iter(runner)
def after_val_iter(self, runner):
self.after_iter(runner)
def every_n_epochs(self, runner, n):
return (runner.epoch + 1) % n == 0 if n > 0 else False
def every_n_inner_iters(self, runner, n):
return (runner.inner_iter + 1) % n == 0 if n > 0 else False
def every_n_iters(self, runner, n):
return (runner.iter + 1) % n == 0 if n > 0 else False
def end_of_epoch(self, runner):
return runner.inner_iter + 1 == len(runner.data_loader)
import time
from .hook import Hook
class IterTimerHook(Hook):
def before_epoch(self, runner):
self.t = time.time()
def before_iter(self, runner):
runner.log_buffer.update({'data_time': time.time() - self.t})
def after_iter(self, runner):
runner.log_buffer.update({'time': time.time() - self.t})
self.t = time.time()
from .base import LoggerHook
from .pavi import PaviClient, PaviLoggerHook
from .tensorboard import TensorboardLoggerHook
from .text import TextLoggerHook
from abc import ABCMeta, abstractmethod
from ..hook import Hook
class LoggerHook(Hook):
"""Base class for logger hooks."""
__metaclass__ = ABCMeta
def __init__(self, interval=10, ignore_last=True, reset_flag=False):
self.interval = interval
self.ignore_last = ignore_last
self.reset_flag = reset_flag
@abstractmethod
def log(self, runner):
pass
def before_run(self, runner):
for hook in runner.hooks[::-1]:
if isinstance(hook, LoggerHook):
hook.reset_flag = True
break
def before_epoch(self, runner):
runner.log_buffer.clear() # clear logs of last epoch
def after_train_iter(self, runner):
if self.every_n_inner_iters(runner, self.interval):
runner.log_buffer.average(self.interval)
elif self.end_of_epoch(runner) and not self.ignore_last:
# not precise but more stable
runner.log_buffer.average(self.interval)
if runner.log_buffer.ready:
self.log(runner)
if self.reset_flag:
runner.log_buffer.clear_output()
def after_train_epoch(self, runner):
if runner.log_buffer.ready:
self.log(runner)
def after_val_epoch(self, runner):
runner.log_buffer.average()
self.log(runner)
if self.reset_flag:
runner.log_buffer.clear_output()
from __future__ import print_function
import os
import time
from datetime import datetime
from threading import Thread
import requests
from six.moves.queue import Empty, Queue
from .base import LoggerHook
from ...utils import master_only, get_host_info
class PaviClient(object):
def __init__(self, url, username=None, password=None, instance_id=None):
self.url = url
self.username = self._get_env_var(username, 'PAVI_USERNAME')
self.password = self._get_env_var(password, 'PAVI_PASSWORD')
self.instance_id = instance_id
self.log_queue = None
def _get_env_var(self, var, env_var):
if var is not None:
return str(var)
var = os.getenv(env_var)
if not var:
raise ValueError(
'"{}" is neither specified nor defined as env variables'.
format(env_var))
return var
def connect(self,
model_name,
work_dir=None,
info=dict(),
timeout=5,
logger=None):
if logger:
log_info = logger.info
log_error = logger.error
else:
log_info = log_error = print
log_info('connecting pavi service {}...'.format(self.url))
post_data = dict(
time=str(datetime.now()),
username=self.username,
password=self.password,
instance_id=self.instance_id,
model=model_name,
work_dir=os.path.abspath(work_dir) if work_dir else '',
session_file=info.get('session_file', ''),
session_text=info.get('session_text', ''),
model_text=info.get('model_text', ''),
device=get_host_info())
try:
response = requests.post(self.url, json=post_data, timeout=timeout)
except Exception as ex:
log_error('fail to connect to pavi service: {}'.format(ex))
else:
if response.status_code == 200:
self.instance_id = response.text
log_info('pavi service connected, instance_id: {}'.format(
self.instance_id))
self.log_queue = Queue()
self.log_thread = Thread(target=self.post_worker_fn)
self.log_thread.daemon = True
self.log_thread.start()
return True
else:
log_error('fail to connect to pavi service, status code: '
'{}, err message: {}'.format(response.status_code,
response.reason))
return False
def post_worker_fn(self, max_retry=3, queue_timeout=1, req_timeout=3):
while True:
try:
log = self.log_queue.get(timeout=queue_timeout)
except Empty:
time.sleep(1)
except Exception as ex:
print('fail to get logs from queue: {}'.format(ex))
else:
retry = 0
while retry < max_retry:
try:
response = requests.post(
self.url, json=log, timeout=req_timeout)
except Exception as ex:
retry += 1
print('error when posting logs to pavi: {}'.format(ex))
else:
status_code = response.status_code
if status_code == 200:
break
else:
print('unexpected status code: %d, err msg: %s',
status_code, response.reason)
retry += 1
if retry == max_retry:
print('fail to send logs of iteration %d', log['iter_num'])
def log(self, phase, iter, outputs):
if self.log_queue is not None:
logs = {
'time': str(datetime.now()),
'instance_id': self.instance_id,
'flow_id': phase,
'iter_num': iter,
'outputs': outputs,
'msg': ''
}
self.log_queue.put(logs)
class PaviLoggerHook(LoggerHook):
def __init__(self,
url,
username=None,
password=None,
instance_id=None,
interval=10,
reset_meter=True,
ignore_last=True):
self.pavi = PaviClient(url, username, password, instance_id)
super(PaviLoggerHook, self).__init__(interval, reset_meter,
ignore_last)
@master_only
def connect(self,
model_name,
work_dir=None,
info=dict(),
timeout=5,
logger=None):
return self.pavi.connect(model_name, work_dir, info, timeout, logger)
@master_only
def log(self, runner):
log_outs = runner.log_buffer.output.copy()
log_outs.pop('time', None)
log_outs.pop('data_time', None)
self.pavi.log(runner.mode, runner.iter, log_outs)
from .base import LoggerHook
from ...utils import master_only
class TensorboardLoggerHook(LoggerHook):
def __init__(self,
log_dir,
interval=10,
reset_meter=True,
ignore_last=True):
super(TensorboardLoggerHook, self).__init__(interval, reset_meter,
ignore_last)
self.log_dir = log_dir
@master_only
def before_run(self, runner):
try:
from tensorboardX import SummaryWriter
except ImportError:
raise ImportError('Please install tensorflow and tensorboardX '
'to use TensorboardLoggerHook.')
else:
self.writer = SummaryWriter(self.log_dir)
@master_only
def log(self, runner):
for var in runner.log_buffer.output:
if var in ['time', 'data_time']:
continue
tag = '{}/{}'.format(var, runner.mode)
self.writer.add_scalar(tag, runner.log_buffer.output[var],
runner.iter)
@master_only
def after_run(self, runner):
self.writer.close()
from .base import LoggerHook
class TextLoggerHook(LoggerHook):
def log(self, runner):
if runner.mode == 'train':
lr_str = ', '.join(
['{:.5f}'.format(lr) for lr in runner.current_lr()])
log_str = 'Epoch [{}][{}/{}]\tlr: {}, '.format(
runner.epoch + 1, runner.inner_iter + 1,
len(runner.data_loader), lr_str)
else:
log_str = 'Epoch({}) [{}][{}]\t'.format(runner.mode, runner.epoch,
runner.inner_iter + 1)
if 'time' in runner.log_buffer.output:
log_str += (
'time: {log[time]:.3f}, data_time: {log[data_time]:.3f}, '.
format(log=runner.log_buffer.output))
log_items = []
for name, val in runner.log_buffer.output.items():
if name in ['time', 'data_time']:
continue
log_items.append('{}: {:.4f}'.format(name, val))
log_str += ', '.join(log_items)
runner.logger.info(log_str)
from __future__ import division
from .hook import Hook
class LrUpdaterHook(Hook):
def __init__(self,
by_epoch=True,
warmup=None,
warmup_iters=0,
warmup_ratio=0.1,
**kwargs):
# validate the "warmup" argument
if warmup is not None:
if warmup not in ['constant', 'linear', 'exp']:
raise ValueError(
'"{}" is not a supported type for warming up, valid types'
' are "constant" and "linear"'.format(warmup))
if warmup is not None:
assert warmup_iters > 0, \
'"warmup_iters" must be a positive integer'
assert 0 < warmup_ratio <= 1.0, \
'"warmup_ratio" must be in range (0,1]'
self.by_epoch = by_epoch
self.warmup = warmup
self.warmup_iters = warmup_iters
self.warmup_ratio = warmup_ratio
self.base_lr = [] # initial lr for all param groups
self.regular_lr = [] # expected lr if no warming up is performed
def _set_lr(self, runner, lr_groups):
for param_group, lr in zip(runner.optimizer.param_groups, lr_groups):
param_group['lr'] = lr
def get_lr(self, runner, base_lr):
raise NotImplementedError
def get_regular_lr(self, runner):
return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr]
def get_warmup_lr(self, cur_iters):
if self.warmup == 'constant':
warmup_lr = [_lr * self.warmup_ratio for _lr in self.regular_lr]
elif self.warmup == 'linear':
k = (1 - cur_iters / self.warmup_iters) * (1 - self.warmup_ratio)
warmup_lr = [_lr * (1 - k) for _lr in self.regular_lr]
elif self.warmup == 'exp':
k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
warmup_lr = [_lr * k for _lr in self.regular_lr]
return warmup_lr
def before_run(self, runner):
# NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved,
# it will be set according to the optimizer params
for group in runner.optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
self.base_lr = [
group['initial_lr'] for group in runner.optimizer.param_groups
]
def before_train_epoch(self, runner):
if not self.by_epoch:
return
self.regular_lr = self.get_regular_lr(runner)
self._set_lr(runner, self.regular_lr)
def before_train_iter(self, runner):
cur_iter = runner.iter
if not self.by_epoch:
self.regular_lr = self.get_regular_lr(runner)
if self.warmup is None or cur_iter >= self.warmup_iters:
self._set_lr(runner, self.regular_lr)
else:
warmup_lr = self.get_warmup_lr(cur_iter)
self._set_lr(runner, warmup_lr)
elif self.by_epoch:
if self.warmup is None or cur_iter > self.warmup_iters:
return
elif cur_iter == self.warmup_iters:
self._set_lr(runner, self.regular_lr)
else:
warmup_lr = self.get_warmup_lr(cur_iter)
self._set_lr(runner, warmup_lr)
class FixedLrUpdaterHook(LrUpdaterHook):
def __init__(self, **kwargs):
super(FixedLrUpdaterHook, self).__init__(**kwargs)
def get_lr(self, runner, base_lr):
return base_lr
class StepLrUpdaterHook(LrUpdaterHook):
def __init__(self, step, gamma=0.1, **kwargs):
assert isinstance(step, (list, int))
if isinstance(step, list):
for s in step:
assert isinstance(s, int) and s > 0
elif isinstance(step, int):
assert step > 0
else:
raise TypeError('"step" must be a list or integer')
self.step = step
self.gamma = gamma
super(StepLrUpdaterHook, self).__init__(**kwargs)
def get_lr(self, runner, base_lr):
progress = runner.epoch if self.by_epoch else runner.iter
if isinstance(self.step, int):
return base_lr * (self.gamma**(progress // self.step))
exp = len(self.step)
for i, s in enumerate(self.step):
if progress < s:
exp = i
break
return base_lr * self.gamma**exp
class ExpLrUpdaterHook(LrUpdaterHook):
def __init__(self, gamma, **kwargs):
self.gamma = gamma
super(ExpLrUpdaterHook, self).__init__(**kwargs)
def get_lr(self, runner, base_lr):
progress = runner.epoch if self.by_epoch else runner.iter
return base_lr * self.gamma**progress
class PolyLrUpdaterHook(LrUpdaterHook):
def __init__(self, power=1., **kwargs):
self.power = power
super(PolyLrUpdaterHook, self).__init__(**kwargs)
def get_lr(self, runner, base_lr):
if self.by_epoch:
progress = runner.epoch
max_progress = runner.max_epochs
else:
progress = runner.iter
max_progress = runner.max_iters
return base_lr * (1 - progress / max_progress)**self.power
class InvLrUpdaterHook(LrUpdaterHook):
def __init__(self, gamma, power=1., **kwargs):
self.gamma = gamma
self.power = power
super(InvLrUpdaterHook, self).__init__(**kwargs)
def get_lr(self, runner, base_lr):
progress = runner.epoch if self.by_epoch else runner.iter
return base_lr * (1 + self.gamma * progress)**(-self.power)
from torch.nn.utils import clip_grad
from .hook import Hook
class OptimizerStepperHook(Hook):
def __init__(self, grad_clip=False, max_norm=35, norm_type=2):
self.grad_clip = grad_clip
self.max_norm = max_norm
self.norm_type = norm_type
def after_train_iter(self, runner):
runner.optimizer.zero_grad()
runner.outputs['loss'].backward()
if self.grad_clip:
clip_grad.clip_grad_norm_(
filter(lambda p: p.requires_grad, runner.model.parameters()),
max_norm=self.max_norm,
norm_type=self.norm_type)
runner.optimizer.step()
import os.path as osp
import time
from collections import OrderedDict
import mmcv
import torch
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.utils import model_zoo
def load_state_dict(module, state_dict, strict=False, logger=None):
"""Load state_dict to a module.
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
Default value for ``strict`` is set to ``False`` and the message for
param mismatch will be shown even if strict is False.
Args:
module (Module): Module that receives the state_dict.
state_dict (OrderedDict): Weights.
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used.
"""
unexpected_keys = []
own_state = module.state_dict()
for name, param in state_dict.items():
if name not in own_state:
unexpected_keys.append(name)
continue
if isinstance(param, torch.nn.Parameter):
# backwards compatibility for serialized parameters
param = param.data
try:
own_state[name].copy_(param)
except Exception:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_state[name].size(),
param.size()))
missing_keys = set(own_state.keys()) - set(state_dict.keys())
err_msg = []
if unexpected_keys:
err_msg.append('unexpected key in source state_dict: {}\n'.format(
', '.join(unexpected_keys)))
if missing_keys:
err_msg.append('missing keys in source state_dict: {}\n'.format(
', '.join(missing_keys)))
err_msg = '\n'.join(err_msg)
if err_msg:
if strict:
raise RuntimeError(err_msg)
elif logger is not None:
logger.warn(err_msg)
else:
print(err_msg)
def load_checkpoint(model,
filename,
map_location=None,
strict=False,
logger=None):
"""Load checkpoint from a file or URI.
Args:
model (Module): Module to load checkpoint.
filename (str): Either a filepath or URL or modelzoll://xxxxxxx.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
# load checkpoint from modelzoo or file or url
if filename.startswith('modelzoo://'):
from torchvision.models.resnet import model_urls
model_name = filename[11:]
checkpoint = model_zoo.load_url(model_urls[model_name])
elif filename.startswith(('http://', 'https://')):
checkpoint = model_zoo.load_url(filename)
else:
if not osp.isfile(filename):
raise IOError('{} is not a checkpoint file'.format(filename))
checkpoint = torch.load(filename, map_location=map_location)
# get state_dict from checkpoint
if isinstance(checkpoint, OrderedDict):
state_dict = checkpoint
elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
raise RuntimeError(
'No state_dict found in checkpoint file {}'.format(filename))
# strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}
# load state_dict
if isinstance(model, (DataParallel, DistributedDataParallel)):
load_state_dict(model.module, state_dict, strict, logger)
else:
load_state_dict(model, state_dict, strict, logger)
return checkpoint
def weights_to_cpu(state_dict):
"""Copy a model state_dict to cpu.
Args:
state_dict (OrderedDict): Model weights on GPU.
Returns:
OrderedDict: Model weights on GPU.
"""
state_dict_cpu = OrderedDict()
for key, val in state_dict.items():
state_dict_cpu[key] = val.cpu()
return state_dict_cpu
def save_checkpoint(model, filename, optimizer=None, meta=None):
"""Save checkpoint to file.
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
``optimizer``. By default ``meta`` will contain version and time info.
Args:
model (Module): Module whose params are to be saved.
filename (str): Checkpoint filename.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
meta (dict, optional): Metadata to be saved in checkpoint.
"""
if meta is None:
meta = {}
elif not isinstance(meta, dict):
raise TypeError('meta must be a dict or None, but got {}'.format(
type(meta)))
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
mmcv.mkdir_or_exist(osp.dirname(filename))
if isinstance(model, (DataParallel, DistributedDataParallel)):
model = model.module
checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(model.state_dict())
}
if optimizer is not None:
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, filename)
import multiprocessing
import torch
from .io import load_checkpoint
def worker_func(model_cls, model_kwargs, checkpoint, dataset, data_func,
gpu_id, idx_queue, result_queue):
model = model_cls(**model_kwargs)
load_checkpoint(model, checkpoint, map_location='cpu')
torch.cuda.set_device(gpu_id)
model.cuda()
model.eval()
with torch.no_grad():
while True:
idx = idx_queue.get()
data = dataset[idx]
result = model(**data_func(data, gpu_id))
result_queue.put((idx, result))
def parallel_test(model_cls,
model_kwargs,
checkpoint,
dataset,
data_func,
gpus,
workers_per_gpu=1):
"""Parallel testing on multiple GPUs.
Args:
model_cls (type): Model class type.
model_kwargs (dict): Arguments to init the model.
checkpoint (str): Checkpoint filepath.
dataset (:obj:`Dataset`): The dataset to be tested.
data_func (callable): The function that generates model inputs.
gpus (list[int]): GPU ids to be used.
workers_per_gpu (int): Number of processes on each GPU. It is possible
to run multiple workers on each GPU.
Returns:
list: Test results.
"""
ctx = multiprocessing.get_context('spawn')
idx_queue = ctx.Queue()
result_queue = ctx.Queue()
num_workers = len(gpus) * workers_per_gpu
workers = [
ctx.Process(
target=worker_func,
args=(model_cls, model_kwargs, checkpoint, dataset, data_func,
gpus[i % len(gpus)], idx_queue, result_queue))
for i in range(num_workers)
]
for w in workers:
w.daemon = True
w.start()
for i in range(len(dataset)):
idx_queue.put(i)
results = [None for _ in range(len(dataset))]
import cvbase as cvb
prog_bar = cvb.ProgressBar(task_num=len(dataset))
for _ in range(len(dataset)):
idx, res = result_queue.get()
results[idx] = res
prog_bar.update()
print('\n')
for worker in workers:
worker.terminate()
return results
from .log_buffer import LogBuffer
from .runner import Runner
from collections import OrderedDict
import numpy as np
class LogBuffer(object):
def __init__(self):
self.val_history = OrderedDict()
self.n_history = OrderedDict()
self.output = OrderedDict()
self.ready = False
def clear(self):
self.val_history.clear()
self.n_history.clear()
self.clear_output()
def clear_output(self):
self.output.clear()
self.ready = False
def update(self, vars, count=1):
assert isinstance(vars, dict)
for key, var in vars.items():
if key not in self.val_history:
self.val_history[key] = []
self.n_history[key] = []
self.val_history[key].append(var)
self.n_history[key].append(count)
def average(self, n=0):
"""Average latest n values or all values"""
assert n >= 0
for key in self.val_history:
values = np.array(self.val_history[key][-n:])
nums = np.array(self.n_history[key][-n:])
avg = np.sum(values * nums) / np.sum(nums)
self.output[key] = avg
self.ready = True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment