Commit 50dd7d3e authored by dengjb's avatar dengjb
Browse files

update

parents
Pipeline #3040 canceled with stages
from .utils import *
from .visualizer import Visualizer
from .scheduler import PolyLR
from .loss import FocalLoss
\ No newline at end of file
This diff is collapsed.
import torch.nn as nn
import torch.nn.functional as F
import torch
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=255):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.ignore_index = ignore_index
self.size_average = size_average
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(
inputs, targets, reduction='none', ignore_index=self.ignore_index)
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
if self.size_average:
return focal_loss.mean()
else:
return focal_loss.sum()
\ No newline at end of file
from torch.optim.lr_scheduler import _LRScheduler, StepLR
class PolyLR(_LRScheduler):
def __init__(self, optimizer, max_iters, power=0.9, last_epoch=-1, min_lr=1e-6):
self.power = power
self.max_iters = max_iters # avoid zero lr
self.min_lr = min_lr
super(PolyLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
return [ max( base_lr * ( 1 - self.last_epoch/self.max_iters )**self.power, self.min_lr)
for base_lr in self.base_lrs]
\ No newline at end of file
from torchvision.transforms.functional import normalize
import torch.nn as nn
import numpy as np
import os
def denormalize(tensor, mean, std):
mean = np.array(mean)
std = np.array(std)
_mean = -mean/std
_std = 1/std
return normalize(tensor, _mean, _std)
class Denormalize(object):
def __init__(self, mean, std):
mean = np.array(mean)
std = np.array(std)
self._mean = -mean/std
self._std = 1/std
def __call__(self, tensor):
if isinstance(tensor, np.ndarray):
return (tensor - self._mean.reshape(-1,1,1)) / self._std.reshape(-1,1,1)
return normalize(tensor, self._mean, self._std)
def set_bn_momentum(model, momentum=0.1):
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.momentum = momentum
def fix_bn(model):
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def mkdir(path):
if not os.path.exists(path):
os.mkdir(path)
from visdom import Visdom
import json
class Visualizer(object):
""" Visualizer
"""
def __init__(self, port='13579', env='main', id=None):
#self.cur_win = {}
self.vis = Visdom(port=port, env=env)
self.id = id
self.env = env
# Restore
#ori_win = self.vis.get_window_data()
#ori_win = json.loads(ori_win)
#print(ori_win)
#self.cur_win = { v['title']: k for k, v in ori_win.items() }
def vis_scalar(self, name, x, y, opts=None):
if not isinstance(x, list):
x = [x]
if not isinstance(y, list):
y = [y]
if self.id is not None:
name = "[%s]"%self.id + name
default_opts = { 'title': name }
if opts is not None:
default_opts.update(opts)
#win = self.cur_win.get(name, None)
#if win is not None:
self.vis.line( X=x, Y=y, win=name, opts=default_opts, update='append')
#else:
# self.cur_win[name] = self.vis.line( X=x, Y=y, opts=default_opts)
def vis_image(self, name, img, env=None, opts=None):
""" vis image in visdom
"""
if env is None:
env = self.env
if self.id is not None:
name = "[%s]"%self.id + name
#win = self.cur_win.get(name, None)
default_opts = { 'title': name }
if opts is not None:
default_opts.update(opts)
#if win is not None:
self.vis.image( img=img, win=name, opts=opts, env=env )
#else:
# self.cur_win[name] = self.vis.image( img=img, opts=default_opts, env=env )
def vis_table(self, name, tbl, opts=None):
#win = self.cur_win.get(name, None)
tbl_str = "<table width=\"100%\"> "
tbl_str+="<tr> \
<th>Term</th> \
<th>Value</th> \
</tr>"
for k, v in tbl.items():
tbl_str+= "<tr> \
<td>%s</td> \
<td>%s</td> \
</tr>"%(k, v)
tbl_str+="</table>"
default_opts = { 'title': name }
if opts is not None:
default_opts.update(opts)
#if win is not None:
self.vis.text(tbl_str, win=name, opts=default_opts)
#else:
#self.cur_win[name] = self.vis.text(tbl_str, opts=default_opts)
if __name__=='__main__':
import numpy as np
vis = Visualizer(port=35588, env='main')
tbl = {"lr": 214, "momentum": 0.9}
vis.vis_table("test_table", tbl)
tbl = {"lr": 244444, "momentum": 0.9, "haha": "hoho"}
vis.vis_table("test_table", tbl)
vis.vis_scalar(name='loss', x=0, y=1)
vis.vis_scalar(name='loss', x=2, y=4)
vis.vis_scalar(name='loss', x=4, y=6)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment