Commit bea94578 authored by bailuo's avatar bailuo
Browse files

init

parents
Pipeline #1988 failed with stages
in 0 seconds
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = [
"window_partition",
"window_unpartition",
"add_decomposed_rel_pos",
"get_abs_pos",
"PatchEmbed",
]
def window_partition(x, window_size):
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(windows, window_size, pad_hw, hw):
"""
Window unpartition into original sequences and removing padding.
Args:
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def get_rel_pos(q_size, k_size, rel_pos):
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
attn = (
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
).view(B, q_h * q_w, k_h * k_w)
return attn
def get_abs_pos(abs_pos, has_cls_token, hw):
"""
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
dimension for the original embeddings.
Args:
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
hw (Tuple): size of input image tokens.
Returns:
Absolute positional embeddings after processing with shape (1, H, W, C)
"""
h, w = hw
if has_cls_token:
abs_pos = abs_pos[:, 1:]
xy_num = abs_pos.shape[1]
size = int(math.sqrt(xy_num))
assert size * size == xy_num
if size != h or size != w:
new_abs_pos = F.interpolate(
abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
size=(h, w),
mode="bicubic",
align_corners=False,
)
return new_abs_pos.permute(0, 2, 3, 1)
else:
return abs_pos.reshape(1, h, w, -1)
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768
):
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
def forward(self, x):
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x
import logging
import math
import fvcore.nn.weight_init as weight_init
import torch
import torch.nn as nn
from torch.nn import functional as F
from detectron2.layers import CNNBlockBase, Conv2d, get_norm
from detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous
from fairscale.nn.checkpoint import checkpoint_wrapper
from timm.models.layers import DropPath, Mlp, trunc_normal_
from .backbone import Backbone
from .utils import (
PatchEmbed,
add_decomposed_rel_pos,
get_abs_pos,
window_partition,
window_unpartition,
)
logger = logging.getLogger(__name__)
__all__ = ["ViT"]
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim,
num_heads=8,
qkv_bias=True,
use_rel_pos=False,
rel_pos_zero_init=True,
input_size=None,
):
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool: If True, add a learnable bias to query, key, value.
rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (int or None): Input resolution for calculating the relative positional
parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
if not rel_pos_zero_init:
trunc_normal_(self.rel_pos_h, std=0.02)
trunc_normal_(self.rel_pos_w, std=0.02)
def forward(self, x):
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)
return x
class LayerNorm(nn.Module):
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape, )
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class ResBottleneckBlock(CNNBlockBase):
"""
The standard bottleneck residual block without the last activation layer.
It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
"""
def __init__(
self,
in_channels,
out_channels,
bottleneck_channels,
norm="LN",
act_layer=nn.GELU,
conv_kernels=3,
conv_paddings=1,
):
"""
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
bottleneck_channels (int): number of output channels for the 3x3
"bottleneck" conv layers.
norm (str or callable): normalization for all conv layers.
See :func:`layers.get_norm` for supported format.
act_layer (callable): activation for all conv layers.
"""
super().__init__(in_channels, out_channels, 1)
self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False)
self.norm1 = get_norm(norm, bottleneck_channels)
self.act1 = act_layer()
self.conv2 = Conv2d(
bottleneck_channels,
bottleneck_channels,
conv_kernels,
padding=conv_paddings,
bias=False,
)
self.norm2 = get_norm(norm, bottleneck_channels)
self.act2 = act_layer()
self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False)
self.norm3 = get_norm(norm, out_channels)
for layer in [self.conv1, self.conv2, self.conv3]:
weight_init.c2_msra_fill(layer)
for layer in [self.norm1, self.norm2]:
layer.weight.data.fill_(1.0)
layer.bias.data.zero_()
# zero init last norm layer.
self.norm3.weight.data.zero_()
self.norm3.bias.data.zero_()
def forward(self, x):
out = x
for layer in self.children():
out = layer(out)
out = x + out
return out
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=True,
drop_path=0.0,
norm_layer=nn.LayerNorm,
act_layer=nn.GELU,
use_rel_pos=False,
rel_pos_zero_init=True,
window_size=0,
use_cc_attn = False,
use_residual_block=False,
use_convnext_block=False,
input_size=None,
res_conv_kernel_size=3,
res_conv_padding=1,
):
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
drop_path (float): Stochastic depth rate.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then not
use window attention.
use_residual_block (bool): If True, use a residual block after the MLP block.
input_size (int or None): Input resolution for calculating the relative positional
parameter size.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer)
self.window_size = window_size
self.use_residual_block = use_residual_block
if use_residual_block:
# Use a residual block with bottleneck channel as dim // 2
self.residual = ResBottleneckBlock(
in_channels=dim,
out_channels=dim,
bottleneck_channels=dim // 2,
norm="LN",
act_layer=act_layer,
conv_kernels=res_conv_kernel_size,
conv_paddings=res_conv_padding,
)
self.use_convnext_block = use_convnext_block
if use_convnext_block:
self.convnext = ConvNextBlock(dim = dim)
if use_cc_attn:
self.attn = CrissCrossAttention(dim)
def forward(self, x):
shortcut = x
x = self.norm1(x)
if self.training==False:
if self.window_size > 0:
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
else:
x_ori = x
B, H, W, C = x.shape
fea = torch.zeros_like(x)
xs = []
stride_h, stride_w = 2, 2
for sh in range(stride_h):
for sw in range(stride_w):
xs.append(x[:, sh::stride_h, sw::stride_w])
x = torch.cat(xs, dim=0)
fea_list = []
torch.cuda.empty_cache()
for i in range(x.shape[0]):
fea_list.append(self.attn(x[i:i+1]))
torch.cuda.empty_cache()
x = torch.cat(fea_list, dim=0)
i = 0
for sh in range(stride_h):
for sw in range(stride_w):
fea[:, sh::stride_h, sw::stride_w] = x[i:i+1]
i = i+1
x = fea
else:
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
if self.use_residual_block:
x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
if self.use_convnext_block:
x = self.convnext(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
return x
class ViT(Backbone):
"""
This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
"Exploring Plain Vision Transformer Backbones for Object Detection",
https://arxiv.org/abs/2203.16527
"""
def __init__(
self,
img_size=1024,
patch_size=16,
in_chans=3,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
drop_path_rate=0.0,
norm_layer=nn.LayerNorm,
act_layer=nn.GELU,
use_abs_pos=True,
use_rel_pos=False,
rel_pos_zero_init=True,
window_size=0,
window_block_indexes=(),
residual_block_indexes=(),
use_act_checkpoint=False,
pretrain_img_size=224,
pretrain_use_cls_token=True,
out_feature="last_feat",
res_conv_kernel_size=3,
res_conv_padding=1,
):
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
drop_path_rate (float): Stochastic depth rate.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
window_block_indexes (list): Indexes for blocks using window attention.
residual_block_indexes (list): Indexes for blocks using conv propagation.
use_act_checkpoint (bool): If True, use activation checkpointing.
pretrain_img_size (int): input image size for pretraining models.
pretrain_use_cls_token (bool): If True, pretrainig models use class token.
out_feature (str): name of the feature from the last block.
"""
super().__init__()
self.pretrain_use_cls_token = pretrain_use_cls_token
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
else:
self.pos_embed = None
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i in window_block_indexes else 0,
use_residual_block=i in residual_block_indexes,
input_size=(img_size // patch_size, img_size // patch_size),
res_conv_kernel_size=res_conv_kernel_size,
res_conv_padding=res_conv_padding,
)
if use_act_checkpoint:
block = checkpoint_wrapper(block)
self.blocks.append(block)
self._out_feature_channels = {out_feature: embed_dim}
self._out_feature_strides = {out_feature: patch_size}
self._out_features = [out_feature]
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=0.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + get_abs_pos(
self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
)
for blk in self.blocks:
x = blk(x)
outputs = {self._out_features[0]: x.permute(0, 3, 1, 2)}
return outputs['last_feat']
\ No newline at end of file
from .matting_criterion import MattingCriterion
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
class MattingCriterion(nn.Module):
def __init__(self,
*,
losses,
):
super(MattingCriterion, self).__init__()
self.losses = losses
def loss_gradient_penalty(self, sample_map ,preds, targets):
preds = preds['phas']
targets = targets['phas']
#sample_map for unknown area
scale = sample_map.shape[0]*262144/torch.sum(sample_map)
#gradient in x
sobel_x_kernel = torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]]).type(dtype=preds.type())
delta_pred_x = F.conv2d(preds, weight=sobel_x_kernel, padding=1)
delta_gt_x = F.conv2d(targets, weight=sobel_x_kernel, padding=1)
#gradient in y
sobel_y_kernel = torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]]).type(dtype=preds.type())
delta_pred_y = F.conv2d(preds, weight=sobel_y_kernel, padding=1)
delta_gt_y = F.conv2d(targets, weight=sobel_y_kernel, padding=1)
#loss
loss = (F.l1_loss(delta_pred_x*sample_map, delta_gt_x*sample_map)* scale + \
F.l1_loss(delta_pred_y*sample_map, delta_gt_y*sample_map)* scale + \
0.01 * torch.mean(torch.abs(delta_pred_x*sample_map))* scale + \
0.01 * torch.mean(torch.abs(delta_pred_y*sample_map))* scale)
return dict(loss_gradient_penalty=loss)
def loss_pha_laplacian(self, preds, targets):
assert 'phas' in preds and 'phas' in targets
loss = laplacian_loss(preds['phas'], targets['phas'])
return dict(loss_pha_laplacian=loss)
def unknown_l1_loss(self, sample_map, preds, targets):
scale = sample_map.shape[0]*262144/torch.sum(sample_map)
# scale = 1
loss = F.l1_loss(preds['phas']*sample_map, targets['phas']*sample_map)*scale
return dict(unknown_l1_loss=loss)
def known_l1_loss(self, sample_map, preds, targets):
new_sample_map = torch.zeros_like(sample_map)
new_sample_map[sample_map==0] = 1
if torch.sum(new_sample_map) == 0:
scale = 0
else:
scale = new_sample_map.shape[0]*262144/torch.sum(new_sample_map)
# scale = 1
loss = F.l1_loss(preds['phas']*new_sample_map, targets['phas']*new_sample_map)*scale
return dict(known_l1_loss=loss)
def forward(self, sample_map, preds, targets):
losses = dict()
for k in self.losses:
if k=='unknown_l1_loss' or k=='known_l1_loss' or k=='loss_gradient_penalty':
losses.update(getattr(self, k)(sample_map, preds, targets))
else:
losses.update(getattr(self, k)(preds, targets))
return losses
#-----------------Laplacian Loss-------------------------#
def laplacian_loss(pred, true, max_levels=5):
kernel = gauss_kernel(device=pred.device, dtype=pred.dtype)
pred_pyramid = laplacian_pyramid(pred, kernel, max_levels)
true_pyramid = laplacian_pyramid(true, kernel, max_levels)
loss = 0
for level in range(max_levels):
loss += (2 ** level) * F.l1_loss(pred_pyramid[level], true_pyramid[level])
return loss / max_levels
def laplacian_pyramid(img, kernel, max_levels):
current = img
pyramid = []
for _ in range(max_levels):
current = crop_to_even_size(current)
down = downsample(current, kernel)
up = upsample(down, kernel)
diff = current - up
pyramid.append(diff)
current = down
return pyramid
def gauss_kernel(device='cpu', dtype=torch.float32):
kernel = torch.tensor([[1, 4, 6, 4, 1],
[4, 16, 24, 16, 4],
[6, 24, 36, 24, 6],
[4, 16, 24, 16, 4],
[1, 4, 6, 4, 1]], device=device, dtype=dtype)
kernel /= 256
kernel = kernel[None, None, :, :]
return kernel
def gauss_convolution(img, kernel):
B, C, H, W = img.shape
img = img.reshape(B * C, 1, H, W)
img = F.pad(img, (2, 2, 2, 2), mode='reflect')
img = F.conv2d(img, kernel)
img = img.reshape(B, C, H, W)
return img
def downsample(img, kernel):
img = gauss_convolution(img, kernel)
img = img[:, :, ::2, ::2]
return img
def upsample(img, kernel):
B, C, H, W = img.shape
out = torch.zeros((B, C, H * 2, W * 2), device=img.device, dtype=img.dtype)
out[:, :, ::2, ::2] = img * 4
out = gauss_convolution(out, kernel)
return out
def crop_to_even_size(img):
H, W = img.shape[2:]
H = H - H % 2
W = W - W % 2
return img[:, :, :H, :W]
\ No newline at end of file
from .detail_capture import Detail_Capture
\ No newline at end of file
import torch
from torch import nn
from torch.nn import functional as F
class Basic_Conv3x3(nn.Module):
"""
Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers.
"""
def __init__(
self,
in_chans,
out_chans,
stride=2,
padding=1,
):
super().__init__()
self.conv = nn.Conv2d(in_chans, out_chans, 3, stride, padding, bias=False)
self.bn = nn.BatchNorm2d(out_chans)
self.relu = nn.ReLU(True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class ConvStream(nn.Module):
"""
Simple ConvStream containing a series of basic conv3x3 layers to extract detail features.
"""
def __init__(
self,
in_chans = 4,
out_chans = [48, 96, 192],
):
super().__init__()
self.convs = nn.ModuleList()
self.conv_chans = out_chans.copy()
self.conv_chans.insert(0, in_chans)
for i in range(len(self.conv_chans)-1):
in_chan_ = self.conv_chans[i]
out_chan_ = self.conv_chans[i+1]
self.convs.append(
Basic_Conv3x3(in_chan_, out_chan_)
)
def forward(self, x):
out_dict = {'D0': x}
for i in range(len(self.convs)):
x = self.convs[i](x)
name_ = 'D'+str(i+1)
out_dict[name_] = x
return out_dict
class Fusion_Block(nn.Module):
"""
Simple fusion block to fuse feature from ConvStream and Plain Vision Transformer.
"""
def __init__(
self,
in_chans,
out_chans,
):
super().__init__()
self.conv = Basic_Conv3x3(in_chans, out_chans, stride=1, padding=1)
def forward(self, x, D):
F_up = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
out = torch.cat([D, F_up], dim=1)
out = self.conv(out)
return out
class Matting_Head(nn.Module):
"""
Simple Matting Head, containing only conv3x3 and conv1x1 layers.
"""
def __init__(
self,
in_chans = 32,
mid_chans = 16,
):
super().__init__()
self.matting_convs = nn.Sequential(
nn.Conv2d(in_chans, mid_chans, 3, 1, 1),
nn.BatchNorm2d(mid_chans),
nn.ReLU(True),
nn.Conv2d(mid_chans, 1, 1, 1, 0)
)
def forward(self, x):
x = self.matting_convs(x)
return x
class Detail_Capture(nn.Module):
"""
Simple and Lightweight Detail Capture Module for ViT Matting.
"""
def __init__(
self,
in_chans = 384,
img_chans=4,
convstream_out = [48, 96, 192],
fusion_out = [256, 128, 64, 32],
):
super().__init__()
assert len(fusion_out) == len(convstream_out) + 1
self.convstream = ConvStream(in_chans = img_chans)
self.conv_chans = self.convstream.conv_chans
self.fusion_blks = nn.ModuleList()
self.fus_channs = fusion_out.copy()
self.fus_channs.insert(0, in_chans)
for i in range(len(self.fus_channs)-1):
self.fusion_blks.append(
Fusion_Block(
in_chans = self.fus_channs[i] + self.conv_chans[-(i+1)],
out_chans = self.fus_channs[i+1],
)
)
self.matting_head = Matting_Head(
in_chans = fusion_out[-1],
)
def forward(self, features, images):
detail_features = self.convstream(images)
for i in range(len(self.fusion_blks)):
d_name_ = 'D'+str(len(self.fusion_blks)-i-1)
features = self.fusion_blks[i](features, detail_features[d_name_])
phas = torch.sigmoid(self.matting_head(features))
return {'phas': phas}
from .vitmatte import ViTMatte
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import os
from detectron2.structures import ImageList
class ViTMatte(nn.Module):
def __init__(self,
*,
backbone,
criterion,
pixel_mean,
pixel_std,
input_format,
size_divisibility,
decoder,
):
super(ViTMatte, self).__init__()
self.backbone = backbone
self.criterion = criterion
self.input_format = input_format
self.size_divisibility = size_divisibility
self.decoder = decoder
self.register_buffer(
"pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False
)
self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False)
assert (
self.pixel_mean.shape == self.pixel_std.shape
), f"{self.pixel_mean} and {self.pixel_std} have different shapes!"
@property
def device(self):
return self.pixel_mean.device
def forward(self, batched_inputs):
images, targets, H, W = self.preprocess_inputs(batched_inputs)
features = self.backbone(images)
outputs = self.decoder(features, images)
if self.training:
assert targets is not None
trimap = images[:, 3:4]
sample_map = torch.zeros_like(trimap)
sample_map[trimap==0.5] = 1
losses = self.criterion(sample_map ,outputs, targets)
return losses
else:
outputs['phas'] = outputs['phas'][:,:,:H,:W]
return outputs
def preprocess_inputs(self, batched_inputs):
"""
Normalize, pad and batch the input images.
"""
images = batched_inputs["image"].to(self.device)
trimap = batched_inputs['trimap'].to(self.device)
images = (images - self.pixel_mean) / self.pixel_std
if 'fg' in batched_inputs.keys():
trimap[trimap < 85] = 0
trimap[trimap >= 170] = 1
trimap[trimap >= 85] = 0.5
images = torch.cat((images, trimap), dim=1)
B, C, H, W = images.shape
if images.shape[-1]%32!=0 or images.shape[-2]%32!=0:
new_H = (32-images.shape[-2]%32) + H
new_W = (32-images.shape[-1]%32) + W
new_images = torch.zeros((images.shape[0], images.shape[1], new_H, new_W)).to(self.device)
new_images[:,:,:H,:W] = images[:,:,:,:]
del images
images = new_images
if "alpha" in batched_inputs:
phas = batched_inputs["alpha"].to(self.device)
else:
phas = None
return images, dict(phas=phas), H, W
\ No newline at end of file
#torch==2.0.0
#torchvision
#tensorboard
timm==0.5.4
opencv-python==4.5.3.56
setuptools==58.2.0
easydict
wget
scikit-image
gradio==3.34.0
fairscale
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment