Commit b8428a7f authored by peng xu's avatar peng xu
Browse files

Merge branch 'main' into beam_search

parents e5034150 3f4e71df
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
#copyright (c) go-hiroaki & Chokurei
#email: guangmingwu2010@gmail.com
# guozhilingty@gmail.com
#
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
eps = 1e-6
def _binarize(y_data, threshold):
"""
args:
y_data : [float] 4-d tensor in [batch_size, channels, img_rows, img_cols]
threshold : [float] [0.0, 1.0]
return 4-d binarized y_data
"""
y_data[y_data < threshold] = 0.0
y_data[y_data >= threshold] = 1.0
return y_data
def _argmax(y_data, dim):
"""
args:
y_data : 4-d tensor in [batch_size, chs, img_rows, img_cols]
dim : int
return 3-d [int] y_data
"""
return torch.argmax(y_data, dim).int()
def _get_tp(y_pred, y_true):
"""
args:
y_true : [int] 3-d in [batch_size, img_rows, img_cols]
y_pred : [int] 3-d in [batch_size, img_rows, img_cols]
return [float] true_positive
"""
return torch.sum(y_true * y_pred).float()
def _get_fp(y_pred, y_true):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
return [float] false_positive
"""
return torch.sum((1 - y_true) * y_pred).float()
def _get_tn(y_pred, y_true):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
return [float] true_negative
"""
return torch.sum((1 - y_true) * (1 - y_pred)).float()
def _get_fn(y_pred, y_true):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
return [float] false_negative
"""
return torch.sum(y_true * (1 - y_pred)).float()
def _get_weights(y_true, nb_ch):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
nb_ch : int
return [float] weights
"""
batch_size, img_rows, img_cols = y_true.shape
pixels = batch_size * img_rows * img_cols
weights = [torch.sum(y_true==ch).item() / pixels for ch in range(nb_ch)]
return weights
class CFMatrix(object):
def __init__(self, des=None):
self.des = des
def __repr__(self):
return "ConfusionMatrix"
def __call__(self, y_pred, y_true, ignore_index, threshold=0.5):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return confusion matrix
"""
batch_size, img_rows, img_cols = y_pred.shape
chs = ignore_index
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
nb_tp = _get_tp(y_pred, y_true)
nb_fp = _get_fp(y_pred, y_true)
nb_tn = _get_tn(y_pred, y_true)
nb_fn = _get_fn(y_pred, y_true)
mperforms = [nb_tp, nb_fp, nb_tn, nb_fn]
performs = None
else:
performs = torch.zeros(chs, 4).to(device)
weights = _get_weights(y_true, chs)
for ch in range(chs):
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
y_false_ch = torch.zeros(batch_size, img_rows, img_cols)
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
y_true_ch[y_true == ch] = 1
y_false_ch[torch.logical_and((y_true != ch), (y_true != ignore_index))] = 1
y_pred_ch[y_pred == ch] = 1
nb_tp = _get_tp(y_pred_ch, y_true_ch)
nb_fp = torch.sum(y_false_ch * y_pred_ch).float()
nb_tn = torch.sum(y_false_ch * (1 - y_pred_ch)).float()
nb_fn = _get_fn(y_pred_ch, y_true_ch)
performs[int(ch), :] = torch.FloatTensor([nb_tp, nb_fp, nb_tn, nb_fn])
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
return mperforms, performs
class OAAcc(object):
def __init__(self, des="Overall Accuracy"):
self.des = des
def __repr__(self):
return "OAcc"
def __call__(self, y_pred, y_true, threshold=0.5):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return (tp+tn)/total
"""
batch_size, chs, img_rows, img_cols = y_true.shape
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
else:
y_pred = _argmax(y_pred, 1)
y_true = _argmax(y_true, 1)
nb_tp_tn = torch.sum(y_true == y_pred).float()
mperforms = nb_tp_tn / (batch_size * img_rows * img_cols)
performs = None
return mperforms, performs
class Precision(object):
def __init__(self, des="Precision"):
self.des = des
def __repr__(self):
return "Prec"
def __call__(self, y_pred, y_true, threshold=0.5):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return tp/(tp+fp)
"""
batch_size, chs, img_rows, img_cols = y_true.shape
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
nb_tp = _get_tp(y_pred, y_true)
nb_fp = _get_fp(y_pred, y_true)
mperforms = nb_tp / (nb_tp + nb_fp + esp)
performs = None
else:
y_pred = _argmax(y_pred, 1)
y_true = _argmax(y_true, 1)
performs = torch.zeros(chs, 1).to(device)
weights = _get_weights(y_true, chs)
for ch in range(chs):
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
y_true_ch[y_true == ch] = 1
y_pred_ch[y_pred == ch] = 1
nb_tp = _get_tp(y_pred_ch, y_true_ch)
nb_fp = _get_fp(y_pred_ch, y_true_ch)
performs[int(ch)] = nb_tp / (nb_tp + nb_fp + esp)
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
return mperforms, performs
class Recall(object):
def __init__(self, des="Recall"):
self.des = des
def __repr__(self):
return "Reca"
def __call__(self, y_pred, y_true, threshold=0.5):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return tp/(tp+fn)
"""
batch_size, chs, img_rows, img_cols = y_true.shape
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
nb_tp = _get_tp(y_pred, y_true)
nb_fn = _get_fn(y_pred, y_true)
mperforms = nb_tp / (nb_tp + nb_fn + esp)
performs = None
else:
y_pred = _argmax(y_pred, 1)
y_true = _argmax(y_true, 1)
performs = torch.zeros(chs, 1).to(device)
weights = _get_weights(y_true, chs)
for ch in range(chs):
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
y_true_ch[y_true == ch] = 1
y_pred_ch[y_pred == ch] = 1
nb_tp = _get_tp(y_pred_ch, y_true_ch)
nb_fn = _get_fn(y_pred_ch, y_true_ch)
performs[int(ch)] = nb_tp / (nb_tp + nb_fn + esp)
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
return mperforms, performs
class F1Score(object):
def __init__(self, des="F1Score"):
self.des = des
def __repr__(self):
return "F1Sc"
def __call__(self, y_pred, y_true, threshold=0.5):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return 2*precision*recall/(precision+recall)
"""
batch_size, chs, img_rows, img_cols = y_true.shape
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
nb_tp = _get_tp(y_pred, y_true)
nb_fp = _get_fp(y_pred, y_true)
nb_fn = _get_fn(y_pred, y_true)
_precision = nb_tp / (nb_tp + nb_fp + esp)
_recall = nb_tp / (nb_tp + nb_fn + esp)
mperforms = 2 * _precision * _recall / (_precision + _recall + esp)
performs = None
else:
y_pred = _argmax(y_pred, 1)
y_true = _argmax(y_true, 1)
performs = torch.zeros(chs, 1).to(device)
weights = _get_weights(y_true, chs)
for ch in range(chs):
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
y_true_ch[y_true == ch] = 1
y_pred_ch[y_pred == ch] = 1
nb_tp = _get_tp(y_pred_ch, y_true_ch)
nb_fp = _get_fp(y_pred_ch, y_true_ch)
nb_fn = _get_fn(y_pred_ch, y_true_ch)
_precision = nb_tp / (nb_tp + nb_fp + esp)
_recall = nb_tp / (nb_tp + nb_fn + esp)
performs[int(ch)] = 2 * _precision * \
_recall / (_precision + _recall + esp)
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
return mperforms, performs
class Kappa(object):
def __init__(self, des="Kappa"):
self.des = des
def __repr__(self):
return "Kapp"
def __call__(self, y_pred, y_true, threshold=0.5):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return (Po-Pe)/(1-Pe)
"""
batch_size, chs, img_rows, img_cols = y_true.shape
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
nb_tp = _get_tp(y_pred, y_true)
nb_fp = _get_fp(y_pred, y_true)
nb_tn = _get_tn(y_pred, y_true)
nb_fn = _get_fn(y_pred, y_true)
nb_total = nb_tp + nb_fp + nb_tn + nb_fn
Po = (nb_tp + nb_tn) / nb_total
Pe = ((nb_tp + nb_fp) * (nb_tp + nb_fn) +
(nb_fn + nb_tn) * (nb_fp + nb_tn)) / (nb_total**2)
mperforms = (Po - Pe) / (1 - Pe + esp)
performs = None
else:
y_pred = _argmax(y_pred, 1)
y_true = _argmax(y_true, 1)
performs = torch.zeros(chs, 1).to(device)
weights = _get_weights(y_true, chs)
for ch in range(chs):
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
y_true_ch[y_true == ch] = 1
y_pred_ch[y_pred == ch] = 1
nb_tp = _get_tp(y_pred_ch, y_true_ch)
nb_fp = _get_fp(y_pred_ch, y_true_ch)
nb_tn = _get_tn(y_pred_ch, y_true_ch)
nb_fn = _get_fn(y_pred_ch, y_true_ch)
nb_total = nb_tp + nb_fp + nb_tn + nb_fn
Po = (nb_tp + nb_tn) / nb_total
Pe = ((nb_tp + nb_fp) * (nb_tp + nb_fn)
+ (nb_fn + nb_tn) * (nb_fp + nb_tn)) / (nb_total**2)
performs[int(ch)] = (Po - Pe) / (1 - Pe + esp)
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
return mperforms, performs
class Jaccard(object):
def __init__(self, des="Jaccard"):
self.des = des
def __repr__(self):
return "Jacc"
def __call__(self, y_pred, y_true, threshold=0.5):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return intersection / (sum-intersection)
"""
batch_size, chs, img_rows, img_cols = y_true.shape
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
_intersec = torch.sum(y_true * y_pred).float()
_sum = torch.sum(y_true + y_pred).float()
mperforms = _intersec / (_sum - _intersec + esp)
performs = None
else:
y_pred = _argmax(y_pred, 1)
y_true = _argmax(y_true, 1)
performs = torch.zeros(chs, 1).to(device)
weights = _get_weights(y_true, chs)
for ch in range(chs):
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
y_true_ch[y_true == ch] = 1
y_pred_ch[y_pred == ch] = 1
_intersec = torch.sum(y_true_ch * y_pred_ch).float()
_sum = torch.sum(y_true_ch + y_pred_ch).float()
performs[int(ch)] = _intersec / (_sum - _intersec + esp)
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
return mperforms, performs
class MSE(object):
def __init__(self, des="Mean Square Error"):
self.des = des
def __repr__(self):
return "MSE"
def __call__(self, y_pred, y_true, dim=1, threshold=None):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
threshold : [0.0, 1.0]
return mean_squared_error, smaller the better
"""
if threshold:
y_pred = _binarize(y_pred, threshold)
return torch.mean((y_pred - y_true) ** 2)
class PSNR(object):
def __init__(self, des="Peak Signal to Noise Ratio"):
self.des = des
def __repr__(self):
return "PSNR"
def __call__(self, y_pred, y_true, dim=1, threshold=None):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
threshold : [0.0, 1.0]
return PSNR, larger the better
"""
if threshold:
y_pred = _binarize(y_pred, threshold)
mse = torch.mean((y_pred - y_true) ** 2)
return 10 * torch.log10(1 / mse)
class SSIM(object):
'''
modified from https://github.com/jorge-pessoa/pytorch-msssim
'''
def __init__(self, des="structural similarity index"):
self.des = des
def __repr__(self):
return "SSIM"
def gaussian(self, w_size, sigma):
gauss = torch.Tensor([math.exp(-(x - w_size//2)**2/float(2*sigma**2)) for x in range(w_size)])
return gauss/gauss.sum()
def create_window(self, w_size, channel=1):
_1D_window = self.gaussian(w_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = _2D_window.expand(channel, 1, w_size, w_size).contiguous()
return window
def __call__(self, y_pred, y_true, w_size=11, size_average=True, full=False):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
w_size : int, default 11
size_average : boolean, default True
full : boolean, default False
return ssim, larger the better
"""
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
if torch.max(y_pred) > 128:
max_val = 255
else:
max_val = 1
if torch.min(y_pred) < -0.5:
min_val = -1
else:
min_val = 0
L = max_val - min_val
padd = 0
(_, channel, height, width) = y_pred.size()
window = self.create_window(w_size, channel=channel).to(y_pred.device)
mu1 = F.conv2d(y_pred, window, padding=padd, groups=channel)
mu2 = F.conv2d(y_true, window, padding=padd, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(y_pred * y_pred, window, padding=padd, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(y_true * y_true, window, padding=padd, groups=channel) - mu2_sq
sigma12 = F.conv2d(y_pred * y_true, window, padding=padd, groups=channel) - mu1_mu2
C1 = (0.01 * L) ** 2
C2 = (0.03 * L) ** 2
v1 = 2.0 * sigma12 + C2
v2 = sigma1_sq + sigma2_sq + C2
cs = torch.mean(v1 / v2) # contrast sensitivity
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
if size_average:
ret = ssim_map.mean()
else:
ret = ssim_map.mean(1).mean(1).mean(1)
if full:
return ret, cs
return ret
class AE(object):
"""
Modified from matlab : colorangle.m, MATLAB V2019b
angle = acos(RGB1' * RGB2 / (norm(RGB1) * norm(RGB2)));
angle = 180 / pi * angle;
"""
def __init__(self, des='average Angular Error'):
self.des = des
def __repr__(self):
return "AE"
def __call__(self, y_pred, y_true):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
return average AE, smaller the better
"""
dotP = torch.sum(y_pred * y_true, dim=1)
Norm_pred = torch.sqrt(torch.sum(y_pred * y_pred, dim=1))
Norm_true = torch.sqrt(torch.sum(y_true * y_true, dim=1))
ae = 180 / math.pi * torch.acos(dotP / (Norm_pred * Norm_true + eps))
return ae.mean(1).mean(1)
if __name__ == "__main__":
for ch in [3, 1]:
batch_size, img_row, img_col = 1, 224, 224
y_true = torch.rand(batch_size, ch, img_row, img_col)
noise = torch.zeros(y_true.size()).data.normal_(0, std=0.1)
y_pred = y_true + noise
for cuda in [False, True]:
if cuda:
y_pred = y_pred.cuda()
y_true = y_true.cuda()
print('#'*20, 'Cuda : {} ; size : {}'.format(cuda, y_true.size()))
########### similarity metrics
metric = MSE()
acc = metric(y_pred, y_true).item()
print("{} ==> {}".format(repr(metric), acc))
metric = PSNR()
acc = metric(y_pred, y_true).item()
print("{} ==> {}".format(repr(metric), acc))
metric = SSIM()
acc = metric(y_pred, y_true).item()
print("{} ==> {}".format(repr(metric), acc))
metric = LPIPS(cuda)
acc = metric(y_pred, y_true).item()
print("{} ==> {}".format(repr(metric), acc))
metric = AE()
acc = metric(y_pred, y_true).item()
print("{} ==> {}".format(repr(metric), acc))
########### accuracy metrics
metric = OAAcc()
maccu, accu = metric(y_pred, y_true)
print('mAccu:', maccu, 'Accu', accu)
metric = Precision()
mprec, prec = metric(y_pred, y_true)
print('mPrec:', mprec, 'Prec', prec)
metric = Recall()
mreca, reca = metric(y_pred, y_true)
print('mReca:', mreca, 'Reca', reca)
metric = F1Score()
mf1sc, f1sc = metric(y_pred, y_true)
print('mF1sc:', mf1sc, 'F1sc', f1sc)
metric = Kappa()
mkapp, kapp = metric(y_pred, y_true)
print('mKapp:', mkapp, 'Kapp', kapp)
metric = Jaccard()
mjacc, jacc = metric(y_pred, y_true)
print('mJacc:', mjacc, 'Jacc', jacc)
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import einops
import torch
import apex
import torch.nn.functional as F
from megatron import get_args
from megatron.model import LayerNorm
from megatron.model.module import MegatronModule
from megatron.model.vision.utils import resize
class SetrSegmentationHead(MegatronModule):
def __init__(self, hidden_size, num_classes):
super(SetrSegmentationHead, self).__init__()
args = get_args()
self.hidden_size = hidden_size
self.num_classes = num_classes
self.img_h = args.img_h
self.img_w = args.img_w
self.patch_dim = args.patch_dim
self.layernorm = LayerNorm(hidden_size, eps=args.layernorm_epsilon)
self.conv_0 = torch.nn.Conv2d(hidden_size, hidden_size,
1, 1, bias=False)
self.norm_0 = apex.parallel.SyncBatchNorm(hidden_size)
self.conv_1 = torch.nn.Conv2d(hidden_size, num_classes, 1, 1)
def to_2D(self, x):
n, hw, c = x.shape
h = self.img_h // self.patch_dim
w = self.img_w // self.patch_dim
assert(hw == h * w)
x = x.transpose(1, 2).reshape(n, c, h, w)
return x
def forward(self, hidden_states):
# [b c h w]
hidden_states = self.layernorm(hidden_states)
hidden_states = self.to_2D(hidden_states)
hidden_states = self.conv_0(hidden_states)
hidden_states = self.norm_0(hidden_states)
hidden_states = torch.tanh(hidden_states)
hidden_states = self.conv_1(hidden_states)
# [b c h w]
result = F.interpolate(hidden_states,
size=(self.img_h, self.img_w),
mode='bilinear')
return result
class MLP(torch.nn.Module):
"""
Linear Embedding
"""
def __init__(self, input_dim=2048, embed_dim=768):
super().__init__()
self.proj = torch.nn.Linear(input_dim, embed_dim)
def forward(self, x):
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class SegformerSegmentationHead(MegatronModule):
def __init__(self, feature_strides, in_channels,
embedding_dim, dropout_ratio):
super(SegformerSegmentationHead, self).__init__()
assert len(feature_strides) == len(in_channels)
assert min(feature_strides) == feature_strides[0]
args = get_args()
self.feature_strides = feature_strides
self.in_channels = in_channels
self.embedding_dim = embedding_dim
self.num_classes = args.num_classes
self.dropout_ratio = dropout_ratio
c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = \
self.in_channels
self.linear_c4 = MLP(input_dim=c4_in_channels,
embed_dim=self.embedding_dim)
self.linear_c3 = MLP(input_dim=c3_in_channels,
embed_dim=self.embedding_dim)
self.linear_c2 = MLP(input_dim=c2_in_channels,
embed_dim=self.embedding_dim)
self.linear_c1 = MLP(input_dim=c1_in_channels,
embed_dim=self.embedding_dim)
self.conv_fuse = torch.nn.Conv2d(self.embedding_dim*4,
self.embedding_dim, 1, 1)
self.norm = apex.parallel.SyncBatchNorm(self.embedding_dim)
self.dropout = torch.nn.Dropout2d(self.dropout_ratio)
self.linear_pred = torch.nn.Conv2d(self.embedding_dim,
self.num_classes,
kernel_size=1)
def forward(self, inputs):
c1, c2, c3, c4 = inputs
############## MLP decoder on C1-C4 ###########
n, _, h, w = c4.shape
_c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3])
_c4 = resize(_c4, size=c1.size()[2:], mode='bilinear', align_corners=False)
_c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3])
_c3 = resize(_c3, size=c1.size()[2:], mode='bilinear', align_corners=False)
_c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3])
_c2 = resize(_c2, size=c1.size()[2:], mode='bilinear', align_corners=False)
_c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3])
_c = self.conv_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))
x = self.norm(_c)
x = F.relu(x, inplace=True)
x = self.dropout(x)
x = self.linear_pred(x)
return x
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import einops
import torch
import apex
import torch.nn.functional as F
from megatron import get_args
from megatron.model.module import MegatronModule
from megatron.model.vision.vit_backbone import VitBackbone, VitMlpHead
from megatron.model.vision.mit_backbone import mit_b3, mit_b5
from tasks.vision.segmentation.seg_heads import SetrSegmentationHead, SegformerSegmentationHead
class SetrSegmentationModel(MegatronModule):
def __init__(self,
num_classes,
pre_process=True,
post_process=True):
super(SetrSegmentationModel, self).__init__()
args = get_args()
assert post_process & pre_process
self.hidden_size = args.hidden_size
self.num_classes = num_classes
self.backbone = VitBackbone(
pre_process=pre_process,
post_process=post_process,
class_token=False,
post_layer_norm=False,
drop_path_rate=0.1
)
self.head = SetrSegmentationHead(
self.hidden_size,
self.num_classes
)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def forward(self, input):
# [b hw c]
hidden_states = self.backbone(input)
result_final = self.head(hidden_states)
return result_final
class SegformerSegmentationModel(MegatronModule):
def __init__(self,
num_classes,
pre_process=True,
post_process=True):
super(SegformerSegmentationModel, self).__init__()
args = get_args()
self.hidden_size = args.hidden_size
self.num_classes = num_classes
self.pre_process = pre_process
self.post_process = post_process
self.backbone = mit_b5()
self.head = SegformerSegmentationHead(
feature_strides=[4, 8, 16, 32],
in_channels=[64, 128, 320, 512],
embedding_dim=768,
dropout_ratio=0.1
)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def forward(self, input):
# [b hw c]
hidden_states = self.backbone(input)
hidden_states = self.head(hidden_states)
return hidden_states
# Copyright (c) 2020 The MMSegmenation Authors.
#
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
import random
import os
import math
import mmcv
import torch
import numpy as np
import torchvision.transforms as T
from torchvision import datasets
from torch.utils.data import Dataset
from megatron import print_rank_0
from megatron import get_args
from PIL import Image, ImageOps, ImageEnhance
import torchvision.transforms as torch_tr
def _is_pil_image(img):
return isinstance(img, Image.Image)
class PhotoMetricDistortion(object):
"""Apply photometric distortion to image sequentially, every transformation
is applied with a probability of 0.5. The position of random contrast is in
second or second to last.
1. random brightness
2. random contrast (mode 0)
3. convert color from BGR to HSV
4. random saturation
5. random hue
6. convert color from HSV to BGR
7. random contrast (mode 1)
8. randomly swap channels
Args:
brightness_delta (int): delta of brightness.
contrast_range (tuple): range of contrast.
saturation_range (tuple): range of saturation.
hue_delta (int): delta of hue.
"""
def __init__(self,
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18):
self.brightness_delta = brightness_delta
self.contrast_lower, self.contrast_upper = contrast_range
self.saturation_lower, self.saturation_upper = saturation_range
self.hue_delta = hue_delta
def convert(self, img, alpha=1, beta=0):
"""Multiple with alpha and add beat with clip."""
img = img.astype(np.float32) * alpha + beta
img = np.clip(img, 0, 255)
return img.astype(np.uint8)
def brightness(self, img):
"""Brightness distortion."""
if random.randint(0, 1):
return self.convert(
img,
beta=random.uniform(-self.brightness_delta,
self.brightness_delta))
return img
def contrast(self, img):
"""Contrast distortion."""
if random.randint(0, 1):
return self.convert(
img,
alpha=random.uniform(self.contrast_lower, self.contrast_upper))
return img
def saturation(self, img):
"""Saturation distortion."""
if random.randint(0, 1):
img = mmcv.bgr2hsv(img)
img[:, :, 1] = self.convert(
img[:, :, 1],
alpha=random.uniform(self.saturation_lower,
self.saturation_upper))
img = mmcv.hsv2bgr(img)
return img
def hue(self, img):
"""Hue distortion."""
if random.randint(0, 1):
img = mmcv.bgr2hsv(img)
img[:, :,
0] = (img[:, :, 0].astype(int) +
random.randint(-self.hue_delta, self.hue_delta)) % 180
img = mmcv.hsv2bgr(img)
return img
def __call__(self, img):
"""Call function to perform photometric distortion on images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images distorted.
"""
img = np.array(img)
# random brightness
img = self.brightness(img)
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
mode = random.randint(0, 1)
if mode == 1:
img = self.contrast(img)
# random saturation
img = self.saturation(img)
# random hue
img = self.hue(img)
# random contrast
if mode == 0:
img = self.contrast(img)
img = Image.fromarray(img.astype(np.uint8)).convert('RGB')
return img
class RandomCrop(object):
"""
Take a random crop from the image.
First the image or crop size may need to be adjusted if the incoming image
is too small...
If the image is smaller than the crop, then:
the image is padded up to the size of the crop
unless 'nopad', in which case the crop size is shrunk to fit the image
A random crop is taken such that the crop fits within the image.
if cfg.DATASET.TRANSLATION_AUG_FIX is set, we insure that there's always
translation randomness of at least that value around the image.
if image < crop_size:
# slide crop within image, random offset
else:
# slide image within crop
"""
def __init__(self, crop_size):
args = get_args()
self.size = crop_size
self.cat_max_ratio = 0.75
self.ignore_index = args.ignore_index
self.pad_color = (0, 0, 0)
def get_crop_bbox(self, img):
"""Randomly get a crop bounding box."""
img_w, img_h = img.size
target_h, target_w = self.size #[H W]
margin_h = max(img_h - target_h, 0)
margin_w = max(img_w - target_w, 0)
offset_h = random.randint(0, margin_h)
offset_w = random.randint(0, margin_w)
crop_y1, crop_y2 = offset_h, offset_h + target_h
crop_x1, crop_x2 = offset_w, offset_w + target_w
return crop_y1, crop_y2, crop_x1, crop_x2
def crop(self, img, crop_bbox):
"""Crop from ``img``"""
crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
img = img.crop((crop_x1, crop_y1, crop_x2, crop_y2))
return img
@staticmethod
def crop_in_image(target_w, target_h, w, h, img, mask):
if w == target_w:
x1 = 0
else:
x1 = random.randint(0, w - target_w)
if h == target_h:
y1 = 0
else:
y1 = random.randint(0, h - target_h)
return [img.crop((x1, y1, x1 + target_w, y1 + target_h)),
mask.crop((x1, y1, x1 + target_w, y1 + target_h))]
def __call__(self, img, mask):
w, h = img.size
target_h, target_w = self.size # ASSUME H, W
if w == target_w and h == target_h:
return img, mask
# Pad image if image < crop
if target_h > h:
pad_h = (target_h - h) // 2 + 1
else:
pad_h = 0
if target_w > w:
pad_w = (target_w - w) // 2 + 1
else:
pad_w = 0
border = (pad_w, pad_h, pad_w, pad_h)
if pad_h or pad_w:
img = ImageOps.expand(img, border=border, fill=(0, 0, 0))
mask = ImageOps.expand(mask, border=border, fill=self.ignore_index)
w, h = img.size
crop_bbox = self.get_crop_bbox(img)
if self.cat_max_ratio < 1.:
# Repeat 10 times
for _ in range(10):
seg_temp = self.crop(mask, crop_bbox)
labels, cnt = np.unique(seg_temp, return_counts=True)
cnt = cnt[labels != self.ignore_index]
if len(cnt) > 1 and np.max(cnt) / np.sum(
cnt) < self.cat_max_ratio:
break
crop_bbox = self.get_crop_bbox(img)
# crop the image
img = self.crop(img, crop_bbox)
# crop semantic seg
mask = self.crop(mask, crop_bbox)
assert(img.size[0] == self.size[1] and img.size[1] == self.size[0])
return img, mask
class RandomSizeAndCrop(object):
def __init__(self,
crop_size,
scale_min=0.5,
scale_max=2.0):
self.crop = RandomCrop(crop_size)
self.scale_min = scale_min
self.scale_max = scale_max
def __call__(self, img, mask):
scale_amt = random.uniform(self.scale_min, self.scale_max)
w, h = [int(i * scale_amt) for i in img.size]
resized_img = img.resize((w, h), Image.BICUBIC)
resized_mask = mask.resize((w, h), Image.NEAREST)
img, mask = self.crop(resized_img, resized_mask)
return img, mask
class RandomHorizontallyFlip(object):
def __call__(self, img, mask):
if random.random() < 0.5:
return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(
Image.FLIP_LEFT_RIGHT)
return img, mask
def adjust_brightness(img, brightness_factor):
"""Adjust brightness of an Image.
Args:
img (PIL Image): PIL Image to be adjusted.
brightness_factor (float): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2.
Returns:
PIL Image: Brightness adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(brightness_factor)
return img
def adjust_contrast(img, contrast_factor):
"""Adjust contrast of an Image.
Args:
img (PIL Image): PIL Image to be adjusted.
contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2.
Returns:
PIL Image: Contrast adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(contrast_factor)
return img
def adjust_saturation(img, saturation_factor):
"""Adjust color saturation of an image.
Args:
img (PIL Image): PIL Image to be adjusted.
saturation_factor (float): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while
2 will enhance the saturation by a factor of 2.
Returns:
PIL Image: Saturation adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
enhancer = ImageEnhance.Color(img)
img = enhancer.enhance(saturation_factor)
return img
def adjust_hue(img, hue_factor):
"""Adjust hue of an image.
The image hue is adjusted by converting the image to HSV and
cyclically shifting the intensities in the hue channel (H).
The image is then converted back to original image mode.
`hue_factor` is the amount of shift in H channel and must be in the
interval `[-0.5, 0.5]`.
See https://en.wikipedia.org/wiki/Hue for more details on Hue.
Args:
img (PIL Image): PIL Image to be adjusted.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
with complementary colors while 0 gives the original image.
Returns:
PIL Image: Hue adjusted image.
"""
if not(-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
input_mode = img.mode
if input_mode in {'L', '1', 'I', 'F'}:
return img
h, s, v = img.convert('HSV').split()
np_h = np.array(h, dtype=np.uint8)
# uint8 addition take cares of rotation across boundaries
with np.errstate(over='ignore'):
np_h += np.uint8(hue_factor * 255)
h = Image.fromarray(np_h, 'L')
img = Image.merge('HSV', (h, s, v)).convert(input_mode)
return img
class ColorJitter(object):
"""Randomly change the brightness, contrast and saturation of an image.
Args:
brightness (float): How much to jitter brightness. brightness_factor
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
contrast (float): How much to jitter contrast. contrast_factor
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
saturation (float): How much to jitter saturation. saturation_factor
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
[-hue, hue]. Should be >=0 and <= 0.5.
"""
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue
@staticmethod
def get_params(brightness, contrast, saturation, hue):
"""Get a randomized transform to be applied on image.
Arguments are same as that of __init__.
Returns:
Transform which randomly adjusts brightness, contrast and
saturation in a random order.
"""
transforms = []
if brightness > 0:
brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness)
transforms.append(
torch_tr.Lambda(lambda img: adjust_brightness(img, brightness_factor)))
if contrast > 0:
contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast)
transforms.append(
torch_tr.Lambda(lambda img: adjust_contrast(img, contrast_factor)))
if saturation > 0:
saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation)
transforms.append(
torch_tr.Lambda(lambda img: adjust_saturation(img, saturation_factor)))
if hue > 0:
hue_factor = np.random.uniform(-hue, hue)
transforms.append(
torch_tr.Lambda(lambda img: adjust_hue(img, hue_factor)))
np.random.shuffle(transforms)
transform = torch_tr.Compose(transforms)
return transform
def __call__(self, img):
"""
Args:
img (PIL Image): Input image.
Returns:
PIL Image: Color jittered image.
"""
transform = self.get_params(self.brightness, self.contrast,
self.saturation, self.hue)
return transform(img)
import math
import torch
import numpy as np
from megatron import get_args
def slidingcrops(img, mask):
# img: [b c h w]
# mask: [b h w]
args = get_args()
assert args.img_h == args.img_w
crop_size = args.img_h
stride = args.seg_stride
ignore_index = args.ignore_index
n, c, h, w = img.shape
assert h >= crop_size
assert w >= crop_size
long_size = max(h, w)
img_slices, mask_slices, slices_info = [], [], []
if long_size > crop_size:
assert stride <= crop_size
h_step_num = int(math.ceil((h - crop_size) / float(stride))) + 1
w_step_num = int(math.ceil((w - crop_size) / float(stride))) + 1
for yy in range(h_step_num):
for xx in range(w_step_num):
sy, sx = yy * stride, xx * stride
ey, ex = sy + crop_size, sx + crop_size
img_sub = img[:, :, sy: ey, sx: ex]
mask_sub = mask[:, sy: ey, sx: ex]
# padding
sub_h, sub_w = img_sub.shape[2:]
pad_h = max(crop_size - sub_h, 0)
pad_w = max(crop_size - sub_w, 0)
img_sub = torch.nn.functional.pad(img_sub, pad=(0, pad_w, 0, pad_h), value=ignore_index)
mask_sub = torch.nn.functional.pad(mask_sub, pad=(0, pad_w, 0, pad_h))
img_slices.append(img_sub)
mask_slices.append(mask_sub)
slices_info.append([sy, ey, sx, ex, sub_h, sub_w])
return torch.cat(img_slices), torch.cat(mask_slices), slices_info, (h, w)
else:
return img, mask, [[0, h, 0, w, h, w]], (h, w)
def slidingjoins(preds, probs, labels, slices_info, img_size):
args = get_args()
num_slices = len(slices_info)
if num_slices == 1:
return preds, labels
h, w = img_size
split_size = args.micro_batch_size
preds_split = torch.split(preds, split_size)
probs_split = torch.split(probs, split_size)
labels_split = torch.split(labels, split_size)
assert(len(preds_split) == num_slices)
total_max_probs = torch.zeros((split_size, h, w), dtype=torch.float, device='cuda')
total_preds = torch.zeros((split_size, h, w), dtype=torch.int, device='cuda')
total_labels = torch.zeros((split_size, h, w), dtype=torch.int, device='cuda')
for i in range(num_slices):
sy, ey, sx, ex, sub_h, sub_w = slices_info[i]
assert sy + sub_h <= h
assert sx + sub_w <= w
curr_max_probs = total_max_probs[:, sy:sy + sub_h, sx:sx + sub_w]
curr_preds = total_preds[:, sy:sy + sub_h, sx:sx + sub_w]
local_max_probs = probs_split[i][:, :sub_h, : sub_w]
local_preds = preds_split[i][:, :sub_h, :sub_w]
result_max_probs = torch.maximum(curr_max_probs, local_max_probs)
result_preds = torch.where(curr_max_probs >= local_max_probs, curr_preds, local_preds)
total_max_probs[:, sy:sy + sub_h, sx:sx + sub_w] = result_max_probs
total_preds[:, sy:sy + sub_h, sx:sx + sub_w] = result_preds
total_labels[:, sy:sy + sub_h, sx:sx + sub_w] = labels_split[i][0, :sub_h, :sub_w]
return total_preds, total_labels
import os
import sys
import json
import argparse
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
from megatron.data import indexed_dataset
def main(args):
prefixes = set()
for basename in os.listdir(args.input):
prefix, ext = os.path.splitext(basename)
if prefix in prefixes:
continue
if not os.path.isfile(os.path.join(args.input, basename)):
continue
ext_pair = '.bin' if ext == '.idx' else '.idx'
assert os.path.isfile(os.path.join(args.input, prefix) + ext_pair), \
f'ERROR: {ext_pair} file not provided for {os.path.join(args.input, prefix)}'
prefixes.add(prefix)
builder = None
for prefix in sorted(prefixes):
if builder is None:
dataset = indexed_dataset.make_dataset(os.path.join(args.input, prefix), 'infer')
if isinstance(dataset, indexed_dataset.MMapIndexedDataset):
builder = indexed_dataset.MMapIndexedDatasetBuilder(args.output_prefix + '.bin', dtype=dataset._index.dtype)
else:
builder = indexed_dataset.IndexedDatasetBuilder(args.output_prefix + '.bin')
del dataset
builder.merge_file_(os.path.join(args.input, prefix))
builder.finalize(args.output_prefix + '.idx')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
group = parser.add_argument_group(title='input data')
group.add_argument('--input', type=str, required=True,
help='Path to directory containing all document files to merge')
group = parser.add_argument_group(title='output data')
group.add_argument('--output-prefix', type=str, required=True,
help='Path to binary output file without suffix')
args = parser.parse_args()
assert os.path.isdir(args.input), \
f'ERROR: {args.input} is not a directory or does not exist'
assert os.path.isdir(os.path.dirname(args.output_prefix)), \
f'ERROR: {os.path.dirname(args.output_prefix)} is not a directory or does not exist'
main(args)
......@@ -122,8 +122,10 @@ def get_args():
choices=['lazy', 'cached', 'mmap'])
group = parser.add_argument_group(title='runtime')
group.add_argument('--workers', type=int, default=1,
group.add_argument('--workers', type=int, required=True,
help='Number of worker processes to launch')
group.add_argument('--chunk-size', type=int, required=True,
help='Chunk size assigned to each worker process')
group.add_argument('--log-interval', type=int, default=100,
help='Interval between progress updates')
args = parser.parse_args()
......@@ -154,7 +156,7 @@ def main():
encoder = Encoder(args)
tokenizer = build_tokenizer(args)
pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
encoded_docs = pool.imap(encoder.encode, fin, 25)
encoded_docs = pool.imap(encoder.encode, fin, args.chunk_size)
#encoded_docs = map(encoder.encode, fin)
level = "document"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment