"tutorials/blitz/3_message_passing.py" did not exist on "16169f3a9d82566fbc9e32a30aa2bca44853d069"
Unverified Commit b872eb8c authored by Hang Zhang's avatar Hang Zhang Committed by GitHub
Browse files

ResNeSt plus (#256)

parent 5a1e3fbc
...@@ -12,4 +12,8 @@ ...@@ -12,4 +12,8 @@
from .encoding import * from .encoding import *
from .syncbn import * from .syncbn import *
from .customize import * from .customize import *
from .attention import *
from .loss import * from .loss import *
from .rectify import *
from .splat import SplAtConv2d
from .dropblock import *
###########################################################################
# Created by: Hang Zhang
# Email: zhang.hang@rutgers.edu
# Copyright (c) 2018
###########################################################################
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from .syncbn import SyncBatchNorm
__all__ = ['ACFModule', 'MixtureOfSoftMaxACF']
class ACFModule(nn.Module):
""" Multi-Head Attention module """
def __init__(self, n_head, n_mix, d_model, d_k, d_v, norm_layer=SyncBatchNorm,
kq_transform='conv', value_transform='conv',
pooling=True, concat=False, dropout=0.1):
super(ACFModule, self).__init__()
self.n_head = n_head
self.n_mix = n_mix
self.d_k = d_k
self.d_v = d_v
self.pooling = pooling
self.concat = concat
if self.pooling:
self.pool = nn.AvgPool2d(3, 2, 1, count_include_pad=False)
if kq_transform == 'conv':
self.conv_qs = nn.Conv2d(d_model, n_head*d_k, 1)
nn.init.normal_(self.conv_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
elif kq_transform == 'ffn':
self.conv_qs = nn.Sequential(
nn.Conv2d(d_model, n_head*d_k, 3, padding=1, bias=False),
norm_layer(n_head*d_k),
nn.ReLU(True),
nn.Conv2d(n_head*d_k, n_head*d_k, 1),
)
nn.init.normal_(self.conv_qs[-1].weight, mean=0, std=np.sqrt(1.0 / d_k))
elif kq_transform == 'dffn':
self.conv_qs = nn.Sequential(
nn.Conv2d(d_model, n_head*d_k, 3, padding=4, dilation=4, bias=False),
norm_layer(n_head*d_k),
nn.ReLU(True),
nn.Conv2d(n_head*d_k, n_head*d_k, 1),
)
nn.init.normal_(self.conv_qs[-1].weight, mean=0, std=np.sqrt(1.0 / d_k))
else:
raise NotImplemented
#self.conv_ks = nn.Conv2d(d_model, n_head*d_k, 1)
self.conv_ks = self.conv_qs
if value_transform == 'conv':
self.conv_vs = nn.Conv2d(d_model, n_head*d_v, 1)
else:
raise NotImplemented
#nn.init.normal_(self.conv_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.conv_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
self.attention = MixtureOfSoftMaxACF(n_mix=n_mix, d_k=d_k)
self.conv = nn.Conv2d(n_head*d_v, d_model, 1, bias=False)
self.norm_layer = norm_layer(d_model)
def forward(self, x):
residual = x
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
b_, c_, h_, w_ = x.size()
if self.pooling:
qt = self.conv_ks(x).view(b_*n_head, d_k, h_*w_)
kt = self.conv_ks(self.pool(x)).view(b_*n_head, d_k, h_*w_//4)
vt = self.conv_vs(self.pool(x)).view(b_*n_head, d_v, h_*w_//4)
else:
kt = self.conv_ks(x).view(b_*n_head, d_k, h_*w_)
qt = kt
vt = self.conv_vs(x).view(b_*n_head, d_v, h_*w_)
output, attn = self.attention(qt, kt, vt)
output = output.transpose(1, 2).contiguous().view(b_, n_head*d_v, h_, w_)
output = self.conv(output)
if self.concat:
output = torch.cat((self.norm_layer(output), residual), 1)
else:
output = self.norm_layer(output) + residual
return output
def demo(self, x):
residual = x
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
b_, c_, h_, w_ = x.size()
if self.pooling:
qt = self.conv_ks(x).view(b_*n_head, d_k, h_*w_)
kt = self.conv_ks(self.pool(x)).view(b_*n_head, d_k, h_*w_//4)
vt = self.conv_vs(self.pool(x)).view(b_*n_head, d_v, h_*w_//4)
else:
kt = self.conv_ks(x).view(b_*n_head, d_k, h_*w_)
qt = kt
vt = self.conv_vs(x).view(b_*n_head, d_v, h_*w_)
_, attn = self.attention(qt, kt, vt)
attn.view(b_, n_head, h_*w_, -1)
return attn
def extra_repr(self):
return 'n_head={}, n_mix={}, d_k={}, pooling={}' \
.format(self.n_head, self.n_mix, self.d_k, self.pooling)
class MixtureOfSoftMaxACF(nn.Module):
""""Mixture of SoftMax"""
def __init__(self, n_mix, d_k, attn_dropout=0.1):
super(MixtureOfSoftMaxACF, self).__init__()
self.temperature = np.power(d_k, 0.5)
self.n_mix = n_mix
self.att_drop = attn_dropout
self.dropout = nn.Dropout(attn_dropout)
self.softmax1 = nn.Softmax(dim=1)
self.softmax2 = nn.Softmax(dim=2)
self.d_k = d_k
if n_mix > 1:
self.weight = nn.Parameter(torch.Tensor(n_mix, d_k))
std = np.power(n_mix, -0.5)
self.weight.data.uniform_(-std, std)
def forward(self, qt, kt, vt):
B, d_k, N = qt.size()
m = self.n_mix
assert d_k == self.d_k
d = d_k // m
if m > 1:
# \bar{v} \in R^{B, d_k, 1}
bar_qt = torch.mean(qt, 2, True)
# pi \in R^{B, m, 1}
pi = self.softmax1(torch.matmul(self.weight, bar_qt)).view(B*m, 1, 1)
# reshape for n_mix
q = qt.view(B*m, d, N).transpose(1, 2)
N2 = kt.size(2)
kt = kt.view(B*m, d, N2)
v = vt.transpose(1, 2)
# {Bm, N, N}
attn = torch.bmm(q, kt)
attn = attn / self.temperature
attn = self.softmax2(attn)
attn = self.dropout(attn)
if m > 1:
# attn \in R^{Bm, N, N2} => R^{B, N, N2}
attn = (attn * pi).view(B, m, N, N2).sum(1)
output = torch.bmm(attn, v)
return output, attn
...@@ -28,8 +28,6 @@ class GlobalAvgPool2d(nn.Module): ...@@ -28,8 +28,6 @@ class GlobalAvgPool2d(nn.Module):
def forward(self, inputs): def forward(self, inputs):
return F.adaptive_avg_pool2d(inputs, 1).view(inputs.size(0), -1) return F.adaptive_avg_pool2d(inputs, 1).view(inputs.size(0), -1)
class GramMatrix(nn.Module): class GramMatrix(nn.Module):
r""" Gram Matrix for a 4D convolutional featuremaps as a mini-batch r""" Gram Matrix for a 4D convolutional featuremaps as a mini-batch
......
# https://github.com/Randl/MobileNetV3-pytorch/blob/master/dropblock.py
import torch
import torch.nn.functional as F
from torch import nn
__all__ = ['DropBlock2D', 'reset_dropblock']
class DropBlock2D(nn.Module):
r"""Randomly zeroes 2D spatial blocks of the input tensor.
As described in the paper
`DropBlock: A regularization method for convolutional networks`_ ,
dropping whole blocks of feature map allows to remove semantic
information as compared to regular dropout.
Args:
drop_prob (float): probability of an element to be dropped.
block_size (int): size of the block to drop
Shape:
- Input: `(N, C, H, W)`
- Output: `(N, C, H, W)`
.. _DropBlock: A regularization method for convolutional networks:
https://arxiv.org/abs/1810.12890
"""
def __init__(self, drop_prob, block_size, share_channel=False):
super(DropBlock2D, self).__init__()
self.register_buffer('i', torch.zeros(1, dtype=torch.int64))
self.register_buffer('drop_prob', drop_prob * torch.ones(1, dtype=torch.float32))
self.inited = False
self.step_size = 0.0
self.start_step = 0
self.nr_steps = 0
self.block_size = block_size
self.share_channel = share_channel
def reset(self):
"""stop DropBlock"""
self.inited = True
self.i[0] = 0
self.drop_prob = 0.0
def reset_steps(self, start_step, nr_steps, start_value=0, stop_value=None):
self.inited = True
stop_value = self.drop_prob.item() if stop_value is None else stop_value
self.i[0] = 0
self.drop_prob[0] = start_value
self.step_size = (stop_value - start_value) / nr_steps
self.nr_steps = nr_steps
self.start_step = start_step
def forward(self, x):
if not self.training or self.drop_prob.item() == 0.:
return x
else:
self.step()
# get gamma value
gamma = self._compute_gamma(x)
# sample mask and place on input device
if self.share_channel:
mask = (torch.rand(x.shape[0], *x.shape[2:], device=x.device, dtype=x.dtype) < gamma).squeeze(1)
else:
mask = (torch.rand(*x.shape, device=x.device, dtype=x.dtype) < gamma)
# compute block mask
block_mask, keeped = self._compute_block_mask(mask)
# apply block mask
out = x * block_mask
# scale output
out = out * (block_mask.numel() / keeped).to(out)
return out
def _compute_block_mask(self, mask):
block_mask = F.max_pool2d(mask,
kernel_size=(self.block_size, self.block_size),
stride=(1, 1),
padding=self.block_size // 2)
keeped = block_mask.numel() - block_mask.sum().to(torch.float32)
block_mask = 1 - block_mask
return block_mask, keeped
def _compute_gamma(self, x):
_, c, h, w = x.size()
gamma = self.drop_prob.item() / (self.block_size ** 2) * (h * w) / \
((w - self.block_size + 1) * (h - self.block_size + 1))
return gamma
def step(self):
assert self.inited
idx = self.i.item()
if idx > self.start_step and idx < self.start_step + self.nr_steps:
self.drop_prob += self.step_size
self.i += 1
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
idx_key = prefix + 'i'
drop_prob_key = prefix + 'drop_prob'
if idx_key not in state_dict:
state_dict[idx_key] = torch.zeros(1, dtype=torch.int64)
if idx_key not in drop_prob_key:
state_dict[drop_prob_key] = torch.ones(1, dtype=torch.float32)
super(DropBlock2D, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
"""overwrite save method"""
pass
def extra_repr(self):
return 'drop_prob={}, step_size={}'.format(self.drop_prob, self.step_size)
def reset_dropblock(start_step, nr_steps, start_value, stop_value, m):
"""
Example:
from functools import partial
apply_drop_prob = partial(reset_dropblock, 0, epochs*iters_per_epoch, 0.0, 0.1)
net.apply(apply_drop_prob)
"""
if isinstance(m, DropBlock2D):
print('reseting dropblock')
m.reset_steps(start_step, nr_steps, start_value, stop_value)
...@@ -17,7 +17,8 @@ from torch.nn.modules.utils import _pair ...@@ -17,7 +17,8 @@ from torch.nn.modules.utils import _pair
from ..functions import scaled_l2, aggregate, pairwise_cosine from ..functions import scaled_l2, aggregate, pairwise_cosine
__all__ = ['Encoding', 'EncodingDrop', 'Inspiration', 'UpsampleConv2d'] __all__ = ['Encoding', 'EncodingDrop', 'Inspiration', 'UpsampleConv2d',
'EncodingCosine']
class Encoding(Module): class Encoding(Module):
r""" r"""
...@@ -304,3 +305,43 @@ class UpsampleConv2d(Module): ...@@ -304,3 +305,43 @@ class UpsampleConv2d(Module):
out = F.conv2d(input, self.weight, self.bias, self.stride, out = F.conv2d(input, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups) self.padding, self.dilation, self.groups)
return F.pixel_shuffle(out, self.scale_factor) return F.pixel_shuffle(out, self.scale_factor)
# Experimental
class EncodingCosine(Module):
def __init__(self, D, K):
super(EncodingCosine, self).__init__()
# init codewords and smoothing factor
self.D, self.K = D, K
self.codewords = Parameter(torch.Tensor(K, D), requires_grad=True)
#self.scale = Parameter(torch.Tensor(K), requires_grad=True)
self.reset_params()
def reset_params(self):
std1 = 1./((self.K*self.D)**(1/2))
self.codewords.data.uniform_(-std1, std1)
#self.scale.data.uniform_(-1, 0)
def forward(self, X):
# input X is a 4D tensor
assert(X.size(1) == self.D)
if X.dim() == 3:
# BxDxN
B, D = X.size(0), self.D
X = X.transpose(1, 2).contiguous()
elif X.dim() == 4:
# BxDxHxW
B, D = X.size(0), self.D
X = X.view(B, D, -1).transpose(1, 2).contiguous()
else:
raise RuntimeError('Encoding Layer unknown input dims!')
# assignment weights NxKxD
L = pairwise_cosine(X, self.codewords)
A = F.softmax(L, dim=2)
# aggregate
E = aggregate(A, X, self.codewords)
return E
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'N x ' + str(self.D) + '=>' + str(self.K) + 'x' \
+ str(self.D) + ')'
...@@ -2,128 +2,60 @@ import torch ...@@ -2,128 +2,60 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.nn as nn import torch.nn as nn
from torch.autograd import Variable from torch.autograd import Variable
import numpy as np
__all__ = ['SegmentationLosses', 'OhemCrossEntropy2d', 'OHEMSegmentationLosses']
class SegmentationLosses(nn.CrossEntropyLoss): __all__ = ['LabelSmoothing', 'NLLMultiLabelSmooth', 'SegmentationLosses']
"""2D Cross Entropy Loss with Auxilary Loss"""
def __init__(self, se_loss=False, se_weight=0.2, nclass=-1,
aux=False, aux_weight=0.4, weight=None,
ignore_index=-1):
super(SegmentationLosses, self).__init__(weight, None, ignore_index)
self.se_loss = se_loss
self.aux = aux
self.nclass = nclass
self.se_weight = se_weight
self.aux_weight = aux_weight
self.bceloss = nn.BCELoss(weight)
def forward(self, *inputs):
if not self.se_loss and not self.aux:
return super(SegmentationLosses, self).forward(*inputs)
elif not self.se_loss:
pred1, pred2, target = tuple(inputs)
loss1 = super(SegmentationLosses, self).forward(pred1, target)
loss2 = super(SegmentationLosses, self).forward(pred2, target)
return loss1 + self.aux_weight * loss2
elif not self.aux:
pred, se_pred, target = tuple(inputs)
se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred)
loss1 = super(SegmentationLosses, self).forward(pred, target)
loss2 = self.bceloss(torch.sigmoid(se_pred), se_target)
return loss1 + self.se_weight * loss2
else:
pred1, se_pred, pred2, target = tuple(inputs)
se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred1)
loss1 = super(SegmentationLosses, self).forward(pred1, target)
loss2 = super(SegmentationLosses, self).forward(pred2, target)
loss3 = self.bceloss(torch.sigmoid(se_pred), se_target)
return loss1 + self.aux_weight * loss2 + self.se_weight * loss3
@staticmethod class LabelSmoothing(nn.Module):
def _get_batch_label_vector(target, nclass): """
# target is a 3D Variable BxHxW, output is 2D BxnClass NLL loss with label smoothing.
batch = target.size(0) """
tvect = Variable(torch.zeros(batch, nclass)) def __init__(self, smoothing=0.1):
for i in range(batch):
hist = torch.histc(target[i].cpu().data.float(),
bins=nclass, min=0,
max=nclass-1)
vect = hist>0
tvect[i] = vect
return tvect
# adapted from https://github.com/PkuRainBow/OCNet/blob/master/utils/loss.py
class OhemCrossEntropy2d(nn.Module):
def __init__(self, ignore_label=-1, thresh=0.7, min_kept=100000, use_weight=True):
super(OhemCrossEntropy2d, self).__init__()
self.ignore_label = ignore_label
self.thresh = float(thresh)
self.min_kept = int(min_kept)
if use_weight:
print("w/ class balance")
weight = torch.FloatTensor([0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754,
1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955,
1.0865, 1.1529, 1.0507])
self.criterion = torch.nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_label)
else:
print("w/o class balance")
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_label)
def forward(self, predict, target, weight=None):
""" """
Args: Constructor for the LabelSmoothing module.
predict:(n, c, h, w) :param smoothing: label smoothing factor
target:(n, h, w)
weight (Tensor, optional): a manual rescaling weight given to each class.
If given, has to be a Tensor of size "nclasses"
""" """
assert not target.requires_grad super(LabelSmoothing, self).__init__()
assert predict.dim() == 4 self.confidence = 1.0 - smoothing
assert target.dim() == 3 self.smoothing = smoothing
assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0))
assert predict.size(2) == target.size(1), "{0} vs {1} ".format(predict.size(2), target.size(1))
assert predict.size(3) == target.size(2), "{0} vs {1} ".format(predict.size(3), target.size(3))
n, c, h, w = predict.size() def forward(self, x, target):
input_label = target.data.cpu().numpy().ravel().astype(np.int32) logprobs = torch.nn.functional.log_softmax(x, dim=-1)
x = np.rollaxis(predict.data.cpu().numpy(), 1).reshape((c, -1))
input_prob = np.exp(x - x.max(axis=0).reshape((1, -1)))
input_prob /= input_prob.sum(axis=0).reshape((1, -1))
valid_flag = input_label != self.ignore_label nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
valid_inds = np.where(valid_flag)[0] nll_loss = nll_loss.squeeze(1)
label = input_label[valid_flag] smooth_loss = -logprobs.mean(dim=-1)
num_valid = valid_flag.sum() loss = self.confidence * nll_loss + self.smoothing * smooth_loss
if self.min_kept >= num_valid: return loss.mean()
print('Labels: {}'.format(num_valid))
elif num_valid > 0:
prob = input_prob[:,valid_flag]
pred = prob[label, np.arange(len(label), dtype=np.int32)]
threshold = self.thresh
if self.min_kept > 0:
index = pred.argsort()
threshold_index = index[ min(len(index), self.min_kept) - 1 ]
if pred[threshold_index] > self.thresh:
threshold = pred[threshold_index]
kept_flag = pred <= threshold
valid_inds = valid_inds[kept_flag]
label = input_label[valid_inds].copy() class NLLMultiLabelSmooth(nn.Module):
input_label.fill(self.ignore_label) def __init__(self, smoothing = 0.1):
input_label[valid_inds] = label super(NLLMultiLabelSmooth, self).__init__()
valid_flag_new = input_label != self.ignore_label self.confidence = 1.0 - smoothing
# print(np.sum(valid_flag_new)) self.smoothing = smoothing
target = Variable(torch.from_numpy(input_label.reshape(target.size())).long().cuda())
return self.criterion(predict, target) def forward(self, x, target):
if self.training:
x = x.float()
target = target.float()
logprobs = torch.nn.functional.log_softmax(x, dim = -1)
nll_loss = -logprobs * target
nll_loss = nll_loss.sum(-1)
smooth_loss = -logprobs.mean(dim=-1)
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()
else:
return torch.nn.functional.cross_entropy(x, target)
class OHEMSegmentationLosses(OhemCrossEntropy2d): class SegmentationLosses(nn.CrossEntropyLoss):
"""2D Cross Entropy Loss with Auxilary Loss""" """2D Cross Entropy Loss with Auxilary Loss"""
def __init__(self, se_loss=False, se_weight=0.2, nclass=-1, def __init__(self, se_loss=False, se_weight=0.2, nclass=-1,
aux=False, aux_weight=0.4, weight=None, aux=False, aux_weight=0.4, weight=None,
ignore_index=-1): ignore_index=-1):
super(OHEMSegmentationLosses, self).__init__(ignore_index) super(SegmentationLosses, self).__init__(weight, None, ignore_index)
self.se_loss = se_loss self.se_loss = se_loss
self.aux = aux self.aux = aux
self.nclass = nclass self.nclass = nclass
...@@ -133,23 +65,23 @@ class OHEMSegmentationLosses(OhemCrossEntropy2d): ...@@ -133,23 +65,23 @@ class OHEMSegmentationLosses(OhemCrossEntropy2d):
def forward(self, *inputs): def forward(self, *inputs):
if not self.se_loss and not self.aux: if not self.se_loss and not self.aux:
return super(OHEMSegmentationLosses, self).forward(*inputs) return super(SegmentationLosses, self).forward(*inputs)
elif not self.se_loss: elif not self.se_loss:
pred1, pred2, target = tuple(inputs) pred1, pred2, target = tuple(inputs)
loss1 = super(OHEMSegmentationLosses, self).forward(pred1, target) loss1 = super(SegmentationLosses, self).forward(pred1, target)
loss2 = super(OHEMSegmentationLosses, self).forward(pred2, target) loss2 = super(SegmentationLosses, self).forward(pred2, target)
return loss1 + self.aux_weight * loss2 return loss1 + self.aux_weight * loss2
elif not self.aux: elif not self.aux:
pred, se_pred, target = tuple(inputs) pred, se_pred, target = tuple(inputs)
se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred) se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred)
loss1 = super(OHEMSegmentationLosses, self).forward(pred, target) loss1 = super(SegmentationLosses, self).forward(pred, target)
loss2 = self.bceloss(torch.sigmoid(se_pred), se_target) loss2 = self.bceloss(torch.sigmoid(se_pred), se_target)
return loss1 + self.se_weight * loss2 return loss1 + self.se_weight * loss2
else: else:
pred1, se_pred, pred2, target = tuple(inputs) pred1, se_pred, pred2, target = tuple(inputs)
se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred1) se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred1)
loss1 = super(OHEMSegmentationLosses, self).forward(pred1, target) loss1 = super(SegmentationLosses, self).forward(pred1, target)
loss2 = super(OHEMSegmentationLosses, self).forward(pred2, target) loss2 = super(SegmentationLosses, self).forward(pred2, target)
loss3 = self.bceloss(torch.sigmoid(se_pred), se_target) loss3 = self.bceloss(torch.sigmoid(se_pred), se_target)
return loss1 + self.aux_weight * loss2 + self.se_weight * loss3 return loss1 + self.aux_weight * loss2 + self.se_weight * loss3
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -9,14 +9,10 @@ ...@@ -9,14 +9,10 @@
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
"""Encoding Util Tools""" """Encoding Util Tools"""
from .lr_scheduler import LR_Scheduler from .lr_scheduler import *
from .metrics import SegmentationMetric, batch_intersection_union, batch_pix_accuracy from .metrics import *
from .pallete import get_mask_pallete from .pallete import get_mask_pallete
from .train_helper import * from .train_helper import *
from .presets import load_image from .presets import load_image
from .files import * from .files import *
from .misc import * from .misc import *
__all__ = ['LR_Scheduler', 'batch_pix_accuracy', 'batch_intersection_union',
'save_checkpoint', 'download', 'mkdir', 'check_sha1', 'load_image',
'get_mask_pallete', 'get_selabel_vector', 'EMA']
...@@ -10,7 +10,10 @@ __all__ = ['save_checkpoint', 'download', 'mkdir', 'check_sha1'] ...@@ -10,7 +10,10 @@ __all__ = ['save_checkpoint', 'download', 'mkdir', 'check_sha1']
def save_checkpoint(state, args, is_best, filename='checkpoint.pth.tar'): def save_checkpoint(state, args, is_best, filename='checkpoint.pth.tar'):
"""Saves checkpoint to disk""" """Saves checkpoint to disk"""
directory = "runs/%s/%s/%s/"%(args.dataset, args.model, args.checkname) if hasattr(args, 'backbone'):
directory = "runs/%s/%s/%s/%s/"%(args.dataset, args.model, args.backbone, args.checkname)
else:
directory = "runs/%s/%s/%s/"%(args.dataset, args.model, args.checkname)
if not os.path.exists(directory): if not os.path.exists(directory):
os.makedirs(directory) os.makedirs(directory)
filename = directory + filename filename = directory + filename
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment