"vscode:/vscode.git/clone" did not exist on "52d43127f79b8d9294f6712cf899d5c8346cf30a"
Unverified Commit ce461dae authored by Hang Zhang's avatar Hang Zhang Committed by GitHub
Browse files

V1.0.0 (#156)

* v1.0
parent c2cb2aab
import encoding
import shutil
encoding.models.get_model_file('deepten_minc', root='./')
shutil.move('deepten_minc-2e22611a.pth', 'deepten_minc.pth')
...@@ -20,6 +20,8 @@ class Options(): ...@@ -20,6 +20,8 @@ class Options():
# model params # model params
parser.add_argument('--model', type=str, default='densenet', parser.add_argument('--model', type=str, default='densenet',
help='network model type (default: densenet)') help='network model type (default: densenet)')
parser.add_argument('--pretrained', action='store_true',
default=False, help='load pretrianed mode')
parser.add_argument('--nclass', type=int, default=10, metavar='N', parser.add_argument('--nclass', type=int, default=10, metavar='N',
help='number of classes (default: 10)') help='number of classes (default: 10)')
parser.add_argument('--widen', type=int, default=4, metavar='N', parser.add_argument('--widen', type=int, default=4, metavar='N',
...@@ -36,7 +38,9 @@ class Options(): ...@@ -36,7 +38,9 @@ class Options():
parser.add_argument('--epochs', type=int, default=600, metavar='N', parser.add_argument('--epochs', type=int, default=600, metavar='N',
help='number of epochs to train (default: 600)') help='number of epochs to train (default: 600)')
parser.add_argument('--start_epoch', type=int, default=1, parser.add_argument('--start_epoch', type=int, default=1,
metavar='N', help='the epoch number to start (default: 0)') metavar='N', help='the epoch number to start (default: 1)')
parser.add_argument('--workers', type=int, default=16,
metavar='N', help='dataloader threads')
# lr setting # lr setting
parser.add_argument('--lr', type=float, default=0.1, metavar='LR', parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
help='learning rate (default: 0.1)') help='learning rate (default: 0.1)')
...@@ -47,8 +51,8 @@ class Options(): ...@@ -47,8 +51,8 @@ class Options():
# optimizer # optimizer
parser.add_argument('--momentum', type=float, default=0.9, parser.add_argument('--momentum', type=float, default=0.9,
metavar='M', help='SGD momentum (default: 0.9)') metavar='M', help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=5e-4, parser.add_argument('--weight-decay', type=float, default=1e-4,
metavar ='M', help='SGD weight decay (default: 5e-4)') metavar ='M', help='SGD weight decay (default: 1e-4)')
# cuda, seed and logging # cuda, seed and logging
parser.add_argument('--no-cuda', action='store_true', parser.add_argument('--no-cuda', action='store_true',
default=False, help='disables CUDA training') default=False, help='disables CUDA training')
......
...@@ -30,7 +30,7 @@ class Options(): ...@@ -30,7 +30,7 @@ class Options():
parser.add_argument('--crop-size', type=int, default=480, parser.add_argument('--crop-size', type=int, default=480,
help='crop image size') help='crop image size')
parser.add_argument('--train-split', type=str, default='train', parser.add_argument('--train-split', type=str, default='train',
help='dataset train split (default: train)') help='dataset train split (default: train)')
# training hyper params # training hyper params
parser.add_argument('--aux', action='store_true', default= False, parser.add_argument('--aux', action='store_true', default= False,
help='Auxilary Loss') help='Auxilary Loss')
...@@ -44,10 +44,10 @@ class Options(): ...@@ -44,10 +44,10 @@ class Options():
help='number of epochs to train (default: auto)') help='number of epochs to train (default: auto)')
parser.add_argument('--start_epoch', type=int, default=0, parser.add_argument('--start_epoch', type=int, default=0,
metavar='N', help='start epochs (default:0)') metavar='N', help='start epochs (default:0)')
parser.add_argument('--batch-size', type=int, default=None, parser.add_argument('--batch-size', type=int, default=16,
metavar='N', help='input batch size for \ metavar='N', help='input batch size for \
training (default: auto)') training (default: auto)')
parser.add_argument('--test-batch-size', type=int, default=None, parser.add_argument('--test-batch-size', type=int, default=16,
metavar='N', help='input batch size for \ metavar='N', help='input batch size for \
testing (default: same as batch size)') testing (default: same as batch size)')
# optimizer params # optimizer params
...@@ -77,6 +77,8 @@ class Options(): ...@@ -77,6 +77,8 @@ class Options():
# evaluation option # evaluation option
parser.add_argument('--eval', action='store_true', default= False, parser.add_argument('--eval', action='store_true', default= False,
help='evaluating mIoU') help='evaluating mIoU')
parser.add_argument('--test-val', action='store_true', default= False,
help='generate masks on val set')
parser.add_argument('--no-val', action='store_true', default= False, parser.add_argument('--no-val', action='store_true', default= False,
help='skip validation during training') help='skip validation during training')
# test option # test option
...@@ -92,25 +94,21 @@ class Options(): ...@@ -92,25 +94,21 @@ class Options():
if args.epochs is None: if args.epochs is None:
epoches = { epoches = {
'coco': 30, 'coco': 30,
'citys': 240, 'pascal_aug': 80,
'pascal_voc': 50, 'pascal_voc': 50,
'pascal_aug': 50,
'pcontext': 80, 'pcontext': 80,
'ade20k': 120, 'ade20k': 180,
'citys': 240,
} }
args.epochs = epoches[args.dataset.lower()] args.epochs = epoches[args.dataset.lower()]
if args.batch_size is None:
args.batch_size = 16
if args.test_batch_size is None:
args.test_batch_size = args.batch_size
if args.lr is None: if args.lr is None:
lrs = { lrs = {
'coco': 0.01, 'coco': 0.004,
'citys': 0.01,
'pascal_voc': 0.0001,
'pascal_aug': 0.001, 'pascal_aug': 0.001,
'pascal_voc': 0.0001,
'pcontext': 0.001, 'pcontext': 0.001,
'ade20k': 0.01, 'ade20k': 0.004,
'citys': 0.004,
} }
args.lr = lrs[args.dataset.lower()] / 16 * args.batch_size args.lr = lrs[args.dataset.lower()] / 16 * args.batch_size
print(args) print(args)
......
...@@ -14,7 +14,7 @@ import torchvision.transforms as transform ...@@ -14,7 +14,7 @@ import torchvision.transforms as transform
from torch.nn.parallel.scatter_gather import gather from torch.nn.parallel.scatter_gather import gather
import encoding.utils as utils import encoding.utils as utils
from encoding.nn import SegmentationLosses, BatchNorm2d from encoding.nn import SegmentationLosses, SyncBatchNorm
from encoding.parallel import DataParallelModel, DataParallelCriterion from encoding.parallel import DataParallelModel, DataParallelCriterion
from encoding.datasets import get_segmentation_dataset, test_batchify_fn from encoding.datasets import get_segmentation_dataset, test_batchify_fn
from encoding.models import get_model, get_segmentation_model, MultiEvalModule from encoding.models import get_model, get_segmentation_model, MultiEvalModule
...@@ -34,6 +34,9 @@ def test(args): ...@@ -34,6 +34,9 @@ def test(args):
if args.eval: if args.eval:
testset = get_segmentation_dataset(args.dataset, split='val', mode='testval', testset = get_segmentation_dataset(args.dataset, split='val', mode='testval',
transform=input_transform) transform=input_transform)
elif args.test_val:
testset = get_segmentation_dataset(args.dataset, split='val', mode='test',
transform=input_transform)
else: else:
testset = get_segmentation_dataset(args.dataset, split='test', mode='test', testset = get_segmentation_dataset(args.dataset, split='test', mode='test',
transform=input_transform) transform=input_transform)
...@@ -46,10 +49,12 @@ def test(args): ...@@ -46,10 +49,12 @@ def test(args):
# model # model
if args.model_zoo is not None: if args.model_zoo is not None:
model = get_model(args.model_zoo, pretrained=True) model = get_model(args.model_zoo, pretrained=True)
#model.base_size = args.base_size
#model.crop_size = args.crop_size
else: else:
model = get_segmentation_model(args.model, dataset=args.dataset, model = get_segmentation_model(args.model, dataset=args.dataset,
backbone = args.backbone, aux = args.aux, backbone = args.backbone, aux = args.aux,
se_loss = args.se_loss, norm_layer = BatchNorm2d, se_loss = args.se_loss, norm_layer = SyncBatchNorm,
base_size=args.base_size, crop_size=args.crop_size) base_size=args.base_size, crop_size=args.crop_size)
# resuming checkpoint # resuming checkpoint
if args.resume is None or not os.path.isfile(args.resume): if args.resume is None or not os.path.isfile(args.resume):
...@@ -60,8 +65,8 @@ def test(args): ...@@ -60,8 +65,8 @@ def test(args):
print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
print(model) print(model)
scales = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] if args.dataset == 'citys' else \ scales = [0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] if args.dataset == 'citys' else \
[0.5, 0.75, 1.0, 1.25, 1.5, 1.75] [0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0]
evaluator = MultiEvalModule(model, testset.num_class, scales=scales).cuda() evaluator = MultiEvalModule(model, testset.num_class, scales=scales).cuda()
evaluator.eval() evaluator.eval()
metric = utils.SegmentationMetric(testset.num_class) metric = utils.SegmentationMetric(testset.num_class)
......
...@@ -7,6 +7,7 @@ from torch.autograd import Variable ...@@ -7,6 +7,7 @@ from torch.autograd import Variable
if __name__ == "__main__": if __name__ == "__main__":
args = Options().parse() args = Options().parse()
model = encoding.models.get_segmentation_model(args.model, dataset=args.dataset, aux=args.aux, model = encoding.models.get_segmentation_model(args.model, dataset=args.dataset, aux=args.aux,
backbone=args.backbone,
se_loss=args.se_loss, norm_layer=torch.nn.BatchNorm2d) se_loss=args.se_loss, norm_layer=torch.nn.BatchNorm2d)
print('Creating the model:') print('Creating the model:')
......
...@@ -15,9 +15,9 @@ import torchvision.transforms as transform ...@@ -15,9 +15,9 @@ import torchvision.transforms as transform
from torch.nn.parallel.scatter_gather import gather from torch.nn.parallel.scatter_gather import gather
import encoding.utils as utils import encoding.utils as utils
from encoding.nn import SegmentationLosses, BatchNorm2d from encoding.nn import SegmentationLosses, SyncBatchNorm, OHEMSegmentationLosses
from encoding.parallel import DataParallelModel, DataParallelCriterion from encoding.parallel import DataParallelModel, DataParallelCriterion
from encoding.datasets import get_segmentation_dataset from encoding.datasets import get_dataset
from encoding.models import get_segmentation_model from encoding.models import get_segmentation_model
from option import Options from option import Options
...@@ -36,9 +36,9 @@ class Trainer(): ...@@ -36,9 +36,9 @@ class Trainer():
# dataset # dataset
data_kwargs = {'transform': input_transform, 'base_size': args.base_size, data_kwargs = {'transform': input_transform, 'base_size': args.base_size,
'crop_size': args.crop_size} 'crop_size': args.crop_size}
trainset = get_segmentation_dataset(args.dataset, split=args.train_split, mode='train', trainset = get_dataset(args.dataset, split=args.train_split, mode='train',
**data_kwargs) **data_kwargs)
testset = get_segmentation_dataset(args.dataset, split='val', mode ='val', testset = get_dataset(args.dataset, split='val', mode ='val',
**data_kwargs) **data_kwargs)
# dataloader # dataloader
kwargs = {'num_workers': args.workers, 'pin_memory': True} \ kwargs = {'num_workers': args.workers, 'pin_memory': True} \
...@@ -51,7 +51,7 @@ class Trainer(): ...@@ -51,7 +51,7 @@ class Trainer():
# model # model
model = get_segmentation_model(args.model, dataset=args.dataset, model = get_segmentation_model(args.model, dataset=args.dataset,
backbone = args.backbone, aux = args.aux, backbone = args.backbone, aux = args.aux,
se_loss = args.se_loss, norm_layer = BatchNorm2d, se_loss = args.se_loss, norm_layer = SyncBatchNorm,
base_size=args.base_size, crop_size=args.crop_size) base_size=args.base_size, crop_size=args.crop_size)
print(model) print(model)
# optimizer using different LR # optimizer using different LR
...@@ -63,7 +63,8 @@ class Trainer(): ...@@ -63,7 +63,8 @@ class Trainer():
optimizer = torch.optim.SGD(params_list, lr=args.lr, optimizer = torch.optim.SGD(params_list, lr=args.lr,
momentum=args.momentum, weight_decay=args.weight_decay) momentum=args.momentum, weight_decay=args.weight_decay)
# criterions # criterions
self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux, self.criterion = SegmentationLosses(se_loss=args.se_loss,
aux=args.aux,
nclass=self.nclass, nclass=self.nclass,
se_weight=args.se_weight, se_weight=args.se_weight,
aux_weight=args.aux_weight) aux_weight=args.aux_weight)
...@@ -160,12 +161,12 @@ class Trainer(): ...@@ -160,12 +161,12 @@ class Trainer():
if new_pred > self.best_pred: if new_pred > self.best_pred:
is_best = True is_best = True
self.best_pred = new_pred self.best_pred = new_pred
utils.save_checkpoint({ utils.save_checkpoint({
'epoch': epoch + 1, 'epoch': epoch + 1,
'state_dict': self.model.module.state_dict(), 'state_dict': self.model.module.state_dict(),
'optimizer': self.optimizer.state_dict(), 'optimizer': self.optimizer.state_dict(),
'best_pred': self.best_pred, 'best_pred': self.best_pred,
}, self.args, is_best) }, self.args, is_best)
if __name__ == "__main__": if __name__ == "__main__":
...@@ -174,7 +175,10 @@ if __name__ == "__main__": ...@@ -174,7 +175,10 @@ if __name__ == "__main__":
trainer = Trainer(args) trainer = Trainer(args)
print('Starting Epoch:', trainer.args.start_epoch) print('Starting Epoch:', trainer.args.start_epoch)
print('Total Epoches:', trainer.args.epochs) print('Total Epoches:', trainer.args.epochs)
for epoch in range(trainer.args.start_epoch, trainer.args.epochs): if args.eval:
trainer.training(epoch) trainer.validation(trainer.args.start_epoch)
if not trainer.args.no_val: else:
trainer.validation(epoch) for epoch in range(trainer.args.start_epoch, trainer.args.epochs):
trainer.training(epoch)
if not trainer.args.no_val:
trainer.validation(epoch)
"""Prepare ADE20K dataset"""
import os
import shutil
import argparse
import zipfile
from encoding.utils import download, mkdir, check_sha1
_TARGET_DIR = os.path.expanduser('~/.encoding/data')
def parse_args():
parser = argparse.ArgumentParser(
description='Initialize ADE20K dataset.',
epilog='Example: python prepare_cityscapes.py',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--download-dir', default=None, help='dataset directory on disk')
args = parser.parse_args()
return args
def download_city(path, overwrite=False):
_CITY_DOWNLOAD_URLS = [
#('gtCoarse.zip', '61f23198bfff5286e0d7e316ad5c4dbbaaf4717a'),
('gtFine_trainvaltest.zip', '99f532cb1af174f5fcc4c5bc8feea8c66246ddbc'),
('leftImg8bit_trainvaltest.zip', '2c0b77ce9933cc635adda307fbba5566f5d9d404')]
download_dir = os.path.join(path, 'downloads')
mkdir(download_dir)
for filename, checksum in _CITY_DOWNLOAD_URLS:
if not check_sha1(filename, checksum):
raise UserWarning('File {} is downloaded but the content hash does not match. ' \
'The repo may be outdated or download may be incomplete. ' \
'If the "repo_url" is overridden, consider switching to ' \
'the default repo.'.format(filename))
# extract
with zipfile.ZipFile(filename,"r") as zip_ref:
zip_ref.extractall(path=path)
print("Extracted", filename)
if __name__ == '__main__':
args = parse_args()
mkdir(os.path.expanduser('~/.encoding/data'))
mkdir(os.path.expanduser('~/.encoding/data/cityscapes'))
if args.download_dir is not None:
if os.path.isdir(_TARGET_DIR):
os.remove(_TARGET_DIR)
# make symlink
os.symlink(args.download_dir, _TARGET_DIR)
else:
download_city(_TARGET_DIR, overwrite=False)
...@@ -20,21 +20,28 @@ def download_coco(path, overwrite=False): ...@@ -20,21 +20,28 @@ def download_coco(path, overwrite=False):
_DOWNLOAD_URLS = [ _DOWNLOAD_URLS = [
('http://images.cocodataset.org/zips/train2017.zip', ('http://images.cocodataset.org/zips/train2017.zip',
'10ad623668ab00c62c096f0ed636d6aff41faca5'), '10ad623668ab00c62c096f0ed636d6aff41faca5'),
('http://images.cocodataset.org/annotations/annotations_trainval2017.zip',
'8551ee4bb5860311e79dace7e79cb91e432e78b3'),
('http://images.cocodataset.org/zips/val2017.zip', ('http://images.cocodataset.org/zips/val2017.zip',
'4950dc9d00dbe1c933ee0170f5797584351d2a41'), '4950dc9d00dbe1c933ee0170f5797584351d2a41'),
('http://images.cocodataset.org/annotations/annotations_trainval2017.zip',
'8551ee4bb5860311e79dace7e79cb91e432e78b3'),
#('http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip', #('http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip',
# '46cdcf715b6b4f67e980b529534e79c2edffe084'), # '46cdcf715b6b4f67e980b529534e79c2edffe084'),
#('http://images.cocodataset.org/zips/test2017.zip', #('http://images.cocodataset.org/zips/test2017.zip',
# '99813c02442f3c112d491ea6f30cecf421d0e6b3'), # '99813c02442f3c112d491ea6f30cecf421d0e6b3'),
('https://hangzh.s3.amazonaws.com/encoding/data/coco/train_ids.pth',
'12cd266f97c8d9ea86e15a11f11bcb5faba700b6'),
('https://hangzh.s3.amazonaws.com/encoding/data/coco/val_ids.pth',
'4ce037ac33cbf3712fd93280a1c5e92dae3136bb'),
] ]
mkdir(path) mkdir(path)
for url, checksum in _DOWNLOAD_URLS: for url, checksum in _DOWNLOAD_URLS:
filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum) filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum)
# extract # extract
with zipfile.ZipFile(filename) as zf: if os.path.splitext(filename)[1] == '.zip':
zf.extractall(path=path) with zipfile.ZipFile(filename) as zf:
zf.extractall(path=path)
else:
shutil.move(filename, os.path.join(path, 'annotations/'+os.path.basename(filename)))
def install_coco_api(): def install_coco_api():
......
"""Prepare the ImageNet dataset"""
import os
import argparse
import tarfile
import pickle
import gzip
import subprocess
from tqdm import tqdm
from encoding.utils import check_sha1, download, mkdir
_TARGET_DIR = os.path.expanduser('~/.encoding/datasets/imagenet')
_TRAIN_TAR = 'ILSVRC2012_img_train.tar'
_TRAIN_TAR_SHA1 = '43eda4fe35c1705d6606a6a7a633bc965d194284'
_VAL_TAR = 'ILSVRC2012_img_val.tar'
_VAL_TAR_SHA1 = '5f3f73da3395154b60528b2b2a2caf2374f5f178'
def parse_args():
parser = argparse.ArgumentParser(
description='Setup the ImageNet dataset.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--download-dir', required=True,
help="The directory that contains downloaded tar files")
parser.add_argument('--target-dir', default=_TARGET_DIR,
help="The directory to store extracted images")
parser.add_argument('--checksum', action='store_true',
help="If check integrity before extracting.")
parser.add_argument('--with-rec', action='store_true',
help="If build image record files.")
parser.add_argument('--num-thread', type=int, default=1,
help="Number of threads to use when building image record file.")
args = parser.parse_args()
return args
def check_file(filename, checksum, sha1):
if not os.path.exists(filename):
raise ValueError('File not found: '+filename)
if checksum and not check_sha1(filename, sha1):
raise ValueError('Corrupted file: '+filename)
def build_rec_process(img_dir, train=False, num_thread=1):
rec_dir = os.path.abspath(os.path.join(img_dir, '../rec'))
mkdir(rec_dir)
prefix = 'train' if train else 'val'
print('Building ImageRecord file for ' + prefix + ' ...')
to_path = rec_dir
# download lst file and im2rec script
script_path = os.path.join(rec_dir, 'im2rec.py')
script_url = 'https://raw.githubusercontent.com/apache/incubator-encoding/master/tools/im2rec.py'
download(script_url, script_path)
lst_path = os.path.join(rec_dir, prefix + '.lst')
lst_url = 'http://data.encoding.io/models/imagenet/resnet/' + prefix + '.lst'
download(lst_url, lst_path)
# execution
import sys
cmd = [
sys.executable,
script_path,
rec_dir,
img_dir,
'--recursive',
'--pass-through',
'--pack-label',
'--num-thread',
str(num_thread)
]
subprocess.call(cmd)
os.remove(script_path)
os.remove(lst_path)
print('ImageRecord file for ' + prefix + ' has been built!')
def extract_train(tar_fname, target_dir, with_rec=False, num_thread=1):
os.makedirs(target_dir)
with tarfile.open(tar_fname) as tar:
print("Extracting "+tar_fname+"...")
# extract each class one-by-one
pbar = tqdm(total=len(tar.getnames()))
for class_tar in tar:
pbar.set_description('Extract '+class_tar.name)
tar.extract(class_tar, target_dir)
class_fname = os.path.join(target_dir, class_tar.name)
class_dir = os.path.splitext(class_fname)[0]
os.mkdir(class_dir)
with tarfile.open(class_fname) as f:
f.extractall(class_dir)
os.remove(class_fname)
pbar.update(1)
pbar.close()
if with_rec:
build_rec_process(target_dir, True, num_thread)
def extract_val(tar_fname, target_dir, with_rec=False, num_thread=1):
os.makedirs(target_dir)
print('Extracting ' + tar_fname)
with tarfile.open(tar_fname) as tar:
tar.extractall(target_dir)
# build rec file before images are moved into subfolders
if with_rec:
build_rec_process(target_dir, False, num_thread)
# move images to proper subfolders
val_maps_file = os.path.join(os.path.dirname(__file__), 'imagenet_val_maps.pklz')
with gzip.open(val_maps_file, 'rb') as f:
dirs, mappings = pickle.load(f)
for d in dirs:
os.makedirs(os.path.join(target_dir, d))
for m in mappings:
os.rename(os.path.join(target_dir, m[0]), os.path.join(target_dir, m[1], m[0]))
def main():
args = parse_args()
target_dir = os.path.expanduser(args.target_dir)
if os.path.exists(target_dir):
raise ValueError('Target dir ['+target_dir+'] exists. Remove it first')
download_dir = os.path.expanduser(args.download_dir)
train_tar_fname = os.path.join(download_dir, _TRAIN_TAR)
check_file(train_tar_fname, args.checksum, _TRAIN_TAR_SHA1)
val_tar_fname = os.path.join(download_dir, _VAL_TAR)
check_file(val_tar_fname, args.checksum, _VAL_TAR_SHA1)
build_rec = args.with_rec
if build_rec:
os.makedirs(os.path.join(target_dir, 'rec'))
extract_train(train_tar_fname, os.path.join(target_dir, 'train'), build_rec, args.num_thread)
extract_val(val_tar_fname, os.path.join(target_dir, 'val'), build_rec, args.num_thread)
if __name__ == '__main__':
main()
...@@ -23,7 +23,6 @@ def download_ade(path, overwrite=False): ...@@ -23,7 +23,6 @@ def download_ade(path, overwrite=False):
'bf9985e9f2b064752bf6bd654d89f017c76c395a'), 'bf9985e9f2b064752bf6bd654d89f017c76c395a'),
('https://codalabuser.blob.core.windows.net/public/trainval_merged.json', ('https://codalabuser.blob.core.windows.net/public/trainval_merged.json',
'169325d9f7e9047537fedca7b04de4dddf10b881'), '169325d9f7e9047537fedca7b04de4dddf10b881'),
# You can skip these if the network is slow, the dataset will automatically generate them.
('https://hangzh.s3.amazonaws.com/encoding/data/pcontext/train.pth', ('https://hangzh.s3.amazonaws.com/encoding/data/pcontext/train.pth',
'4bfb49e8c1cefe352df876c9b5434e655c9c1d07'), '4bfb49e8c1cefe352df876c9b5434e655c9c1d07'),
('https://hangzh.s3.amazonaws.com/encoding/data/pcontext/val.pth', ('https://hangzh.s3.amazonaws.com/encoding/data/pcontext/val.pth',
......
...@@ -18,7 +18,7 @@ import setuptools.command.install ...@@ -18,7 +18,7 @@ import setuptools.command.install
cwd = os.path.dirname(os.path.abspath(__file__)) cwd = os.path.dirname(os.path.abspath(__file__))
version = '0.5.1' version = '1.0.1'
try: try:
sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
cwd=cwd).decode('ascii').strip() cwd=cwd).decode('ascii').strip()
......
...@@ -173,13 +173,13 @@ def test_encoding_dist_inference(): ...@@ -173,13 +173,13 @@ def test_encoding_dist_inference():
test = gradcheck(encoding.functions.encoding_dist_inference, input, eps=EPS, atol=ATOL) test = gradcheck(encoding.functions.encoding_dist_inference, input, eps=EPS, atol=ATOL)
print('Testing encoding_dist_inference(): {}'.format(test)) print('Testing encoding_dist_inference(): {}'.format(test))
def test_sum_square(): def test_moments():
B,C,H = 2,3,4 B,C,H = 2,3,4
X = Variable(torch.cuda.DoubleTensor(B,C,H).uniform_(-0.5,0.5), X = Variable(torch.cuda.DoubleTensor(B,C,H).uniform_(-0.5,0.5),
requires_grad=True) requires_grad=True)
input = (X,) input = (X,)
test = gradcheck(encoding.functions.sum_square, input, eps=EPS, atol=ATOL) test = gradcheck(encoding.functions.moments, input, eps=EPS, atol=ATOL)
print('Testing sum_square(): {}'.format(test)) print('Testing moments(): {}'.format(test))
def test_syncbn_func(): def test_syncbn_func():
# generate input # generate input
......
...@@ -49,7 +49,7 @@ def testSyncBN(): ...@@ -49,7 +49,7 @@ def testSyncBN():
def _find_bn(module): def _find_bn(module):
for m in module.modules(): for m in module.modules():
if isinstance(m, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, if isinstance(m, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d,
encoding.nn.BatchNorm1d, encoding.nn.BatchNorm2d)): encoding.nn.SyncBatchNorm)):
return m return m
def _syncParameters(bn1, bn2): def _syncParameters(bn1, bn2):
bn1.reset_parameters() bn1.reset_parameters()
...@@ -70,29 +70,128 @@ def testSyncBN(): ...@@ -70,29 +70,128 @@ def testSyncBN():
input1 = Variable(input.clone().detach(), requires_grad=True) input1 = Variable(input.clone().detach(), requires_grad=True)
input2 = Variable(input.clone().detach(), requires_grad=True) input2 = Variable(input.clone().detach(), requires_grad=True)
output1 = bn1(input1) if is_train:
output2 = bn2(input2) bn1.train()
bn2.train()
output1 = bn1(input1)
output2 = bn2(input2)
else:
bn1.eval()
bn2.eval()
with torch.no_grad():
output1 = bn1(input1)
output2 = bn2(input2)
# assert forwarding # assert forwarding
_assert_tensor_close(input1.data, input2.data) #_assert_tensor_close(input1.data, input2.data)
_assert_tensor_close(output1.data, output2.data) _assert_tensor_close(output1.data, output2.data)
if not is_train: if not is_train:
return return
(output1 ** 2).sum().backward() (output1 ** 2).sum().backward()
(output2 ** 2).sum().backward() (output2 ** 2).sum().backward()
_assert_tensor_close(_find_bn(bn1).bias.grad.data, _find_bn(bn2).bias.grad.data)
_assert_tensor_close(_find_bn(bn1).weight.grad.data, _find_bn(bn2).weight.grad.data)
_assert_tensor_close(input1.grad.data, input2.grad.data) _assert_tensor_close(input1.grad.data, input2.grad.data)
_assert_tensor_close(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) _assert_tensor_close(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
_assert_tensor_close(_find_bn(bn1).running_var, _find_bn(bn2).running_var) #_assert_tensor_close(_find_bn(bn1).running_var, _find_bn(bn2).running_var)
bn = torch.nn.BatchNorm2d(10).cuda().double() bn = torch.nn.BatchNorm2d(10).cuda().double()
sync_bn = encoding.nn.BatchNorm2d(10).double() sync_bn = encoding.nn.SyncBatchNorm(10, inplace=True, sync=True).cuda().double()
sync_bn = torch.nn.DataParallel(sync_bn).cuda() sync_bn = torch.nn.DataParallel(sync_bn).cuda()
encoding.parallel.patch_replication_callback(sync_bn)
# check with unsync version # check with unsync version
#_check_batchnorm_result(bn, sync_bn, torch.rand(2, 1, 2, 2).double(), True, cuda=True)
for i in range(10): for i in range(10):
print(i) print(i)
_check_batchnorm_result(bn, sync_bn, torch.rand(16, 10, 16, 16).double(), True, cuda=True) _check_batchnorm_result(bn, sync_bn, torch.rand(16, 10, 16, 16).double(), True, cuda=True)
_check_batchnorm_result(bn, sync_bn, torch.rand(16, 10, 16, 16).double(), False, cuda=True) #_check_batchnorm_result(bn, sync_bn, torch.rand(16, 10, 16, 16).double(), False, cuda=True)
def testABN():
class NormAct(torch.nn.BatchNorm2d):
def __init__(self, num_features, eps=1e-5, momentum=0.1, sync=True, activation="none",
slope=0.01):
super(NormAct, self).__init__(num_features, eps=eps, momentum=momentum, affine=True)
self.slope = slope
def forward(self, x):
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
else: # use exponential moving average
exponential_average_factor = self.momentum
y = torch.nn.functional.batch_norm(
x, self.running_mean, self.running_var, self.weight, self.bias,
self.training or not self.track_running_stats,
exponential_average_factor, self.eps)
return torch.nn.functional.leaky_relu_(y, self.slope)
def _check_batchnorm_result(bn1, bn2, input, is_train, cuda=False):
def _find_bn(module):
for m in module.modules():
if isinstance(m, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d,
encoding.nn.SyncBatchNorm)):
return m
def _syncParameters(bn1, bn2):
bn1.reset_parameters()
bn2.reset_parameters()
if bn1.affine and bn2.affine:
bn2.weight.data.copy_(bn1.weight.data)
bn2.bias.data.copy_(bn1.bias.data)
bn2.running_mean.copy_(bn1.running_mean)
bn2.running_var.copy_(bn1.running_var)
bn1.train(mode=is_train)
bn2.train(mode=is_train)
if cuda:
input = input.cuda()
# using the same values for gamma and beta
_syncParameters(_find_bn(bn1), _find_bn(bn2))
input1 = Variable(input.clone().detach(), requires_grad=True)
input2 = Variable(input.clone().detach(), requires_grad=True)
if is_train:
bn1.train()
bn2.train()
output1 = bn1(input1)
output2 = bn2(input2)
else:
bn1.eval()
bn2.eval()
with torch.no_grad():
output1 = bn1(input1)
output2 = bn2(input2)
# assert forwarding
_assert_tensor_close(output1.data, output2.data)
if not is_train:
return
loss1 = (output1 ** 2).sum()
loss2 = (output2 ** 2).sum()
loss1.backward()
loss2.backward()
_assert_tensor_close(_find_bn(bn1).bias.grad.data, _find_bn(bn2).bias.grad.data)
_assert_tensor_close(_find_bn(bn1).weight.grad.data, _find_bn(bn2).weight.grad.data)
_assert_tensor_close(input1.grad.data, input2.grad.data)
_assert_tensor_close(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
bn = NormAct(10).cuda().double()
inp_abn = encoding.nn.SyncBatchNorm(10, sync=False, activation='leaky_relu', inplace=True).cuda().double()
inp_abn = torch.nn.DataParallel(inp_abn).cuda()
# check with unsync version
for i in range(10):
print(i)
_check_batchnorm_result(bn, inp_abn, torch.rand(16, 10, 16, 16).double(), True, cuda=True)
#_check_batchnorm_result(bn, inp_abn, torch.rand(16, 10, 16, 16).double(), False, cuda=True)
def test_Atten_Module():
B, C, H, W = 8, 24, 10, 10
X = Variable(torch.cuda.DoubleTensor(B,C,H,W).uniform_(-0.5,0.5),
requires_grad=True)
layer1 = encoding.nn.MultiHeadAttention(4, 24, 24, 24).double().cuda()
Y = layer1(X)
if __name__ == '__main__': if __name__ == '__main__':
import nose import nose
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment