Unverified Commit 32e382bc authored by Hang Zhang's avatar Hang Zhang Committed by GitHub
Browse files

v0.4.3 (#71)

- ADE20K training model
- Amazon legal approval

fixes https://github.com/zhanghang1989/PyTorch-Encoding/issues/69
parent 9bc70531
...@@ -59,7 +59,7 @@ def test(args): ...@@ -59,7 +59,7 @@ def test(args):
raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume)) raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
checkpoint = torch.load(args.resume) checkpoint = torch.load(args.resume)
# strict=False, so that it is compatible with old pytorch saved models # strict=False, so that it is compatible with old pytorch saved models
model.load_state_dict(checkpoint['state_dict'], strict=False) model.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
print(model) print(model)
......
...@@ -12,6 +12,7 @@ if __name__ == "__main__": ...@@ -12,6 +12,7 @@ if __name__ == "__main__":
print(model) print(model)
model.cuda() model.cuda()
model.eval()
x = Variable(torch.Tensor(4, 3, 480, 480)).cuda() x = Variable(torch.Tensor(4, 3, 480, 480)).cuda()
with torch.no_grad(): with torch.no_grad():
out = model(x) out = model(x)
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
########################################################################### ###########################################################################
import os import os
import copy
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
...@@ -64,7 +65,8 @@ class Trainer(): ...@@ -64,7 +65,8 @@ class Trainer():
if args.ft: if args.ft:
args.start_epoch = 0 args.start_epoch = 0
# criterions # criterions
self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux, nclass=self.nclass) self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux,
nclass=self.nclass)
self.model, self.optimizer = model, optimizer self.model, self.optimizer = model, optimizer
# using cuda # using cuda
if args.cuda: if args.cuda:
...@@ -86,7 +88,8 @@ class Trainer(): ...@@ -86,7 +88,8 @@ class Trainer():
print("=> loaded checkpoint '{}' (epoch {})" print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch'])) .format(args.resume, checkpoint['epoch']))
# lr scheduler # lr scheduler
self.scheduler = utils.LR_Scheduler(args, len(self.trainloader)) self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr,
args.epochs, len(self.trainloader))
self.best_pred = 0.0 self.best_pred = 0.0
def training(self, epoch): def training(self, epoch):
...@@ -106,19 +109,21 @@ class Trainer(): ...@@ -106,19 +109,21 @@ class Trainer():
train_loss += loss.item() train_loss += loss.item()
tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
# save checkpoint every epoch if self.args.no_val:
is_best = False # save checkpoint every epoch
utils.save_checkpoint({ is_best = False
'epoch': epoch + 1, utils.save_checkpoint({
'state_dict': self.model.module.state_dict(), 'epoch': epoch + 1,
'optimizer': self.optimizer.state_dict(), 'state_dict': self.model.module.state_dict(),
'best_pred': self.best_pred, 'optimizer': self.optimizer.state_dict(),
}, self.args, is_best) 'best_pred': self.best_pred,
}, self.args, is_best)
def validation(self, epoch): def validation(self, epoch):
# Fast test during the training # Fast test during the training
def eval_batch(image, target): def eval_batch(model, image, target):
outputs = self.model(image) outputs = model(image)
outputs = gather(outputs, 0, dim=0) outputs = gather(outputs, 0, dim=0)
pred = outputs[0] pred = outputs[0]
target = target.cuda() target = target.cuda()
...@@ -133,10 +138,10 @@ class Trainer(): ...@@ -133,10 +138,10 @@ class Trainer():
for i, (image, target) in enumerate(tbar): for i, (image, target) in enumerate(tbar):
if torch_ver == "0.3": if torch_ver == "0.3":
image = Variable(image, volatile=True) image = Variable(image, volatile=True)
correct, labeled, inter, union = eval_batch(image, target) correct, labeled, inter, union = eval_batch(self.model, image, target)
else: else:
with torch.no_grad(): with torch.no_grad():
correct, labeled, inter, union = eval_batch(image, target) correct, labeled, inter, union = eval_batch(self.model, image, target)
total_correct += correct total_correct += correct
total_label += labeled total_label += labeled
......
"""Prepare MS COCO datasets"""
import os
import shutil
import argparse
import zipfile
from encoding.utils import download, mkdir
_TARGET_DIR = os.path.expanduser('~/.encoding/data')
def parse_args():
parser = argparse.ArgumentParser(
description='Initialize MS COCO dataset.',
epilog='Example: python mscoco.py --download-dir ~/mscoco',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--download-dir', type=str, default=None, help='dataset directory on disk')
args = parser.parse_args()
return args
def download_coco(path, overwrite=False):
_DOWNLOAD_URLS = [
('http://images.cocodataset.org/zips/train2017.zip',
'10ad623668ab00c62c096f0ed636d6aff41faca5'),
('http://images.cocodataset.org/annotations/annotations_trainval2017.zip',
'8551ee4bb5860311e79dace7e79cb91e432e78b3'),
('http://images.cocodataset.org/zips/val2017.zip',
'4950dc9d00dbe1c933ee0170f5797584351d2a41'),
('http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip',
'e7aa0f7515c07e23873a9f71d9095b06bcea3e12'),
('http://images.cocodataset.org/zips/test2017.zip',
'99813c02442f3c112d491ea6f30cecf421d0e6b3'),
]
mkdir(path)
for url, checksum in _DOWNLOAD_URLS:
filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum)
# extract
with zipfile.ZipFile(filename) as zf:
zf.extractall(path=path)
def install_coco_api():
repo_url = "https://github.com/cocodataset/cocoapi"
os.system("git clone " + repo_url)
os.system("cd cocoapi/PythonAPI/ && python setup.py install")
shutil.rmtree('cocoapi')
try:
import pycocotools
except Exception:
print("Installing COCO API failed, please install it manually %s"%(repo_url))
if __name__ == '__main__':
args = parse_args()
mkdir(os.path.expanduser('~/.encoding/data'))
if args.download_dir is not None:
if os.path.isdir(_TARGET_DIR):
os.remove(_TARGET_DIR)
# make symlink
os.symlink(args.download_dir, _TARGET_DIR)
else:
download_coco(_TARGET_DIR, overwrite=False)
install_coco_api()
...@@ -19,8 +19,16 @@ def parse_args(): ...@@ -19,8 +19,16 @@ def parse_args():
def download_ade(path, overwrite=False): def download_ade(path, overwrite=False):
_AUG_DOWNLOAD_URLS = [ _AUG_DOWNLOAD_URLS = [
('http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar', 'bf9985e9f2b064752bf6bd654d89f017c76c395a'), ('http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar',
('https://codalabuser.blob.core.windows.net/public/trainval_merged.json', '169325d9f7e9047537fedca7b04de4dddf10b881')] 'bf9985e9f2b064752bf6bd654d89f017c76c395a'),
('https://codalabuser.blob.core.windows.net/public/trainval_merged.json',
'169325d9f7e9047537fedca7b04de4dddf10b881'),
# You can skip these if the network is slow, the dataset will automatically generate them.
('https://hangzh.s3.amazonaws.com/encoding/data/pcontext/train.pth',
'4bfb49e8c1cefe352df876c9b5434e655c9c1d07'),
('https://hangzh.s3.amazonaws.com/encoding/data/pcontext/val.pth',
'ebedc94247ec616c57b9a2df15091784826a7b0c'),
]
download_dir = os.path.join(path, 'downloads') download_dir = os.path.join(path, 'downloads')
mkdir(download_dir) mkdir(download_dir)
for url, checksum in _AUG_DOWNLOAD_URLS: for url, checksum in _AUG_DOWNLOAD_URLS:
......
...@@ -18,7 +18,7 @@ import setuptools.command.install ...@@ -18,7 +18,7 @@ import setuptools.command.install
cwd = os.path.dirname(os.path.abspath(__file__)) cwd = os.path.dirname(os.path.abspath(__file__))
version = '0.4.2' version = '0.4.3'
try: try:
sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
cwd=cwd).decode('ascii').strip() cwd=cwd).decode('ascii').strip()
...@@ -83,6 +83,7 @@ setup( ...@@ -83,6 +83,7 @@ setup(
install_requires=requirements, install_requires=requirements,
packages=find_packages(exclude=["tests", "experiments"]), packages=find_packages(exclude=["tests", "experiments"]),
package_data={ 'encoding': [ package_data={ 'encoding': [
'LICENSE',
'lib/cpu/*.h', 'lib/cpu/*.h',
'lib/cpu/*.cpp', 'lib/cpu/*.cpp',
'lib/gpu/*.h', 'lib/gpu/*.h',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment