Commit ce0e5303 authored by bailuo's avatar bailuo
Browse files

init

parents
Pipeline #2003 failed with stages
in 0 seconds
import os
import sys
import cv2
import numpy as np
sys.path.insert(0, './utils')
from evaluate import compute_sad_loss, compute_mse_loss, compute_mad_loss
import argparse
from tqdm import tqdm
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--pred-dir', type=str, default='path/to/outputs/ppm100', help="pred alpha dir")
parser.add_argument('--label-dir', type=str, default='path/to/PPM-100/matte', help="GT alpha dir")
parser.add_argument('--detailmap-dir', type=str, default='path/to/PPM-100/matte', help="trimap dir")
args = parser.parse_args()
mse_loss = []
sad_loss = []
mad_loss = []
### loss_unknown only consider the unknown regions, i.e. trimap==128, as trimap-based methods do
#mse_loss_unknown = []
#sad_loss_unknown = []
for img in tqdm(os.listdir(args.label_dir)):
print(img)
#pred = cv2.imread(os.path.join(args.pred_dir, img.replace('.png', '.jpg')), 0).astype(np.float32)
pred = cv2.imread(os.path.join(args.pred_dir, img.replace('.jpg', '.png')), 0).astype(np.float32)
label = cv2.imread(os.path.join(args.label_dir, img), 0).astype(np.float32)
detailmap = cv2.imread(os.path.join(args.detailmap_dir, img), 0).astype(np.float32)
#detailmap[detailmap > 0] = 128
#mse_loss_unknown_ = compute_mse_loss(pred, label, detailmap)
#sad_loss_unknown_ = compute_sad_loss(pred, label, detailmap)[0]
detailmap[...] = 128
mse_loss_ = compute_mse_loss(pred, label, detailmap)
sad_loss_ = compute_sad_loss(pred, label, detailmap)[0]
mad_loss_ = compute_mad_loss(pred, label, detailmap)
print('Whole Image: MSE:', mse_loss_, ' SAD:', sad_loss_, ' MAD:', mad_loss_)
#print('Detail Region: MSE:', mse_loss_unknown_, ' SAD:', sad_loss_unknown_)
#mse_loss_unknown.append(mse_loss_unknown_)
#sad_loss_unknown.append(sad_loss_unknown_)
mse_loss.append(mse_loss_)
sad_loss.append(sad_loss_)
mad_loss.append(mad_loss_)
print('Average:')
print('Whole Image: MSE:', np.array(mse_loss).mean(), ' SAD:', np.array(sad_loss).mean(), ' MAD:', np.array(mad_loss).mean())
#print('Detail Region: MSE:', np.array(mse_loss_unknown).mean(), ' SAD:', np.array(sad_loss_unknown).mean())
import os
import cv2
import numpy as np
import sys
sys.path.insert(0, './utils')
from evaluate import compute_sad_loss, compute_mse_loss, compute_mad_loss
import argparse
from tqdm import tqdm
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--pred-dir', type=str, default='path/to/outputs/rw100', help="pred alpha dir")
parser.add_argument('--label-dir', type=str, default='path/to/RefMatte_RW_100/mask/', help="GT alpha dir")
parser.add_argument('--detailmap-dir', type=str, default='path/to/RefMatte_RW_100/mask/', help="trimap dir")
args = parser.parse_args()
mse_loss = []
sad_loss = []
mad_loss = []
### loss_unknown only consider the unknown regions, i.e. trimap==128, as trimap-based methods do
#mse_loss_unknown = []
#sad_loss_unknown = []
for img in tqdm(os.listdir(args.label_dir)):
print(img)
#pred = cv2.imread(os.path.join(args.pred_dir, img.replace('.png', '.jpg')), 0).astype(np.float32)
pred = cv2.imread(os.path.join(args.pred_dir, img), 0).astype(np.float32)
label = cv2.imread(os.path.join(args.label_dir, img), 0).astype(np.float32)
detailmap = cv2.imread(os.path.join(args.detailmap_dir, img), 0).astype(np.float32)
#detailmap[detailmap > 0] = 128
#mse_loss_unknown_ = compute_mse_loss(pred, label, detailmap)
#sad_loss_unknown_ = compute_sad_loss(pred, label, detailmap)[0]
detailmap[...] = 128
mse_loss_ = compute_mse_loss(pred, label, detailmap)
sad_loss_ = compute_sad_loss(pred, label, detailmap)[0]
mad_loss_ = compute_mad_loss(pred, label, detailmap)
print('Whole Image: MSE:', mse_loss_, ' SAD:', sad_loss_, ' MAD:', mad_loss_)
#print('Detail Region: MSE:', mse_loss_unknown_, ' SAD:', sad_loss_unknown_)
#mse_loss_unknown.append(mse_loss_unknown_)
#sad_loss_unknown.append(sad_loss_unknown_)
mse_loss.append(mse_loss_)
sad_loss.append(sad_loss_)
mad_loss.append(mad_loss_)
print('Average:')
print('Whole Image: MSE:', np.array(mse_loss).mean(), ' SAD:', np.array(sad_loss).mean(), ' MAD:', np.array(mad_loss).mean())
#print('Detail Region: MSE:', np.array(mse_loss_unknown).mean(), ' SAD:', np.array(sad_loss_unknown).mean())
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import math
import time
import skimage.measure
import torch.nn.functional as F
from PIL import Image
from scipy import ndimage
from scipy.ndimage.morphology import distance_transform_edt
from multiprocessing import Pool
def findMaxConnectedRegion(x):
assert len(x.shape) == 2
cc, num = skimage.measure.label(x, connectivity=1, return_num=True)
omega = np.zeros_like(x)
if num > 0:
# find the largest connected region
max_id = np.argmax(np.bincount(cc.flatten())[1:]) + 1
omega[cc == max_id] = 1
return omega
def genGaussKernel(sigma, q=2):
pi = math.pi
eps = 1e-2
def gauss(x, sigma):
return np.exp(-np.power(x,2)/(2*np.power(sigma,2))) / (sigma*np.sqrt(2*pi))
def dgauss(x, sigma):
return -x * gauss(x,sigma) / np.power(sigma, 2)
hsize = int(np.ceil(sigma*np.sqrt(-2*np.log(np.sqrt(2*pi)*sigma*eps))))
size = 2 * hsize + 1
hx = np.zeros([size, size], dtype=np.float32)
for i in range(size):
for j in range(size):
u, v = i-hsize, j-hsize
hx[i,j] = gauss(u,sigma) * dgauss(v,sigma)
hx = hx / np.sqrt(np.sum(np.power(np.abs(hx), 2)))
hy = hx.transpose(1, 0)
return hx, hy, size
def calcOpticalFlow(frames):
prev, curr = frames
flow = cv2.calcOpticalFlowFarneback(prev.astype(np.uint8), curr.astype(np.uint8), None,
0.5, 5, 10, 2, 7, 1.5,
cv2.OPTFLOW_FARNEBACK_GAUSSIAN)
return flow
class ImageFilter(nn.Module):
def __init__(self, chn, kernel_size, weight, device):
super(ImageFilter, self).__init__()
self.kernel_size = kernel_size
assert kernel_size == weight.size(-1)
self.filter = nn.Conv2d(chn, chn, kernel_size, padding=0, bias=False)
self.filter.weight = nn.Parameter(weight)
self.device = device
def pad(self, x):
assert len(x.shape) == 3
x = x.unsqueeze(-1).permute((0,3,1,2))
b, c, h, w = x.shape
pad = self.kernel_size // 2
y = torch.zeros([b, c, h+pad*2, w+pad*2]).to(self.device)
y[:,:,0:pad,0:pad] = x[:,:,0:1,0:1].repeat(1,1,pad,pad)
y[:,:,0:pad,w+pad:] = x[:,:,0:1,-1:].repeat(1,1,pad,pad)
y[:,:,h+pad:,0:pad] = x[:,:,-1:,0:1].repeat(1,1,pad,pad)
y[:,:,h+pad:,w+pad:] = x[:,:,-1:,-1:].repeat(1,1,pad,pad)
y[:,:,0:pad,pad:w+pad] = x[:,:,0:1,:].repeat(1,1,pad,1)
y[:,:,pad:h+pad,0:pad] = x[:,:,:,0:1].repeat(1,1,1,pad)
y[:,:,h+pad:,pad:w+pad] = x[:,:,-1:,:].repeat(1,1,pad,1)
y[:,:,pad:h+pad,w+pad:] = x[:,:,:,-1:].repeat(1,1,1,pad)
y[:,:,pad:h+pad, pad:w+pad] = x
return y
def forward(self, x):
y = self.filter(self.pad(x))
return y
class BatchMetric(object):
def __init__(self, device, grad_sigma=1.4, grad_q=2,
conn_step=0.1, conn_thresh=0.5, conn_theta=0.15, conn_p=1):
# parameters for connectivity
self.conn_step = conn_step
self.conn_thresh = conn_thresh
self.conn_theta = conn_theta
self.conn_p = conn_p
self.device = device
hx, hy, size = genGaussKernel(grad_sigma, grad_q)
self.hx = hx
self.hy = hy
self.kernel_size = size
kx = self.hx[::-1, ::-1].copy()
ky = self.hy[::-1, ::-1].copy()
kernel_x = torch.from_numpy(kx).unsqueeze(0).unsqueeze(0)
kernel_y = torch.from_numpy(ky).unsqueeze(0).unsqueeze(0)
self.fx = ImageFilter(1, self.kernel_size, kernel_x, self.device).cuda(self.device)
self.fy = ImageFilter(1, self.kernel_size, kernel_y, self.device).cuda(self.device)
def run(self, input, target, mask=None, calc_mad=False):
torch.cuda.empty_cache()
input_t = torch.from_numpy(input.astype(np.float32)).to(self.device)
target_t = torch.from_numpy(target.astype(np.float32)).to(self.device)
if mask is None:
mask = torch.ones_like(target_t).to(self.device)
else:
mask = torch.from_numpy(mask.astype(np.float32)).to(self.device)
mask = (mask == 128).float()
if calc_mad:
mad = self.BatchMAD(input_t, target_t, mask)
else:
mad = None
sad = self.BatchSAD(input_t, target_t, mask)
mse = self.BatchMSE(input_t, target_t, mask)
grad = self.BatchGradient(input_t, target_t, mask)
conn = self.BatchConnectivity(input_t, target_t, mask)
return sad, mad, mse, grad, conn
def run_quick(self, input, target, mask=None):
torch.cuda.empty_cache()
input_t = torch.from_numpy(input.astype(np.float32)).to(self.device)
target_t = torch.from_numpy(target.astype(np.float32)).to(self.device)
if mask is None:
mask = torch.ones_like(target_t).to(self.device)
else:
mask = torch.from_numpy(mask.astype(np.float32)).to(self.device)
mask = (mask == 128).float()
mad = self.BatchMAD(input_t, target_t, mask)
#sad = self.BatchSAD(input_t, target_t, mask)
mse = self.BatchMSE(input_t, target_t, mask)
#grad = self.BatchGradient(input_t, target_t, mask)
#conn = self.BatchConnectivity(input_t, target_t, mask)
return mad, mse
def run_metric(self, metric, input, target, mask=None):
torch.cuda.empty_cache()
input_t = torch.from_numpy(input.astype(np.float32)).to(self.device)
target_t = torch.from_numpy(target.astype(np.float32)).to(self.device)
if mask is None:
mask = torch.ones_like(target_t).to(self.device)
else:
mask = torch.from_numpy(mask.astype(np.float32)).to(self.device)
mask = (mask == 128).float()
if metric == 'sad':
ret = self.BatchSAD(input_t, target_t, mask)
elif metric == 'mse':
ret = self.BatchMSE(input_t, target_t, mask)
elif metric == 'grad':
ret = self.BatchGradient(input_t, target_t, mask)
elif metric == 'conn':
ret = self.BatchConnectivity(input_t, target_t, mask)
else:
raise NotImplementedError
return ret
def BatchSAD(self, pred, target, mask):
B = target.size(0)
error_map = (pred - target).abs() / 255.
batch_loss = (error_map * mask).view(B, -1).sum(dim=-1)
batch_loss = batch_loss / 1000.
return batch_loss.data.cpu().numpy()
def BatchMAD(self, pred, target, mask):
B = target.size(0)
error_map = (pred - target).abs() / 255.
batch_loss = (error_map * mask).view(B, -1).sum(dim=-1)
batch_loss = batch_loss / (mask.view(B, -1).sum(dim=-1) + 1.)
return batch_loss.data.cpu().numpy()
def BatchMSE(self, pred, target, mask):
B = target.size(0)
error_map = (pred-target) / 255.
batch_loss = (error_map.pow(2) * mask).view(B, -1).sum(dim=-1)
batch_loss = batch_loss / (mask.view(B, -1).sum(dim=-1) + 1.)
return batch_loss.data.cpu().numpy()
def BatchGradient(self, pred, target, mask):
B = target.size(0)
pred = pred / 255.
target = target / 255.
pred_x_t = self.fx(pred).squeeze(1)
pred_y_t = self.fy(pred).squeeze(1)
target_x_t = self.fx(target).squeeze(1)
target_y_t = self.fy(target).squeeze(1)
pred_amp = (pred_x_t.pow(2) + pred_y_t.pow(2)).sqrt()
target_amp = (target_x_t.pow(2) + target_y_t.pow(2)).sqrt()
error_map = (pred_amp - target_amp).pow(2)
batch_loss = (error_map * mask).view(B, -1).sum(dim=-1) / (mask.view(B,-1).sum(dim=-1) + 1.)
return batch_loss.data.cpu().numpy()
def BatchConnectivity(self, pred, target, mask):
_, h, w = pred.shape
step = self.conn_step
theta = self.conn_theta
pred = pred / 255.
target = target / 255.
B, dimy, dimx = pred.shape
thresh_steps = torch.arange(0, 1+step, step).to(self.device)
l_map = torch.ones_like(pred).to(self.device)*(-1)
pool = Pool(B)
for i in range(1, len(thresh_steps)):
pred_alpha_thresh = pred>=thresh_steps[i]
target_alpha_thresh = target>=thresh_steps[i]
mask_i = pred_alpha_thresh * target_alpha_thresh
omegas = []
items = [mask_ij.data.cpu().numpy() for mask_ij in mask_i]
for omega in pool.imap(findMaxConnectedRegion, items):
omegas.append(omega)
omegas = torch.from_numpy(np.array(omegas)).to(self.device)
flag = (l_map==-1) * (omegas==0)
l_map[flag==1] = thresh_steps[i-1]
l_map[l_map==-1] = 1
pred_d = pred - l_map
target_d = target - l_map
pred_phi = 1 - pred_d*(pred_d>=theta).float()
target_phi = 1 - target_d*(target_d>=theta).float()
batch_loss = ((pred_phi-target_phi).abs()*mask).view(B, -1).sum(dim=-1) / (mask.view(B,-1).sum(dim=-1) + 1.)
pool.close()
return batch_loss.data.cpu().numpy()
def GaussianGradient(self, mat):
gx = np.zeros_like(mat)
gy = np.zeros_like(mat)
for i in range(mat.shape[0]):
gx[i, ...] = ndimage.filters.convolve(mat[i], self.hx, mode='nearest')
gy[i, ...] = ndimage.filters.convolve(mat[i], self.hy, mode='nearest')
return gx, gy
def generate_trimap(alpha, k_size=3, iterations=5):
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k_size, k_size))
fg = np.array(np.equal(alpha, 255).astype(np.float32))
unknown = np.array(np.not_equal(alpha, 0).astype(np.float32))
unknown = cv2.dilate(unknown, kernel, iterations=iterations)
trimap = fg * 255 + (unknown - fg) * 128
return trimap.astype(np.uint8)
# ------------------------------------------------------------------------
# Modified from Grounded-SAM (https://github.com/IDEA-Research/Grounded-Segment-Anything)
# ------------------------------------------------------------------------
import os
import random
import cv2
from scipy import ndimage
import gradio as gr
import argparse
import numpy as np
import torch
from torch.nn import functional as F
import torchvision
import networks
import utils
import time
# Grounding DINO
import sys
sys.path.insert(0, './GroundingDINO')
from groundingdino.util.inference import Model
# SAM
sys.path.insert(0, './segment-anything')
from segment_anything.utils.transforms import ResizeLongestSide
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
# SD
from diffusers import StableDiffusionPipeline
transform = ResizeLongestSide(1024)
# Green Screen
PALETTE_back = (51, 255, 146)
GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT_PATH = "checkpoints/groundingdino_swint_ogc.pth"
mam_checkpoint="checkpoints/mam_vith.pth"
output_dir="outputs"
device="cuda"
background_list = os.listdir('assets/backgrounds')
# initialize MAM
mam_model = networks.get_generator_m2m(seg='sam_vit_h', m2m='sam_decoder_deep')
mam_model.to(device)
checkpoint = torch.load(mam_checkpoint, map_location=device)
mam_model.m2m.load_state_dict(utils.remove_prefix_state_dict(checkpoint['state_dict']), strict=True)
mam_model = mam_model.eval()
# initialize GroundingDINO
grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, device=device)
# initialize StableDiffusionPipeline
generator = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
generator.to(device)
def run_grounded_sam(input_image, text_prompt, task_type, background_prompt, background_type, box_threshold, text_threshold, iou_threshold, scribble_mode, guidance_mode):
global groundingdino_model, sam_predictor, generator
start_time = time.time()
# make dir
os.makedirs(output_dir, exist_ok=True)
# load image
image_ori = input_image["image"]
scribble = input_image["mask"]
original_size = image_ori.shape[:2]
if task_type == 'text':
if text_prompt is None:
print('Please input non-empty text prompt')
with torch.no_grad():
detections, phrases = grounding_dino_model.predict_with_caption(
image=cv2.cvtColor(image_ori, cv2.COLOR_RGB2BGR),
caption=text_prompt,
box_threshold=box_threshold,
text_threshold=text_threshold
)
if len(detections.xyxy) > 1:
nms_idx = torchvision.ops.nms(
torch.from_numpy(detections.xyxy),
torch.from_numpy(detections.confidence),
iou_threshold,
).numpy().tolist()
detections.xyxy = detections.xyxy[nms_idx]
detections.confidence = detections.confidence[nms_idx]
bbox = detections.xyxy[np.argmax(detections.confidence)]
bbox = transform.apply_boxes(bbox, original_size)
bbox = torch.as_tensor(bbox, dtype=torch.float).to(device)
image = transform.apply_image(image_ori)
image = torch.as_tensor(image).to(device)
image = image.permute(2, 0, 1).contiguous()
pixel_mean = torch.tensor([123.675, 116.28, 103.53]).view(3,1,1).to(device)
pixel_std = torch.tensor([58.395, 57.12, 57.375]).view(3,1,1).to(device)
image = (image - pixel_mean) / pixel_std
h, w = image.shape[-2:]
pad_size = image.shape[-2:]
padh = 1024 - h
padw = 1024 - w
image = F.pad(image, (0, padw, 0, padh))
if task_type == 'scribble_point':
scribble = scribble.transpose(2, 1, 0)[0]
labeled_array, num_features = ndimage.label(scribble >= 255)
centers = ndimage.center_of_mass(scribble, labeled_array, range(1, num_features+1))
centers = np.array(centers)
### (x,y)
centers = transform.apply_coords(centers, original_size)
point_coords = torch.from_numpy(centers).to(device)
point_coords = point_coords.unsqueeze(0).to(device)
point_labels = torch.from_numpy(np.array([1] * len(centers))).unsqueeze(0).to(device)
if scribble_mode == 'split':
point_coords = point_coords.permute(1, 0, 2)
point_labels = point_labels.permute(1, 0)
sample = {'image': image.unsqueeze(0), 'point': point_coords, 'label': point_labels, 'ori_shape': original_size, 'pad_shape': pad_size}
elif task_type == 'scribble_box':
scribble = scribble.transpose(2, 1, 0)[0]
labeled_array, num_features = ndimage.label(scribble >= 255)
centers = ndimage.center_of_mass(scribble, labeled_array, range(1, num_features+1))
centers = np.array(centers)
### (x1, y1, x2, y2)
x_min = centers[:, 0].min()
x_max = centers[:, 0].max()
y_min = centers[:, 1].min()
y_max = centers[:, 1].max()
bbox = np.array([x_min, y_min, x_max, y_max])
bbox = transform.apply_boxes(bbox, original_size)
bbox = torch.as_tensor(bbox, dtype=torch.float).to(device)
sample = {'image': image.unsqueeze(0), 'bbox': bbox.unsqueeze(0), 'ori_shape': original_size, 'pad_shape': pad_size}
elif task_type == 'text':
sample = {'image': image.unsqueeze(0), 'bbox': bbox.unsqueeze(0), 'ori_shape': original_size, 'pad_shape': pad_size}
else:
print("task_type:{} error!".format(task_type))
with torch.no_grad():
feas, pred, post_mask = mam_model.forward_inference(sample)
alpha_pred_os1, alpha_pred_os4, alpha_pred_os8 = pred['alpha_os1'], pred['alpha_os4'], pred['alpha_os8']
alpha_pred_os8 = alpha_pred_os8[..., : sample['pad_shape'][0], : sample['pad_shape'][1]]
alpha_pred_os4 = alpha_pred_os4[..., : sample['pad_shape'][0], : sample['pad_shape'][1]]
alpha_pred_os1 = alpha_pred_os1[..., : sample['pad_shape'][0], : sample['pad_shape'][1]]
alpha_pred_os8 = F.interpolate(alpha_pred_os8, sample['ori_shape'], mode="bilinear", align_corners=False)
alpha_pred_os4 = F.interpolate(alpha_pred_os4, sample['ori_shape'], mode="bilinear", align_corners=False)
alpha_pred_os1 = F.interpolate(alpha_pred_os1, sample['ori_shape'], mode="bilinear", align_corners=False)
if guidance_mode == 'mask':
weight_os8 = utils.get_unknown_tensor_from_mask_oneside(post_mask, rand_width=10, train_mode=False)
post_mask[weight_os8>0] = alpha_pred_os8[weight_os8>0]
alpha_pred = post_mask.clone().detach()
else:
weight_os8 = utils.get_unknown_box_from_mask(post_mask)
alpha_pred_os8[weight_os8>0] = post_mask[weight_os8>0]
alpha_pred = alpha_pred_os8.clone().detach()
weight_os4 = utils.get_unknown_tensor_from_pred_oneside(alpha_pred, rand_width=20, train_mode=False)
alpha_pred[weight_os4>0] = alpha_pred_os4[weight_os4>0]
weight_os1 = utils.get_unknown_tensor_from_pred_oneside(alpha_pred, rand_width=10, train_mode=False)
alpha_pred[weight_os1>0] = alpha_pred_os1[weight_os1>0]
alpha_pred = alpha_pred[0][0].cpu().numpy()
#### draw
### alpha matte
alpha_rgb = cv2.cvtColor(np.uint8(alpha_pred*255), cv2.COLOR_GRAY2RGB)
### com img with background
if background_type == 'real_world_sample':
background_img_file = os.path.join('assets/backgrounds', random.choice(background_list))
background_img = cv2.imread(background_img_file)
background_img = cv2.cvtColor(background_img, cv2.COLOR_BGR2RGB)
background_img = cv2.resize(background_img, (image_ori.shape[1], image_ori.shape[0]))
com_img = alpha_pred[..., None] * image_ori + (1 - alpha_pred[..., None]) * np.uint8(background_img)
com_img = np.uint8(com_img)
else:
if background_prompt is None:
print('Please input non-empty background prompt')
else:
background_img = generator(background_prompt).images[0]
background_img = np.array(background_img)
background_img = cv2.resize(background_img, (image_ori.shape[1], image_ori.shape[0]))
com_img = alpha_pred[..., None] * image_ori + (1 - alpha_pred[..., None]) * np.uint8(background_img)
com_img = np.uint8(com_img)
### com img with green screen
green_img = alpha_pred[..., None] * image_ori + (1 - alpha_pred[..., None]) * np.array([PALETTE_back], dtype='uint8')
green_img = np.uint8(green_img)
end_time = time.time()
execution_time = end_time - start_time
print(f"推理时间:{execution_time}")
return [(com_img, 'composite with background'), (green_img, 'green screen'), (alpha_rgb, 'alpha matte')]
if __name__ == "__main__":
parser = argparse.ArgumentParser("MAM demo", add_help=True)
parser.add_argument("--debug", action="store_true", help="using debug mode")
parser.add_argument("--share", action="store_true", help="share the app")
parser.add_argument('--port', type=int, default=7589, help='port to run the server')
parser.add_argument('--no-gradio-queue', action="store_true", help='path to the SAM checkpoint')
args = parser.parse_args()
print(args)
block = gr.Blocks()
if not args.no_gradio_queue:
block = block.queue()
with block:
gr.Markdown(
"""
# Matting Anything Demo
Welcome to the Matting Anything demo and upload your image to get started <br/> You may select different prompt types to get the alpha matte of target instance, and select different backgrounds for image composition.
## Usage
You may check the <a href='https://www.youtube.com/watch?v=XY2Q0HATGOk'>video</a> to see how to play with the demo, or check the details below.
<details>
You may upload an image to start, we support 3 prompt types to get the alpha matte of the target instance:
**scribble_point**: Click an point on the target instance.
**scribble_box**: Click on two points, the top-left point and the bottom-right point to represent a bounding box of the target instance.
**text**: Send text prompt to identify the target instance in the `Text prompt` box.
We also support 2 background types to support image composition with the alpha matte output:
**real_world_sample**: Randomly select a real-world image from `assets/backgrounds` for composition.
**generated_by_text**: Send background text prompt to create a background image with stable diffusion model in the `Background prompt` box.
</details>
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(source='upload', type="numpy", value="assets/demo.jpg", tool="sketch")
task_type = gr.Dropdown(["scribble_point", "scribble_box", "text"], value="text", label="Prompt type")
text_prompt = gr.Textbox(label="Text prompt", placeholder="the girl in the middle")
background_type = gr.Dropdown(["generated_by_text", "real_world_sample"], value="generated_by_text", label="Background type")
background_prompt = gr.Textbox(label="Background prompt", placeholder="downtown area in New York")
run_button = gr.Button(label="Run")
with gr.Accordion("Advanced options", open=False):
box_threshold = gr.Slider(
label="Box Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05
)
text_threshold = gr.Slider(
label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05
)
iou_threshold = gr.Slider(
label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.05
)
scribble_mode = gr.Dropdown(
["merge", "split"], value="split", label="scribble_mode"
)
guidance_mode = gr.Dropdown(
["mask", "alpha"], value="alpha", label="guidance_mode", info="mask guidance is for complex scenes with multiple instances, alpha guidance is for simple scene with single instance"
)
with gr.Column():
gallery = gr.Gallery(
label="Generated images", show_label=True, elem_id="gallery"
).style(preview=True, grid=3, object_fit="scale-down")
run_button.click(fn=run_grounded_sam, inputs=[
input_image, text_prompt, task_type, background_prompt, background_type, box_threshold, text_threshold, iou_threshold, scribble_mode, guidance_mode], outputs=gallery)
block.queue(concurrency_count=100)
block.launch(server_name='0.0.0.0', server_port=args.port, debug=args.debug, share=args.share)
icon.png

68.4 KB

import os
import cv2
import toml
import argparse
import numpy as np
import json
import torch
from torch.nn import functional as F
import torchvision
import utils
from utils import CONFIG
import networks
from tqdm import tqdm
from fvcore.nn import FlopCountAnalysis
import sys
sys.path.insert(0, './segment-anything')
sys.path.insert(0, './GroundingDINO')
from segment_anything.utils.transforms import ResizeLongestSide
from groundingdino.util.inference import Model
transform = ResizeLongestSide(1024)
def single_ms_inference(model, image_dict, args):
with torch.no_grad():
feas, pred, post_mask = model.forward_inference(image_dict)
if args.sam:
post_mask = post_mask[0].cpu().numpy() * 255
return post_mask.transpose(1, 2, 0).astype('uint8')
alpha_pred_os1, alpha_pred_os4, alpha_pred_os8 = pred['alpha_os1'], pred['alpha_os4'], pred['alpha_os8']
alpha_pred_os8 = alpha_pred_os8[..., : image_dict['pad_shape'][0], : image_dict['pad_shape'][1]]
alpha_pred_os4 = alpha_pred_os4[..., : image_dict['pad_shape'][0], : image_dict['pad_shape'][1]]
alpha_pred_os1 = alpha_pred_os1[..., : image_dict['pad_shape'][0], : image_dict['pad_shape'][1]]
alpha_pred_os8 = F.interpolate(alpha_pred_os8, image_dict['ori_shape'], mode="bilinear", align_corners=False)
alpha_pred_os4 = F.interpolate(alpha_pred_os4, image_dict['ori_shape'], mode="bilinear", align_corners=False)
alpha_pred_os1 = F.interpolate(alpha_pred_os1, image_dict['ori_shape'], mode="bilinear", align_corners=False)
if args.maskguide:
if args.twoside:
weight_os8 = utils.get_unknown_tensor_from_mask(post_mask, rand_width=args.os8_width, train_mode=False)
else:
weight_os8 = utils.get_unknown_tensor_from_mask_oneside(post_mask, rand_width=args.os8_width, train_mode=False)
post_mask[weight_os8>0] = alpha_pred_os8[weight_os8>0]
alpha_pred = post_mask.clone().detach()
else:
if args.postprocess:
weight_os8 = utils.get_unknown_box_from_mask(post_mask)
alpha_pred_os8[weight_os8>0] = post_mask[weight_os8>0]
alpha_pred = alpha_pred_os8.clone().detach()
if args.twoside:
weight_os4 = utils.get_unknown_tensor_from_pred(alpha_pred, rand_width=args.os4_width, train_mode=False)
else:
weight_os4 = utils.get_unknown_tensor_from_pred_oneside(alpha_pred, rand_width=args.os4_width, train_mode=False)
alpha_pred[weight_os4>0] = alpha_pred_os4[weight_os4>0]
if args.twoside:
weight_os1 = utils.get_unknown_tensor_from_pred(alpha_pred, rand_width=args.os1_width, train_mode=False)
else:
weight_os1 = utils.get_unknown_tensor_from_pred_oneside(alpha_pred, rand_width=args.os1_width, train_mode=False)
alpha_pred[weight_os1>0] = alpha_pred_os1[weight_os1>0]
alpha_pred = alpha_pred[0].cpu().numpy() * 255
return alpha_pred.transpose(1, 2, 0).astype('uint8')
def generator_tensor_dict(image_path, alpha_path, args):
# read images
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
original_size = image.shape[:2]
alpha_single = cv2.imread(alpha_path, 0)
alpha_single[alpha_single>127] = 255
alpha_single[alpha_single<=127] = 0
fg_set = np.where(alpha_single != 0)
x_min = np.min(fg_set[1])
x_max = np.max(fg_set[1])
y_min = np.min(fg_set[0])
y_max = np.max(fg_set[0])
bbox = np.array([x_min, y_min, x_max, y_max])
image = transform.apply_image(image)
image = torch.as_tensor(image).cuda()
image = image.permute(2, 0, 1).contiguous()
bbox = transform.apply_boxes(bbox, original_size)
input_point = np.array([[(bbox[0][0] + bbox[0][2])/2, (bbox[0][1] + bbox[0][3])/2]])
input_label = np.array([1])
input_point = torch.as_tensor(input_point, dtype=torch.float).cuda()
input_label = torch.as_tensor(input_label, dtype=torch.float).cuda()
bbox = torch.as_tensor(bbox, dtype=torch.float).cuda()
pixel_mean = torch.tensor([123.675, 116.28, 103.53]).view(3,1,1).cuda()
pixel_std = torch.tensor([58.395, 57.12, 57.375]).view(3,1,1).cuda()
image = (image - pixel_mean) / pixel_std
h, w = image.shape[-2:]
pad_size = image.shape[-2:]
padh = 1024 - h
padw = 1024 - w
image = F.pad(image, (0, padw, 0, padh))
if args.prompt == 'box':
sample = {'image': image[None, ...], 'bbox': bbox[None, ...], 'ori_shape': original_size, 'pad_shape': pad_size}
elif args.prompt == 'point':
sample = {'image': image[None, ...], 'point': input_point[None, ...], 'label': input_label[None, ...], 'ori_shape': original_size, 'pad_shape': pad_size}
return sample
def generator_tensor_dict_from_text(image_path, text, dino_model, args):
# read images
image = cv2.imread(image_path)
detections, phrases = dino_model.predict_with_caption(
image=image,
caption=text,
box_threshold=0.25,
text_threshold=0.5
)
if len(detections.xyxy) > 1:
nms_idx = torchvision.ops.nms(
torch.from_numpy(detections.xyxy),
torch.from_numpy(detections.confidence),
0.8,
).numpy().tolist()
detections.xyxy = detections.xyxy[nms_idx]
detections.confidence = detections.confidence[nms_idx]
bbox = detections.xyxy[np.argmax(detections.confidence)]
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
original_size = image.shape[:2]
image = transform.apply_image(image)
image = torch.as_tensor(image).cuda()
image = image.permute(2, 0, 1).contiguous()
bbox = transform.apply_boxes(bbox, original_size)
bbox = torch.as_tensor(bbox, dtype=torch.float).cuda()
pixel_mean = torch.tensor([123.675, 116.28, 103.53]).view(3,1,1).cuda()
pixel_std = torch.tensor([58.395, 57.12, 57.375]).view(3,1,1).cuda()
image = (image - pixel_mean) / pixel_std
h, w = image.shape[-2:]
pad_size = image.shape[-2:]
padh = 1024 - h
padw = 1024 - w
image = F.pad(image, (0, padw, 0, padh))
sample = {'image': image[None, ...], 'bbox': bbox[None, ...], 'ori_shape': original_size, 'pad_shape': pad_size}
return sample
if __name__ == '__main__':
print('Torch Version: ', torch.__version__)
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='config/MAM-ViTB-8gpu.toml')
parser.add_argument('--benchmark', type=str, default='him2k', choices=['him2k', 'him2k_comp', 'rwp636', 'ppm100', 'am2k', 'pm10k', 'rw100'])
parser.add_argument('--checkpoint', type=str, default='checkpoints/mam_sam_vitb.pth',
help="path of checkpoint")
parser.add_argument('--image-ext', type=str, default='.jpg', help="input image ext")
parser.add_argument('--mask-ext', type=str, default='.png', help="input mask ext")
parser.add_argument('--output', type=str, default='outputs/', help="output dir")
parser.add_argument('--os8_width', type=int, default=10, help="guidance threshold")
parser.add_argument('--os4_width', type=int, default=20, help="guidance threshold")
parser.add_argument('--os1_width', type=int, default=10, help="guidance threshold")
parser.add_argument('--twoside', action='store_true', default=False, help='post process with twoside of the guidance')
parser.add_argument('--sam', action='store_true', default=False, help='return mask')
parser.add_argument('--maskguide', action='store_true', default=False, help='mask guidance')
parser.add_argument('--postprocess', action='store_true', default=False, help='postprocess to remove bg')
parser.add_argument('--prompt', type=str, default='box', choices=['box', 'point', 'text'])
# Parse configuration
args = parser.parse_args()
with open(args.config) as f:
utils.load_config(toml.load(f))
# Check if toml config file is loaded
if CONFIG.is_default:
raise ValueError("No .toml config loaded.")
args.output = os.path.join(args.output)
utils.make_dir(args.output)
# build model
model = networks.get_generator_m2m(seg=CONFIG.model.arch.seg, m2m=CONFIG.model.arch.m2m)
model.cuda()
# load checkpoint
checkpoint = torch.load(args.checkpoint)
model.m2m.load_state_dict(utils.remove_prefix_state_dict(checkpoint['state_dict']), strict=True)
# inference
model = model.eval()
n_parameters = sum(p.numel() for p in model.m2m.parameters() if p.requires_grad)
print('number of params:', n_parameters)
if args.prompt == 'text':
GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT_PATH = "checkpoints/groundingdino_swint_ogc.pth"
grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH)
if args.benchmark == 'him2k':
image_dir = CONFIG.benchmark.him2k_img
alpha_dir = CONFIG.benchmark.him2k_alpha
for i, image_name in enumerate(tqdm(os.listdir(image_dir))):
image_path = os.path.join(image_dir, image_name)
alpha_path = os.path.join(alpha_dir, os.path.splitext(image_name)[0])
output_path = os.path.join(args.output, os.path.splitext(image_name)[0])
utils.make_dir(output_path)
for alpha_single_dir in sorted(os.listdir(alpha_path)):
alpha_single_path = os.path.join(alpha_path, alpha_single_dir)
image_dict = generator_tensor_dict(image_path, alpha_single_path, args)
alpha_pred = single_ms_inference(model, image_dict, args)
cv2.imwrite(os.path.join(output_path, alpha_single_dir), alpha_pred)
elif args.benchmark == 'him2k_comp':
image_dir = CONFIG.benchmark.him2k_comp_img
alpha_dir = CONFIG.benchmark.him2k_comp_alpha
for i, image_name in enumerate(tqdm(os.listdir(image_dir))):
image_path = os.path.join(image_dir, image_name)
alpha_path = os.path.join(alpha_dir, os.path.splitext(image_name)[0])
output_path = os.path.join(args.output, os.path.splitext(image_name)[0])
utils.make_dir(output_path)
for alpha_single_dir in sorted(os.listdir(alpha_path)):
alpha_single_path = os.path.join(alpha_path, alpha_single_dir)
image_dict = generator_tensor_dict(image_path, alpha_single_path, args)
alpha_pred = single_ms_inference(model, image_dict, args)
cv2.imwrite(os.path.join(output_path, alpha_single_dir), alpha_pred)
elif args.benchmark == 'rwp636':
image_dir = CONFIG.benchmark.rwp636_img
alpha_dir = CONFIG.benchmark.rwp636_alpha
for i, image_name in enumerate(tqdm(os.listdir(image_dir))):
image_path = os.path.join(image_dir, image_name)
alpha_path = os.path.join(alpha_dir, os.path.splitext(image_name)[0]+'.png')
image_dict = generator_tensor_dict(image_path, alpha_path, args)
alpha_pred = single_ms_inference(model, image_dict, args)
cv2.imwrite(os.path.join(args.output, os.path.splitext(image_name)[0]+'.png'), alpha_pred)
elif args.benchmark == 'ppm100':
image_dir = CONFIG.benchmark.ppm100_img
alpha_dir = CONFIG.benchmark.ppm100_alpha
for i, image_name in enumerate(tqdm(os.listdir(image_dir))):
image_path = os.path.join(image_dir, image_name)
alpha_path = os.path.join(alpha_dir, image_name)
image_dict = generator_tensor_dict(image_path, alpha_path, args)
alpha_pred = single_ms_inference(model, image_dict, args)
cv2.imwrite(os.path.join(args.output, os.path.splitext(image_name)[0]+'.png'), alpha_pred)
elif args.benchmark == 'am2k':
image_dir = CONFIG.benchmark.am2k_img
alpha_dir = CONFIG.benchmark.am2k_alpha
for i, image_name in enumerate(tqdm(os.listdir(image_dir))):
image_path = os.path.join(image_dir, image_name)
alpha_path = os.path.join(alpha_dir, os.path.splitext(image_name)[0]+'.png')
image_dict = generator_tensor_dict(image_path, alpha_path, args)
alpha_pred = single_ms_inference(model, image_dict, args)
cv2.imwrite(os.path.join(args.output, os.path.splitext(image_name)[0]+'.png'), alpha_pred)
elif args.benchmark == 'pm10k':
image_dir = CONFIG.benchmark.pm10k_img
alpha_dir = CONFIG.benchmark.pm10k_alpha
for i, image_name in enumerate(tqdm(os.listdir(image_dir))):
image_path = os.path.join(image_dir, image_name)
alpha_path = os.path.join(alpha_dir, os.path.splitext(image_name)[0]+'.png')
image_dict = generator_tensor_dict(image_path, alpha_path, args)
alpha_pred = single_ms_inference(model, image_dict, args)
cv2.imwrite(os.path.join(args.output, os.path.splitext(image_name)[0]+'.png'), alpha_pred)
elif args.benchmark == 'rw100':
image_dir = CONFIG.benchmark.rw100_img
text_dir = CONFIG.benchmark.rw100_text
index_dir = CONFIG.benchmark.rw100_index
alpha_dir = CONFIG.benchmark.rw100_alpha
if args.prompt == 'text':
index_data = json.load(open(index_dir, 'r'))
text_data = json.load(open(text_dir, 'r'))
for i, image_name in enumerate(tqdm(os.listdir(image_dir))):
if args.prompt == 'text':
image_path = os.path.join(image_dir, image_name)
text_label = text_data[os.path.splitext(image_name)[0]]
index_label = index_data[text_label['image_name']]
text = text_label['expressions'][index_label]
image_dict = generator_tensor_dict_from_text(image_path, text, grounding_dino_model, args)
alpha_pred = single_ms_inference(model, image_dict, args)
cv2.imwrite(os.path.join(args.output, os.path.splitext(image_name)[0]+'.png'), alpha_pred)
else:
image_path = os.path.join(image_dir, image_name)
alpha_path = os.path.join(alpha_dir, os.path.splitext(image_name)[0]+'.png')
image_dict = generator_tensor_dict(image_path, alpha_path, args)
alpha_pred = single_ms_inference(model, image_dict, args)
cv2.imwrite(os.path.join(args.output, os.path.splitext(image_name)[0]+'.png'), alpha_pred)
\ No newline at end of file
# ------------------------------------------------------------------------
# Modified from MGMatting (https://github.com/yucornetto/MGMatting)
# ------------------------------------------------------------------------
import os
import toml
import argparse
from pprint import pprint
import torch
from torch.utils.data import DataLoader
import utils
from utils import CONFIG
from trainer import Trainer
from dataloader.image_file import ImageFileTrain
from dataloader.data_generator import DataGenerator
from dataloader.prefetcher import Prefetcher
import wandb
import warnings
warnings.filterwarnings("ignore")
def main():
# Train or Test
if CONFIG.phase.lower() == "train":
# set distributed training
if CONFIG.dist:
CONFIG.gpu = CONFIG.local_rank
torch.cuda.set_device(CONFIG.gpu)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
CONFIG.world_size = torch.distributed.get_world_size()
# Create directories if not exist.
if CONFIG.local_rank == 0:
utils.make_dir(CONFIG.log.logging_path)
utils.make_dir(CONFIG.log.tensorboard_path)
utils.make_dir(CONFIG.log.checkpoint_path)
if CONFIG.wandb:
wandb.init(project="mam", config=CONFIG, name=CONFIG.version)
# Create a logger
logger, tb_logger = utils.get_logger(CONFIG.log.logging_path,
CONFIG.log.tensorboard_path,
logging_level=CONFIG.log.logging_level)
train_dataset = DataGenerator(phase='train')
if CONFIG.dist:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else:
train_sampler = None
train_dataloader = DataLoader(train_dataset,
batch_size=CONFIG.model.batch_size,
shuffle=(train_sampler is None),
num_workers=CONFIG.data.workers,
pin_memory=True,
sampler=train_sampler,
drop_last=True)
train_dataloader = Prefetcher(train_dataloader)
trainer = Trainer(train_dataloader=train_dataloader,
test_dataloader=None,
logger=logger,
tb_logger=tb_logger)
trainer.train()
else:
raise NotImplementedError("Unknown Phase: {}".format(CONFIG.phase))
if __name__ == '__main__':
print('Torch Version: ', torch.__version__)
parser = argparse.ArgumentParser()
parser.add_argument('--phase', type=str, default='train')
parser.add_argument('--config', type=str, default='config/gca-dist.toml')
parser.add_argument('--local_rank', type=int, default=0)
# Parse configuration
args = parser.parse_args()
with open(args.config) as f:
utils.load_config(toml.load(f))
# Check if toml config file is loaded
if CONFIG.is_default:
raise ValueError("No .toml config loaded.")
CONFIG.phase = args.phase
CONFIG.log.logging_path = os.path.join(CONFIG.log.logging_path, CONFIG.version)
CONFIG.log.tensorboard_path = os.path.join(CONFIG.log.tensorboard_path, CONFIG.version)
CONFIG.log.checkpoint_path = os.path.join(CONFIG.log.checkpoint_path, CONFIG.version)
if args.local_rank == 0:
print('CONFIG: ')
pprint(CONFIG)
CONFIG.local_rank = args.local_rank
# Train
main()
# 模型唯一标识
modelCode = 1119
# 模型名称
modelName=matting-anything_pytorch
# 模型描述
modelDescription=一个高效和通用的抠图框架。
# 应用场景
appScenario=AIGC,零售,制造,电商,医疗,教育
# 框架类型
frameType=pytorch
from .generator_m2m import *
\ No newline at end of file
# ------------------------------------------------------------------------
# Modified from MGMatting (https://github.com/yucornetto/MGMatting)
# ------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import CONFIG
from networks import m2ms, ops
import sys
sys.path.insert(0, './segment-anything')
from segment_anything import sam_model_registry
class sam_m2m(nn.Module):
def __init__(self, seg, m2m):
super(sam_m2m, self).__init__()
if m2m not in m2ms.__all__:
raise NotImplementedError("Unknown M2M {}".format(m2m))
self.m2m = m2ms.__dict__[m2m](nc=256)
if seg == 'sam_vit_b':
self.seg_model = sam_model_registry['vit_b'](checkpoint='segment-anything/checkpoints/sam_vit_b_01ec64.pth')
elif seg == 'sam_vit_l':
self.seg_model = sam_model_registry['vit_l'](checkpoint='segment-anything/checkpoints/sam_vit_l_0b3195.pth')
elif seg == 'sam_vit_h':
self.seg_model = sam_model_registry['vit_h'](checkpoint='segment-anything/checkpoints/sam_vit_h_4b8939.pth')
self.seg_model.eval()
def forward(self, image, guidance):
self.seg_model.eval()
with torch.no_grad():
feas, masks = self.seg_model.forward_m2m(image, guidance, multimask_output=True)
pred = self.m2m(feas, image, masks)
return pred
def forward_inference(self, image_dict):
self.seg_model.eval()
with torch.no_grad():
feas, masks, post_masks = self.seg_model.forward_m2m_inference(image_dict, multimask_output=True)
pred = self.m2m(feas, image_dict["image"], masks)
return feas, pred, post_masks
def get_generator_m2m(seg, m2m):
if 'sam' in seg:
generator = sam_m2m(seg=seg, m2m=m2m)
return generator
\ No newline at end of file
from .conv_sam import SAM_Decoder_Deep
__all__ = ['sam_decoder_deep']
def sam_decoder_deep(nc, **kwargs):
model = SAM_Decoder_Deep(nc, [2, 3, 3, 2], **kwargs)
return model
\ No newline at end of file
# ------------------------------------------------------------------------
# Modified from MGMatting (https://github.com/yucornetto/MGMatting)
# ------------------------------------------------------------------------
import logging
import torch.nn as nn
import torch
import torch.nn.functional as F
from networks import ops
def conv5x5(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""5x5 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride,
padding=2, groups=groups, bias=False, dilation=dilation)
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, upsample=None, norm_layer=None, large_kernel=False):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self.stride = stride
conv = conv5x5 if large_kernel else conv3x3
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
if self.stride > 1:
self.conv1 = ops.SpectralNorm(nn.ConvTranspose2d(inplanes, inplanes, kernel_size=4, stride=2, padding=1, bias=False))
else:
self.conv1 = ops.SpectralNorm(conv(inplanes, inplanes))
self.bn1 = norm_layer(inplanes)
self.activation = nn.LeakyReLU(0.2, inplace=True)
self.conv2 = ops.SpectralNorm(conv(inplanes, planes))
self.bn2 = norm_layer(planes)
self.upsample = upsample
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
out = self.bn2(out)
if self.upsample is not None:
identity = self.upsample(x)
out += identity
out = self.activation(out)
return out
class SAM_Decoder_Deep(nn.Module):
def __init__(self, nc, layers, block=BasicBlock, norm_layer=None, large_kernel=False, late_downsample=False):
super(SAM_Decoder_Deep, self).__init__()
self.logger = logging.getLogger("Logger")
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.large_kernel = large_kernel
self.kernel_size = 5 if self.large_kernel else 3
#self.inplanes = 512 if layers[0] > 0 else 256
self.inplanes = 256
self.late_downsample = late_downsample
self.midplanes = 64 if late_downsample else 32
self.conv1 = ops.SpectralNorm(nn.ConvTranspose2d(self.midplanes, 32, kernel_size=4, stride=2, padding=1, bias=False))
self.bn1 = norm_layer(32)
self.leaky_relu = nn.LeakyReLU(0.2, inplace=True)
self.upsample = nn.UpsamplingNearest2d(scale_factor=2)
self.tanh = nn.Tanh()
#self.layer1 = self._make_layer(block, 256, layers[0], stride=2)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
self.layer4 = self._make_layer(block, self.midplanes, layers[3], stride=2)
self.refine_OS1 = nn.Sequential(
nn.Conv2d(32, 32, kernel_size=self.kernel_size, stride=1, padding=self.kernel_size//2, bias=False),
norm_layer(32),
self.leaky_relu,
nn.Conv2d(32, 1, kernel_size=self.kernel_size, stride=1, padding=self.kernel_size//2),)
self.refine_OS4 = nn.Sequential(
nn.Conv2d(64, 32, kernel_size=self.kernel_size, stride=1, padding=self.kernel_size//2, bias=False),
norm_layer(32),
self.leaky_relu,
nn.Conv2d(32, 1, kernel_size=self.kernel_size, stride=1, padding=self.kernel_size//2),)
self.refine_OS8 = nn.Sequential(
nn.Conv2d(128, 32, kernel_size=self.kernel_size, stride=1, padding=self.kernel_size//2, bias=False),
norm_layer(32),
self.leaky_relu,
nn.Conv2d(32, 1, kernel_size=self.kernel_size, stride=1, padding=self.kernel_size//2),)
for m in self.modules():
if isinstance(m, nn.Conv2d):
if hasattr(m, "weight_bar"):
nn.init.xavier_uniform_(m.weight_bar)
else:
nn.init.xavier_uniform_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
for m in self.modules():
if isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
self.logger.debug(self)
def _make_layer(self, block, planes, blocks, stride=1):
if blocks == 0:
return nn.Sequential(nn.Identity())
norm_layer = self._norm_layer
upsample = None
if stride != 1:
upsample = nn.Sequential(
nn.UpsamplingNearest2d(scale_factor=2),
ops.SpectralNorm(conv1x1(self.inplanes + 4, planes * block.expansion)),
norm_layer(planes * block.expansion),
)
elif self.inplanes != planes * block.expansion:
upsample = nn.Sequential(
ops.SpectralNorm(conv1x1(self.inplanes + 4, planes * block.expansion)),
norm_layer(planes * block.expansion),
)
layers = [block(self.inplanes + 4, planes, stride, upsample, norm_layer, self.large_kernel)]
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, norm_layer=norm_layer, large_kernel=self.large_kernel))
return nn.Sequential(*layers)
def forward(self, x_os16, img, mask):
ret = {}
mask_os16 = F.interpolate(mask, x_os16.shape[2:], mode='bilinear', align_corners=False)
img_os16 = F.interpolate(img, x_os16.shape[2:], mode='bilinear', align_corners=False)
x = self.layer2(torch.cat((x_os16, img_os16, mask_os16), dim=1)) # N x 128 x 128 x 128
x_os8 = self.refine_OS8(x)
mask_os8 = F.interpolate(mask, x.shape[2:], mode='bilinear', align_corners=False)
img_os8 = F.interpolate(img, x.shape[2:], mode='bilinear', align_corners=False)
x = self.layer3(torch.cat((x, img_os8, mask_os8), dim=1)) # N x 64 x 256 x 256
x_os4 = self.refine_OS4(x)
mask_os4 = F.interpolate(mask, x.shape[2:], mode='bilinear', align_corners=False)
img_os4 = F.interpolate(img, x.shape[2:], mode='bilinear', align_corners=False)
x = self.layer4(torch.cat((x, img_os4, mask_os4), dim=1)) # N x 32 x 512 x 512
x = self.conv1(x)
x = self.bn1(x)
x = self.leaky_relu(x) # N x 32 x 1024 x 1024
x_os1 = self.refine_OS1(x) # N
x_os4 = F.interpolate(x_os4, scale_factor=4.0, mode='bilinear', align_corners=False)
x_os8 = F.interpolate(x_os8, scale_factor=8.0, mode='bilinear', align_corners=False)
x_os1 = (torch.tanh(x_os1) + 1.0) / 2.0
x_os4 = (torch.tanh(x_os4) + 1.0) / 2.0
x_os8 = (torch.tanh(x_os8) + 1.0) / 2.0
mask_os1 = F.interpolate(mask, x_os1.shape[2:], mode='bilinear', align_corners=False)
ret['alpha_os1'] = x_os1
ret['alpha_os4'] = x_os4
ret['alpha_os8'] = x_os8
ret['mask'] = mask_os1
return ret
\ No newline at end of file
import torch
from torch import nn
from torch.nn import Parameter
from torch.autograd import Variable
from torch.nn import functional as F
def l2normalize(v, eps=1e-12):
return v / (v.norm() + eps)
class SpectralNorm(nn.Module):
"""
Based on https://github.com/heykeetae/Self-Attention-GAN/blob/master/spectral.py
and add _noupdate_u_v() for evaluation
"""
def __init__(self, module, name='weight', power_iterations=1):
super(SpectralNorm, self).__init__()
self.module = module
self.name = name
self.power_iterations = power_iterations
if not self._made_params():
self._make_params()
def _update_u_v(self):
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
height = w.data.shape[0]
for _ in range(self.power_iterations):
v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
sigma = u.dot(w.view(height, -1).mv(v))
setattr(self.module, self.name, w / sigma.expand_as(w))
def _noupdate_u_v(self):
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
height = w.data.shape[0]
sigma = u.dot(w.view(height, -1).mv(v))
setattr(self.module, self.name, w / sigma.expand_as(w))
def _made_params(self):
try:
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
return True
except AttributeError:
return False
def _make_params(self):
w = getattr(self.module, self.name)
height = w.data.shape[0]
width = w.view(height, -1).data.shape[1]
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
u.data = l2normalize(u.data)
v.data = l2normalize(v.data)
w_bar = Parameter(w.data)
del self.module._parameters[self.name]
self.module.register_parameter(self.name + "_u", u)
self.module.register_parameter(self.name + "_v", v)
self.module.register_parameter(self.name + "_bar", w_bar)
def forward(self, *args):
# if torch.is_grad_enabled() and self.module.training:
if self.module.training:
self._update_u_v()
else:
self._noupdate_u_v()
return self.module.forward(*args)
class ASPP(nn.Module):
'''
based on https://github.com/chenxi116/DeepLabv3.pytorch/blob/master/deeplab.py
'''
def __init__(self, in_channel, out_channel, conv=nn.Conv2d, norm=nn.BatchNorm2d):
super(ASPP, self).__init__()
mid_channel = 256
dilations = [1, 2, 4, 8]
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.relu = nn.ReLU(inplace=True)
self.aspp1 = conv(in_channel, mid_channel, kernel_size=1, stride=1, dilation=dilations[0], bias=False)
self.aspp2 = conv(in_channel, mid_channel, kernel_size=3, stride=1,
dilation=dilations[1], padding=dilations[1],
bias=False)
self.aspp3 = conv(in_channel, mid_channel, kernel_size=3, stride=1,
dilation=dilations[2], padding=dilations[2],
bias=False)
self.aspp4 = conv(in_channel, mid_channel, kernel_size=3, stride=1,
dilation=dilations[3], padding=dilations[3],
bias=False)
self.aspp5 = conv(in_channel, mid_channel, kernel_size=1, stride=1, bias=False)
self.aspp1_bn = norm(mid_channel)
self.aspp2_bn = norm(mid_channel)
self.aspp3_bn = norm(mid_channel)
self.aspp4_bn = norm(mid_channel)
self.aspp5_bn = norm(mid_channel)
self.conv2 = conv(mid_channel * 5, out_channel, kernel_size=1, stride=1,
bias=False)
self.bn2 = norm(out_channel)
def forward(self, x):
x1 = self.aspp1(x)
x1 = self.aspp1_bn(x1)
x1 = self.relu(x1)
x2 = self.aspp2(x)
x2 = self.aspp2_bn(x2)
x2 = self.relu(x2)
x3 = self.aspp3(x)
x3 = self.aspp3_bn(x3)
x3 = self.relu(x3)
x4 = self.aspp4(x)
x4 = self.aspp4_bn(x4)
x4 = self.relu(x4)
x5 = self.global_pooling(x)
x5 = self.aspp5(x5)
x5 = self.aspp5_bn(x5)
x5 = self.relu(x5)
x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='nearest')(x5)
x = torch.cat((x1, x2, x3, x4, x5), 1)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
return x
\ No newline at end of file
[flake8]
ignore = W503, E203, E221, C901, C408, E741, C407, B017, F811, C101, EXE001, EXE002
max-line-length = 100
max-complexity = 18
select = B,C,E,F,W,T4,B9
per-file-ignores =
**/__init__.py:F401,F403,E402
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment