##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ## Created by: Hang Zhang ## ECE Department, Rutgers University ## Email: zhang.hang@rutgers.edu ## Copyright (c) 2017 ## ## This source code is licensed under the MIT-style license found in the ## LICENSE file in the root directory of this source tree ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ """Encoding Util Tools""" import os import errno import requests import shutil import hashlib import math from tqdm import tqdm import numpy as np import torch __all__ = ['LR_Scheduler', 'save_checkpoint', 'batch_pix_accuracy', 'batch_intersection_union', 'download', 'mkdir', 'check_sha1'] class LR_Scheduler(object): """Learning Rate Scheduler Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}`` Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))`` Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9`` Args: args: :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`), :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs, :attr:`args.lr_step` niters: number of iterations per epoch """ def __init__(self, args, niters=0): self.mode = args.lr_scheduler print('Using {} LR Scheduler!'.format(self.mode)) self.lr = args.lr if self.mode == 'step': self.lr_step = args.lr_step else: self.niters = niters self.N = args.epochs * niters self.epoch = -1 def __call__(self, optimizer, i, epoch, best_pred): if self.mode == 'cos': T = (epoch - 1) * self.niters + i lr = 0.5 * self.lr * (1 + math.cos(1.0 * T / self.N * math.pi)) elif self.mode == 'poly': T = (epoch - 1) * self.niters + i lr = self.lr * pow((1 - 1.0 * T / self.N), 0.9) elif self.mode == 'step': lr = self.lr * (0.1 ** ((epoch - 1) // self.lr_step)) else: raise RuntimeError('Unknown LR scheduler!') if epoch > self.epoch: print('\n=>Epoches %i, learning rate = %.4f, \ previous best = %.4f' % (epoch, lr, best_pred)) self.epoch = epoch self._adjust_learning_rate(optimizer, lr) def _adjust_learning_rate(self, optimizer, lr): if len(optimizer.param_groups) == 1: optimizer.param_groups[0]['lr'] = lr else: # enlarge the lr at the head optimizer.param_groups[0]['lr'] = lr for i in range(1, len(optimizer.param_groups)): optimizer.param_groups[i]['lr'] = lr * 10 # refer to https://github.com/xternalz/WideResNet-pytorch def save_checkpoint(state, args, is_best, filename='checkpoint.pth.tar'): """Saves checkpoint to disk""" directory = "runs/%s/%s/%s/"%(args.dataset, args.model, args.checkname) if not os.path.exists(directory): os.makedirs(directory) filename = directory + filename torch.save(state, filename) if is_best: shutil.copyfile(filename, directory + 'model_best.pth.tar') def batch_pix_accuracy(predict, target): """Batch Pixel Accuracy Args: predict: input 4D tensor target: label 3D tensor """ _, predict = torch.max(predict, 1) # pixel_labeled = (target >= 0).sum().item() # TODO currently torch.eq is not working as expected, change back when it's fixed # pixel_correct = torch.eq(predict, target).sum().item() predict = predict.cpu().numpy() target = target.cpu().numpy() pixel_labeled = np.sum(target >= 0) pixel_correct = np.sum((predict == target)*(target >= 0)) assert(pixel_correct <= pixel_labeled) return pixel_correct, pixel_labeled def batch_intersection_union(predict, target, nclass): """Batch Intersection of Union Args: predict: input 4D tensor target: label 3D tensor nclass: number of categories (int) """ _, predict = torch.max(predict, 1) mini = 0 maxi = nclass - 1 nbins = nclass """ predict = predict.cpu().numpy() target = target.cpu().numpy() predict = predict * (target >= 0).astype(predict.dtype) intersection = predict * (predict == target) # areas of intersection and union area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi)) area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi)) area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi)) area_union = area_pred + area_lab - area_inter # Somehow PyTorch update break this, will change back if fixed """ predict = predict * (target >= 0).type_as(predict) intersection = predict * (predict == target).type_as(predict) area_inter = torch.histc(intersection.cpu().float(), bins=nclass, min=mini, max=maxi) area_pred = torch.histc(predict.cpu().float(), bins=nclass, min=mini, max=maxi) area_lab = torch.histc(target.cpu().float(), bins=nclass, min=mini, max=maxi) area_union = area_pred + area_lab - area_inter return area_inter, area_union def get_selabel_vector(target, nclass): """Get SE-Loss Label in a batch Args: predict: input 4D tensor target: label 3D tensor (BxHxW) nclass: number of categories (int) Output: 2D tensor (BxnClass) """ batch = target.size(0) tvect = torch.zeros(batch, nclass) for i in range(batch): hist = torch.histc(target[i].data.float(), bins=nclass, min=0, max=nclass-1) vect = hist>0 tvect[i] = vect return tvect def get_mask_pallete(npimg, dataset='detail'): """Get image color pallete for visualizing masks""" # recovery boundary if dataset == 'pascal_voc': npimg[npimg==21] = 255 # put colormap out_img = Image.fromarray(npimg.astype('uint8')) if dataset == 'ade20k': out_img.putpalette(adepallete) elif dataset == 'cityscapes': out_img.putpalette(citypallete) else: out_img.putpalette(vocpallete) return out_img def download(url, path=None, overwrite=False, sha1_hash=None): """Download an given URL Parameters ---------- url : str URL to download path : str, optional Destination path to store downloaded file. By default stores to the current directory with same name as in url. overwrite : bool, optional Whether to overwrite destination file if already exists. sha1_hash : str, optional Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified but doesn't match. Returns ------- str The file path of the downloaded file. """ if path is None: fname = url.split('/')[-1] else: path = os.path.expanduser(path) if os.path.isdir(path): fname = os.path.join(path, url.split('/')[-1]) else: fname = path if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) if not os.path.exists(dirname): os.makedirs(dirname) print('Downloading %s from %s...'%(fname, url)) r = requests.get(url, stream=True) if r.status_code != 200: raise RuntimeError("Failed downloading url %s"%url) total_length = r.headers.get('content-length') with open(fname, 'wb') as f: if total_length is None: # no content length header for chunk in r.iter_content(chunk_size=1024): if chunk: # filter out keep-alive new chunks f.write(chunk) else: total_length = int(total_length) for chunk in tqdm(r.iter_content(chunk_size=1024), total=int(total_length / 1024. + 0.5), unit='KB', unit_scale=False, dynamic_ncols=True): f.write(chunk) if sha1_hash and not check_sha1(fname, sha1_hash): raise UserWarning('File {} is downloaded but the content hash does not match. ' \ 'The repo may be outdated or download may be incomplete. ' \ 'If the "repo_url" is overridden, consider switching to ' \ 'the default repo.'.format(fname)) return fname def check_sha1(filename, sha1_hash): """Check whether the sha1 hash of the file content matches the expected hash. Parameters ---------- filename : str Path to the file. sha1_hash : str Expected sha1 hash in hexadecimal digits. Returns ------- bool Whether the file content matches the expected hash. """ sha1 = hashlib.sha1() with open(filename, 'rb') as f: while True: data = f.read(1048576) if not data: break sha1.update(data) return sha1.hexdigest() == sha1_hash def mkdir(path): """make dir exists okay""" try: os.makedirs(path) except OSError as exc: # Python >2.5 if exc.errno == errno.EEXIST and os.path.isdir(path): pass else: raise # ref https://github.com/CSAILVision/sceneparsing/blob/master/evaluationCode/utils_eval.py def pixel_accuracy(im_pred, im_lab): im_pred = np.asarray(im_pred) im_lab = np.asarray(im_lab) # Remove classes from unlabeled pixels in gt image. # We should not penalize detections in unlabeled portions of the image. pixel_labeled = np.sum(im_lab > 0) pixel_correct = np.sum((im_pred == im_lab) * (im_lab > 0)) #pixel_accuracy = 1.0 * pixel_correct / pixel_labeled return pixel_correct, pixel_labeled def intersection_and_union(im_pred, im_lab, num_class): im_pred = np.asarray(im_pred) im_lab = np.asarray(im_lab) # Remove classes from unlabeled pixels in gt image. im_pred = im_pred * (im_lab > 0) # Compute area intersection: intersection = im_pred * (im_pred == im_lab) area_inter, _ = np.histogram(intersection, bins=num_class-1, range=(1, num_class - 1)) # Compute area union: area_pred, _ = np.histogram(im_pred, bins=num_class-1, range=(1, num_class - 1)) area_lab, _ = np.histogram(im_lab, bins=num_class-1, range=(1, num_class - 1)) area_union = area_pred + area_lab - area_inter return area_inter, area_union def _get_voc_pallete(num_cls): n = num_cls pallete = [0]*(n*3) for j in range(0,n): lab = j pallete[j*3+0] = 0 pallete[j*3+1] = 0 pallete[j*3+2] = 0 i = 0 while (lab > 0): pallete[j*3+0] |= (((lab >> 0) & 1) << (7-i)) pallete[j*3+1] |= (((lab >> 1) & 1) << (7-i)) pallete[j*3+2] |= (((lab >> 2) & 1) << (7-i)) i = i + 1 lab >>= 3 return pallete vocpallete = _get_voc_pallete(256) adepallete = [0,0,0,120,120,120,180,120,120,6,230,230,80,50,50,4,200,3,120,120,80,140,140,140,204,5,255,230,230,230,4,250,7,224,5,255,235,255,7,150,5,61,120,120,70,8,255,51,255,6,82,143,255,140,204,255,4,255,51,7,204,70,3,0,102,200,61,230,250,255,6,51,11,102,255,255,7,71,255,9,224,9,7,230,220,220,220,255,9,92,112,9,255,8,255,214,7,255,224,255,184,6,10,255,71,255,41,10,7,255,255,224,255,8,102,8,255,255,61,6,255,194,7,255,122,8,0,255,20,255,8,41,255,5,153,6,51,255,235,12,255,160,150,20,0,163,255,140,140,140,250,10,15,20,255,0,31,255,0,255,31,0,255,224,0,153,255,0,0,0,255,255,71,0,0,235,255,0,173,255,31,0,255,11,200,200,255,82,0,0,255,245,0,61,255,0,255,112,0,255,133,255,0,0,255,163,0,255,102,0,194,255,0,0,143,255,51,255,0,0,82,255,0,255,41,0,255,173,10,0,255,173,255,0,0,255,153,255,92,0,255,0,255,255,0,245,255,0,102,255,173,0,255,0,20,255,184,184,0,31,255,0,255,61,0,71,255,255,0,204,0,255,194,0,255,82,0,10,255,0,112,255,51,0,255,0,194,255,0,122,255,0,255,163,255,153,0,0,255,10,255,112,0,143,255,0,82,0,255,163,255,0,255,235,0,8,184,170,133,0,255,0,255,92,184,0,255,255,0,31,0,184,255,0,214,255,255,0,112,92,255,0,0,224,255,112,224,255,70,184,160,163,0,255,153,0,255,71,255,0,255,0,163,255,204,0,255,0,143,0,255,235,133,255,0,255,0,235,245,0,255,255,0,122,255,245,0,10,190,212,214,255,0,0,204,255,20,0,255,255,255,0,0,153,255,0,41,255,0,255,204,41,0,255,41,255,0,173,0,255,0,245,255,71,0,255,122,0,255,0,255,184,0,92,255,184,255,0,0,133,255,255,214,0,25,194,194,102,255,0,92,0,255] citypallete = [ 128,64,128,244,35,232,70,70,70,102,102,156,190,153,153,153,153,153,250,170,30,220,220,0,107,142,35,152,251,152,70,130,180,220,20,60,255,0,0,0,0,142,0,0,70,0,60,100,0,80,100,0,0,230,119,11,32,128,192,0,0,64,128,128,64,128,0,192,128,128,192,128,64,64,0,192,64,0,64,192,0,192,192,0,64,64,128,192,64,128,64,192,128,192,192,128,0,0,64,128,0,64,0,128,64,128,128,64,0,0,192,128,0,192,0,128,192,128,128,192,64,0,64,192,0,64,64,128,64,192,128,64,64,0,192,192,0,192,64,128,192,192,128,192,0,64,64,128,64,64,0,192,64,128,192,64,0,64,192,128,64,192,0,192,192,128,192,192,64,64,64,192,64,64,64,192,64,192,192,64,64,64,192,192,64,192,64,192,192,192,192,192,32,0,0,160,0,0,32,128,0,160,128,0,32,0,128,160,0,128,32,128,128,160,128,128,96,0,0,224,0,0,96,128,0,224,128,0,96,0,128,224,0,128,96,128,128,224,128,128,32,64,0,160,64,0,32,192,0,160,192,0,32,64,128,160,64,128,32,192,128,160,192,128,96,64,0,224,64,0,96,192,0,224,192,0,96,64,128,224,64,128,96,192,128,224,192,128,32,0,64,160,0,64,32,128,64,160,128,64,32,0,192,160,0,192,32,128,192,160,128,192,96,0,64,224,0,64,96,128,64,224,128,64,96,0,192,224,0,192,96,128,192,224,128,192,32,64,64,160,64,64,32,192,64,160,192,64,32,64,192,160,64,192,32,192,192,160,192,192,96,64,64,224,64,64,96,192,64,224,192,64,96,64,192,224,64,192,96,192,192,224,192,192,0,32,0,128,32,0,0,160,0,128,160,0,0,32,128,128,32,128,0,160,128,128,160,128,64,32,0,192,32,0,64,160,0,192,160,0,64,32,128,192,32,128,64,160,128,192,160,128,0,96,0,128,96,0,0,224,0,128,224,0,0,96,128,128,96,128,0,224,128,128,224,128,64,96,0,192,96,0,64,224,0,192,224,0,64,96,128,192,96,128,64,224,128,192,224,128,0,32,64,128,32,64,0,160,64,128,160,64,0,32,192,128,32,192,0,160,192,128,160,192,64,32,64,192,32,64,64,160,64,192,160,64,64,32,192,192,32,192,64,160,192,192,160,192,0,96,64,128,96,64,0,224,64,128,224,64,0,96,192,128,96,192,0,224,192,128,224,192,64,96,64,192,96,64,64,224,64,192,224,64,64,96,192,192,96,192,64,224,192,192,224,192,32,32,0,160,32,0,32,160,0,160,160,0,32,32,128,160,32,128,32,160,128,160,160,128,96,32,0,224,32,0,96,160,0,224,160,0,96,32,128,224,32,128,96,160,128,224,160,128,32,96,0,160,96,0,32,224,0,160,224,0,32,96,128,160,96,128,32,224,128,160,224,128,96,96,0,224,96,0,96,224,0,224,224,0,96,96,128,224,96,128,96,224,128,224,224,128,32,32,64,160,32,64,32,160,64,160,160,64,32,32,192,160,32,192,32,160,192,160,160,192,96,32,64,224,32,64,96,160,64,224,160,64,96,32,192,224,32,192,96,160,192,224,160,192,32,96,64,160,96,64,32,224,64,160,224,64,32,96,192,160,96,192,32,224,192,160,224,192,96,96,64,224,96,64,96,224,64,224,224,64,96,96,192,224,96,192,96,224,192,0,0,0]