"lib/llm/vscode:/vscode.git/clone" did not exist on "c4106e6a27290d7ec36f7661a0edf7f214495ea4"
Commit f55a786e authored by luopl's avatar luopl
Browse files

Initial commit

parents
Pipeline #1081 canceled with stages
# Copyright (c) Facebook, Inc. and its affiliates.
from .backbone.swin import D2SwinTransformer
from .backbone.clip import CLIP
from .heads.sed_head import SEDHead
\ No newline at end of file
# Copyright (c) Facebook, Inc. and its affiliates.
"""
Copyright (2023) Bytedance Ltd. and/or its affiliates
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import torch.nn.functional as F
import math
from detectron2.utils import comm
import open_clip
from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
@BACKBONE_REGISTRY.register()
class CLIP(Backbone):
def __init__(self, cfg, input_shape):
super().__init__()
model_name = cfg.MODEL.ENC.CLIP_MODEL_NAME
pretrained= cfg.MODEL.ENC.CLIP_PRETRAINED_WEIGHTS
# download on local rank 0 first
if comm.get_local_rank() == 0:
open_clip.create_model_and_transforms(model_name, pretrained=pretrained)
comm.synchronize()
self.clip_model, _, _ = open_clip.create_model_and_transforms(model_name, pretrained=pretrained)
self.text_tokenizer = open_clip.get_tokenizer(model_name)
model_name = model_name.lower()
if 'convnext_' in model_name:
self.model_type = 'convnext'
if '_base' in model_name:
self.output_channels = [128, 128, 256, 512, 1024]
elif '_large' in model_name:
self.output_channels = [192, 192, 384, 768, 1536]
elif '_xxlarge' in model_name:
self.output_channels = [384, 384, 768, 1536, 3072]
self._out_feature_strides = {
"stem": 2,
"res2": 4,
"res3": 8,
"res4": 16,
"res5": 32,
"clip_embedding": -1
}
self._out_feature_channels = {
"stem": self.output_channels[0],
"res2": self.output_channels[1],
"res3": self.output_channels[2],
"res4": self.output_channels[3],
"res5": self.output_channels[4],
"clip_embedding": self.dim_latent
}
self.eval()
self.freeze_everything()
def freeze_everything(self):
for param in self.clip_model.parameters():
param.requires_grad = False
def encode_text(self, text, normalize: bool = False):
cast_dtype = self.clip_model.transformer.get_cast_dtype()
x = self.clip_model.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = x + self.clip_model.positional_embedding.to(cast_dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.clip_model.transformer(x, attn_mask=self.clip_model.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.clip_model.ln_final(x) # [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.clip_model.text_projection
return F.normalize(x, dim=-1) if normalize else x
def tokenize_text(self, text):
return self.text_tokenizer(text)
def extract_features(self, x):
return {
'convnext': self.extract_features_convnext,
}[self.model_type](x)
def visual_prediction_forward(self, x):
return {
'convnext': self.visual_prediction_forward_convnext,
}[self.model_type](x)
def extract_features_convnext(self, x):
out = {}
x = self.clip_model.visual.trunk.stem(x)
out['stem'] = x.contiguous() # os4
for i in range(4):
x = self.clip_model.visual.trunk.stages[i](x)
out[f'res{i+2}'] = x.contiguous() # res 2 (os4), 3 (os8), 4 (os16), 5 (os32)
x = self.clip_model.visual.trunk.norm_pre(x)
out['clip_vis_dense'] = x.contiguous()
return out
def visual_prediction_forward_convnext(self, x,):
batch, num_query, channel = x.shape
x = x.reshape(batch*num_query, channel, 1, 1) # fake 2D input
x = self.clip_model.visual.trunk.head(x)
x = self.clip_model.visual.head(x)
return x.view(batch, num_query, x.shape[-1]) # B x num_queries x 640
def get_text_classifier(self, text_list, device):
self.eval()
with torch.no_grad():
# reference for templates: https://github.com/mlfoundations/open_clip/blob/91f6cce16b7bee90b3b5d38ca305b5b3b67cc200/src/training/imagenet_zeroshot_data.py
text_tokens = self.tokenize_text(text_list)
text_tokens = text_tokens.to(device)
# we return un-normalized text feature.
text_features = self.encode_text(text_tokens, normalize=False)
return text_features
def forward(self, x):
self.eval()
with torch.no_grad():
return self.extract_features(x)
@property
def dim_latent(self):
return self.clip_model.text_projection.shape[-1]
def output_shape(self):
return {
name: ShapeSpec(
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
)
for name in ["stem", "res2", "res3", "res4", "res5", "clip_embedding"]
}
@property
def size_divisibility(self):
return -1
\ No newline at end of file
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu, Yutong Lin, Yixuan Wei
# --------------------------------------------------------
# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
class Mlp(nn.Module):
"""Multilayer perceptron."""
def __init__(
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(
self,
dim,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=0.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""Forward function.
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock(nn.Module):
"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(
self,
dim,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
)
self.H = None
self.W = None
def forward(self, x, mask_matrix):
"""Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
mask_matrix: Attention mask for cyclic shift.
"""
B, L, C = x.shape
H, W = self.H, self.W
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# pad feature maps to multiples of window size
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
attn_mask = mask_matrix
else:
shifted_x = x
attn_mask = None
# partition windows
x_windows = window_partition(
shifted_x, self.window_size
) # nW*B, window_size, window_size, C
x_windows = x_windows.view(
-1, self.window_size * self.window_size, C
) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchMerging(nn.Module):
"""Patch Merging Layer
Args:
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x, H, W):
"""Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
# padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
class BasicLayer(nn.Module):
"""A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of feature channels
depth (int): Depths of this stage.
num_heads (int): Number of attention head.
window_size (int): Local window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(
self,
dim,
depth,
num_heads,
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
):
super().__init__()
self.window_size = window_size
self.shift_size = window_size // 2
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList(
[
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
)
for i in range(depth)
]
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, H, W):
"""Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
# calculate attention mask for SW-MSA
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
h_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
w_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(
img_mask, self.window_size
) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
attn_mask == 0, float(0.0)
)
for blk in self.blocks:
blk.H, blk.W = H, W
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, attn_mask)
else:
x = blk(x, attn_mask)
if self.downsample is not None:
x_down = self.downsample(x, H, W)
Wh, Ww = (H + 1) // 2, (W + 1) // 2
return x, H, W, x_down, Wh, Ww
else:
return x, H, W, x, H, W
class PatchEmbed(nn.Module):
"""Image to Patch Embedding
Args:
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
# padding
_, _, H, W = x.size()
if W % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
if H % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
x = self.proj(x) # B C Wh Ww
if self.norm is not None:
Wh, Ww = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
return x
class SwinTransformer(nn.Module):
"""Swin Transformer backbone.
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
pretrain_img_size (int): Input image size for training the pretrained model,
used in absolute postion embedding. Default 224.
patch_size (int | tuple(int)): Patch size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
depths (tuple[int]): Depths of each Swin Transformer stage.
num_heads (tuple[int]): Number of attention head of each stage.
window_size (int): Window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
drop_rate (float): Dropout rate.
attn_drop_rate (float): Attention dropout rate. Default: 0.
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
out_indices (Sequence[int]): Output from which stages.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(
self,
pretrain_img_size=224,
patch_size=4,
in_chans=3,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.2,
norm_layer=nn.LayerNorm,
ape=False,
patch_norm=True,
out_indices=(0, 1, 2), #3),
frozen_stages=-1,
use_checkpoint=False,
):
super().__init__()
self.pretrain_img_size = pretrain_img_size
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.out_indices = out_indices
self.frozen_stages = frozen_stages
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None,
)
# absolute position embedding
if self.ape:
pretrain_img_size = to_2tuple(pretrain_img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [
pretrain_img_size[0] // patch_size[0],
pretrain_img_size[1] // patch_size[1],
]
self.absolute_pos_embed = nn.Parameter(
torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
)
trunc_normal_(self.absolute_pos_embed, std=0.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
)
self.layers.append(layer)
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
self.num_features = num_features
# add a norm layer for each output
for i_layer in out_indices:
layer = norm_layer(num_features[i_layer])
layer_name = f"norm{i_layer}"
self.add_module(layer_name, layer)
self._freeze_stages()
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
if self.frozen_stages >= 1 and self.ape:
self.absolute_pos_embed.requires_grad = False
if self.frozen_stages >= 2:
self.pos_drop.eval()
for i in range(0, self.frozen_stages - 1):
m = self.layers[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
def _init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
"""Forward function."""
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
if self.ape:
# interpolate the position embedding to the corresponding size
absolute_pos_embed = F.interpolate(
self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
)
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
else:
x = x.flatten(2).transpose(1, 2)
x = self.pos_drop(x)
outs = {}
for i in range(self.num_layers):
layer = self.layers[i]
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
if i in self.out_indices:
norm_layer = getattr(self, f"norm{i}")
x_out = norm_layer(x_out)
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
outs["res{}".format(i + 2)] = out
return outs
def train(self, mode=True):
"""Convert the model into training mode while keep layers freezed."""
super(SwinTransformer, self).train(mode)
self._freeze_stages()
@BACKBONE_REGISTRY.register()
class D2SwinTransformer(SwinTransformer, Backbone):
def __init__(self, cfg, input_shape):
pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE
patch_size = cfg.MODEL.SWIN.PATCH_SIZE
in_chans = 3
embed_dim = cfg.MODEL.SWIN.EMBED_DIM
depths = cfg.MODEL.SWIN.DEPTHS
num_heads = cfg.MODEL.SWIN.NUM_HEADS
window_size = cfg.MODEL.SWIN.WINDOW_SIZE
mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO
qkv_bias = cfg.MODEL.SWIN.QKV_BIAS
qk_scale = cfg.MODEL.SWIN.QK_SCALE
drop_rate = cfg.MODEL.SWIN.DROP_RATE
attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE
drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE
norm_layer = nn.LayerNorm
ape = cfg.MODEL.SWIN.APE
patch_norm = cfg.MODEL.SWIN.PATCH_NORM
super().__init__(
pretrain_img_size,
patch_size,
in_chans,
embed_dim,
depths,
num_heads,
window_size,
mlp_ratio,
qkv_bias,
qk_scale,
drop_rate,
attn_drop_rate,
drop_path_rate,
norm_layer,
ape,
patch_norm,
)
self._out_features = cfg.MODEL.SWIN.OUT_FEATURES
self._out_feature_strides = {
"res2": 4,
"res3": 8,
"res4": 16,
#"res5": 32,
}
self._out_feature_channels = {
"res2": self.num_features[0],
"res3": self.num_features[1],
"res4": self.num_features[2],
#"res5": self.num_features[3],
}
def forward(self, x):
"""
Args:
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
Returns:
dict[str->Tensor]: names and the corresponding features
"""
assert (
x.dim() == 4
), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
outputs = {}
y = super().forward(x)
for k in y.keys():
if k in self._out_features:
outputs[k] = y[k]
return outputs
def output_shape(self):
return {
name: ShapeSpec(
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
)
for name in self._out_features
}
@property
def size_divisibility(self):
return 32
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Facebook, Inc. and its affiliates.
import logging
from copy import deepcopy
from typing import Callable, Dict, List, Optional, Tuple, Union
from einops import rearrange
import fvcore.nn.weight_init as weight_init
from torch import nn
from torch.nn import functional as F
from detectron2.config import configurable
from detectron2.layers import Conv2d, ShapeSpec, get_norm
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
from ..transformer.sed_predictor import SEDPredictor
@SEM_SEG_HEADS_REGISTRY.register()
class SEDHead(nn.Module):
@configurable
def __init__(
self,
input_shape: Dict[str, ShapeSpec],
*,
num_classes: int,
ignore_value: int = -1,
# extra parameters
feature_resolution: list,
transformer_predictor: nn.Module,
):
"""
NOTE: this interface is experimental.
Args:
input_shape: shapes (channels and stride) of the input features
num_classes: number of classes to predict
pixel_decoder: the pixel decoder module
loss_weight: loss weight
ignore_value: category id to be ignored during training.
transformer_predictor: the transformer decoder that makes prediction
transformer_in_feature: input feature name to the transformer_predictor
"""
super().__init__()
input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
self.in_features = [k for k, v in input_shape]
self.ignore_value = ignore_value
self.predictor = transformer_predictor
self.num_classes = num_classes
self.feature_resolution = feature_resolution
@classmethod
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
return {
"input_shape": {
k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
},
"ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
"num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
"feature_resolution": cfg.MODEL.SEM_SEG_HEAD.FEATURE_RESOLUTION,
"transformer_predictor": SEDPredictor(
cfg,
),
}
def forward(self, features, guidance_features):
"""
Arguments:
img_feats: (B, C, HW)
affinity_features: (B, C, )
"""
img_feat = features
return self.predictor(img_feat, guidance_features)
\ No newline at end of file
# Copyright (c) Facebook, Inc. and its affiliates.
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_, DropPath
class ConvNextV2Block(nn.Module):
""" ConvNeXtV2 Block.
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
"""
def __init__(self, dim, drop_path=0.):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.grn = GRN(4 * dim)
self.pwconv2 = nn.Linear(4 * dim, dim)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = input + self.drop_path(x)
return x
class GRN(nn.Module):
""" GRN (Global Response Normalization) layer
"""
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
class ConvNextBlock(nn.Module):
r""" ConvNeXt Block. There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
We use (2) as we find it slightly faster in PyTorch
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
"""
def __init__(self, dim, kernel_size=7, drop_path=0., layer_scale_init_value=1e-6):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size//2, groups=dim) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(4 * dim, dim)
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
requires_grad=True) if layer_scale_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = input + self.drop_path(x)
return x
class LayerNorm(nn.Module):
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from .convnext import ConvNextBlock, ConvNextV2Block
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert
def window_partition(x, window_size: int):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size: int, H: int, W: int):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
head_dim (int): Number of channels per head (dim // num_heads if not set)
window_size (tuple[int]): The height and width of the window.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, appearance_guidance_dim, num_heads, head_dim=None, window_size=7, qkv_bias=True, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = to_2tuple(window_size) # Wh, Ww
win_h, win_w = self.window_size
self.window_area = win_h * win_w
self.num_heads = num_heads
head_dim = head_dim or dim // num_heads
attn_dim = head_dim * num_heads
self.scale = head_dim ** -0.5
self.q = nn.Linear(dim + appearance_guidance_dim, attn_dim, bias=qkv_bias)
self.k = nn.Linear(dim + appearance_guidance_dim, attn_dim, bias=qkv_bias)
self.v = nn.Linear(dim, attn_dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(attn_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
q = self.q(x).reshape(B_, N, self.num_heads, -1).permute(0, 2, 1, 3)
k = self.k(x).reshape(B_, N, self.num_heads, -1).permute(0, 2, 1, 3)
v = self.v(x[:, :, :self.dim]).reshape(B_, N, self.num_heads, -1).permute(0, 2, 1, 3)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if mask is not None:
num_win = mask.shape[0]
attn = attn.view(B_ // num_win, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
window_size (int): Window size.
num_heads (int): Number of attention heads.
head_dim (int): Enforce the number of channels per head
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(
self, dim, appearance_guidance_dim, input_resolution, num_heads=4, head_dim=None, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, appearance_guidance_dim=appearance_guidance_dim, num_heads=num_heads, head_dim=head_dim, window_size=to_2tuple(self.window_size),
qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
cnt = 0
for h in (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None)):
for w in (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None)):
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # num_win, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x, appearance_guidance):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
if appearance_guidance is not None:
appearance_guidance = appearance_guidance.view(B, H, W, -1)
x = torch.cat([x, appearance_guidance], dim=-1)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # num_win*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, x_windows.shape[-1]) # num_win*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_win*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class SwinTransformerBlockWrapper(nn.Module):
def __init__(self, dim, appearance_guidance_dim, input_resolution, nheads=4, window_size=5):
super().__init__()
self.block_1 = SwinTransformerBlock(dim, appearance_guidance_dim, input_resolution, num_heads=nheads, head_dim=None, window_size=window_size, shift_size=0)
self.block_2 = SwinTransformerBlock(dim, appearance_guidance_dim, input_resolution, num_heads=nheads, head_dim=None, window_size=window_size, shift_size=window_size // 2)
self.guidance_norm = nn.LayerNorm(appearance_guidance_dim) if appearance_guidance_dim > 0 else None
def forward(self, x, appearance_guidance=None):
"""
Arguments:
x: B C T H W
appearance_guidance: B C H W
"""
BT, C, H, W = x.shape
x = rearrange(x, 'BT C H W -> BT (H W) C')
if appearance_guidance is not None:
# appearance_guidance = self.guidance_norm(repeat(appearance_guidance, 'B C H W -> (B T) (H W) C', T=T))
pass
x = self.block_1(x, appearance_guidance)
x = self.block_2(x, appearance_guidance)
x = rearrange(x, 'BT (H W) C -> BT C H W', H=H, W=W)
return x
def elu_feature_map(x):
return torch.nn.functional.elu(x) + 1
class LinearAttention(nn.Module):
def __init__(self, eps=1e-6):
super().__init__()
self.feature_map = elu_feature_map
self.eps = eps
def forward(self, queries, keys, values):
""" Multi-Head linear attention proposed in "Transformers are RNNs"
Args:
queries: [N, L, H, D]
keys: [N, S, H, D]
values: [N, S, H, D]
q_mask: [N, L]
kv_mask: [N, S]
Returns:
queried_values: (N, L, H, D)
"""
Q = self.feature_map(queries)
K = self.feature_map(keys)
v_length = values.size(1)
values = values / v_length # prevent fp16 overflow
KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
return queried_values.contiguous()
class FullAttention(nn.Module):
def __init__(self, use_dropout=False, attention_dropout=0.1):
super().__init__()
self.use_dropout = use_dropout
self.dropout = nn.Dropout(attention_dropout)
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
""" Multi-head scaled dot-product attention, a.k.a full attention.
Args:
queries: [N, L, H, D]
keys: [N, S, H, D]
values: [N, S, H, D]
q_mask: [N, L]
kv_mask: [N, S]
Returns:
queried_values: (N, L, H, D)
"""
# Compute the unnormalized attention and apply the masks
QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
if kv_mask is not None:
QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf'))
# Compute the attention and the weighted average
softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
A = torch.softmax(softmax_temp * QK, dim=2)
if self.use_dropout:
A = self.dropout(A)
queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
return queried_values.contiguous()
class AttentionLayer(nn.Module):
def __init__(self, hidden_dim, guidance_dim, nheads=8, attention_type='linear'):
super().__init__()
self.nheads = nheads
self.q = nn.Linear(hidden_dim + guidance_dim, hidden_dim)
self.k = nn.Linear(hidden_dim + guidance_dim, hidden_dim)
self.v = nn.Linear(hidden_dim, hidden_dim)
if attention_type == 'linear':
self.attention = LinearAttention()
elif attention_type == 'full':
self.attention = FullAttention()
else:
raise NotImplementedError
def forward(self, x, guidance):
"""
Arguments:
x: B, L, C
guidance: B, L, C
"""
q = self.q(torch.cat([x, guidance], dim=-1)) if guidance is not None else self.q(x)
k = self.k(torch.cat([x, guidance], dim=-1)) if guidance is not None else self.k(x)
v = self.v(x)
q = rearrange(q, 'B L (H D) -> B L H D', H=self.nheads)
k = rearrange(k, 'B S (H D) -> B S H D', H=self.nheads)
v = rearrange(v, 'B S (H D) -> B S H D', H=self.nheads)
out = self.attention(q, k, v)
out = rearrange(out, 'B L H D -> B L (H D)')
return out
class ClassTransformerLayer(nn.Module):
def __init__(self, hidden_dim=64, guidance_dim=64, nheads=8, attention_type='linear', pooling_size=(4, 4)) -> None:
super().__init__()
self.pool = nn.AvgPool2d(pooling_size)
self.attention = AttentionLayer(hidden_dim, guidance_dim, nheads=nheads, attention_type=attention_type)
self.MLP = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 4),
nn.ReLU(),
nn.Linear(hidden_dim * 4, hidden_dim)
)
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
def pool_features(self, x):
"""
Intermediate pooling layer for computational efficiency.
Arguments:
x: B, C, T, H, W
"""
B = x.size(0)
x = rearrange(x, 'B C T H W -> (B T) C H W')
x = self.pool(x)
x = rearrange(x, '(B T) C H W -> B C T H W', B=B)
return x
def forward(self, x, guidance):
"""
Arguments:
x: B, C, T, H, W
guidance: B, T, C
"""
B, _, _, H, W = x.size()
x_pool = self.pool_features(x)
*_, H_pool, W_pool = x_pool.size()
x_pool = rearrange(x_pool, 'B C T H W -> (B H W) T C')
if guidance is not None:
guidance = repeat(guidance, 'B T C -> (B H W) T C', H=H_pool, W=W_pool)
x_pool = x_pool + self.attention(self.norm1(x_pool), guidance) # Attention
x_pool = x_pool + self.MLP(self.norm2(x_pool)) # MLP
x_pool = rearrange(x_pool, '(B H W) T C -> (B T) C H W', H=H_pool, W=W_pool)
x_pool = F.interpolate(x_pool, size=(H, W), mode='bilinear', align_corners=True)
x_pool = rearrange(x_pool, '(B T) C H W -> B C T H W', B=B)
x = x + x_pool # Residual
return x
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class Bottleneck(nn.Module):
expansion = 4
__constants__ = ['downsample']
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class AggregatorLayer(nn.Module):
def __init__(self, hidden_dim=64, text_guidance_dim=512, appearance_guidance=512, nheads=4, input_resolution=(20, 20), pooling_size=(5, 5), window_size=(10, 10), attention_type='linear') -> None:
super().__init__()
self.swin_block = SwinTransformerBlockWrapper(hidden_dim, appearance_guidance, input_resolution, nheads, window_size)
self.attention = ClassTransformerLayer(hidden_dim, text_guidance_dim, nheads=nheads, attention_type=attention_type, pooling_size=pooling_size)
def forward(self, x, appearance_guidance, text_guidance):
"""
Arguments:
x: B C T H W
"""
x = self.swin_block(x, appearance_guidance)
x = self.attention(x, text_guidance)
return x
class AggregatorResNetLayer(nn.Module):
def __init__(self, hidden_dim=64, appearance_guidance=512) -> None:
super().__init__()
self.conv_linear = nn.Conv2d(hidden_dim + appearance_guidance, hidden_dim, kernel_size=1, stride=1)
self.conv_layer = Bottleneck(hidden_dim, hidden_dim // 4)
def forward(self, x, appearance_guidance):
"""
Arguments:
x: B C T H W
"""
B, T = x.size(0), x.size(2)
x = rearrange(x, 'B C T H W -> (B T) C H W')
appearance_guidance = repeat(appearance_guidance, 'B C H W -> (B T) C H W', T=T)
x = self.conv_linear(torch.cat([x, appearance_guidance], dim=1))
x = self.conv_layer(x)
x = rearrange(x, '(B T) C H W -> B C T H W', B=B)
return x
class DoubleConv(nn.Module):
"""(convolution => [GN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.GroupNorm(mid_channels // 16, mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.GroupNorm(mid_channels // 16, mid_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, guidance_channels, text_guidance_channels, nheads, attention_type, pooling_size, cnext_type="V1", kernel_size=7,
input_resolution=(24,24), window_size=(10, 10)):
super().__init__()
corr_guidance_channels = guidance_channels
self.up = nn.ConvTranspose2d(in_channels, in_channels - guidance_channels, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels+corr_guidance_channels, out_channels)
if cnext_type=="V1":
self.cnext = ConvNextBlock(in_channels, kernel_size, 0.0, 1.0)
elif cnext_type=="V2":
self.cnext = ConvNextV2Block(in_channels, kernel_size, 0.0)
elif cnext_type == "Swin":
self.cnext = SwinTransformerBlockWrapper(in_channels, 0, input_resolution, nheads, window_size)
else:
raise NotImplementedError
self.attention = ClassTransformerLayer(in_channels, text_guidance_channels, nheads=nheads, attention_type=attention_type, pooling_size=pooling_size)
def forward(self, x, text_guidance, guidance=None, corr_guidance=None):
B, _, _, _ = guidance.size()
x = self.cnext(x)
x = rearrange(x, '(B T) C H W -> B C T H W', B=B)
x = self.attention(x, text_guidance)
x = rearrange(x, 'B C T H W -> (B T) C H W')
x = self.up(x)
if guidance is not None:
T = x.size(0) // guidance.size(0)
guidance = repeat(guidance, "B C H W -> (B T) C H W", T=T)
x = torch.cat([x, guidance], dim=1)
if corr_guidance is not None:
x = torch.cat([x, corr_guidance], dim=1)
return self.conv(x)
class Aggregator(nn.Module):
def __init__(self,
text_guidance_dim=512,
text_guidance_proj_dim=128,
appearance_guidance_dim=512,
appearance_guidance_proj_dim=128,
decoder_dims = (64, 32),
decoder_guidance_dims=(256, 128),
decoder_corr_guidance_dims=(80, 32, 16),
decoder_guidance_proj_dims=(32, 16),
num_layers=4,
nheads=4,
hidden_dim=128,
pooling_size=(6, 6),
feature_resolution=(24, 24),
window_size=12,
attention_type='linear',
prompt_channel=80,
cnext_type="V1",
kernel_size=(7, 7, 7),
fast_inference=False,
topK=1,
) -> None:
super().__init__()
self.num_layers = num_layers
self.hidden_dim = hidden_dim
# self.layers = nn.ModuleList([
# AggregatorLayer(
# hidden_dim=hidden_dim, text_guidance_dim=text_guidance_proj_dim, appearance_guidance=appearance_guidance_proj_dim,
# nheads=nheads, input_resolution=feature_resolution, pooling_size=pooling_size, window_size=window_size, attention_type=attention_type
# ) for _ in range(num_layers)
# ])
self.conv1 = nn.Conv2d(prompt_channel, hidden_dim, kernel_size=7, stride=1, padding=3)
# self.guidance_projection = nn.Sequential(
# nn.Conv2d(appearance_guidance_dim, appearance_guidance_proj_dim, kernel_size=3, stride=1, padding=1),
# nn.ReLU(),
# ) if appearance_guidance_dim > 0 else None
self.text_guidance_projection = nn.ModuleList([
nn.Sequential(
nn.Linear(text_guidance_dim, tp),
nn.ReLU(),
) for tp in [hidden_dim, decoder_dims[0], decoder_dims[1]]
]) if text_guidance_dim > 0 else None
if text_guidance_dim > 0:
text_guidance_channels = [hidden_dim, decoder_dims[0], decoder_dims[1]]
else:
text_guidance_channels = [0, 0, 0]
self.decoder_guidance_projection = nn.ModuleList([
nn.Sequential(
nn.Conv2d(d, dp, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
) for d, dp in zip(decoder_guidance_dims, decoder_guidance_proj_dims)
]) if decoder_guidance_dims[0] > 0 else None
self.decoder_corr_guidance_projection = nn.ModuleList([
nn.Sequential(
nn.ConvTranspose2d(d, dp, kernel_size=2, stride=2),
nn.ReLU(),
) for d, dp in zip(decoder_corr_guidance_dims, decoder_guidance_proj_dims)
]) if decoder_corr_guidance_dims[0] > 0 else None
self.decoder1 = Up(hidden_dim, decoder_dims[0], decoder_guidance_proj_dims[0], text_guidance_channels[0], nheads, attention_type, pooling_size, cnext_type, kernel_size[0], (24, 24), window_size)
self.decoder2 = Up(decoder_dims[0], decoder_dims[1], decoder_guidance_proj_dims[1], text_guidance_channels[1], nheads, attention_type, pooling_size, cnext_type, kernel_size[1], (48, 48), window_size)
self.decoder3 = Up(decoder_dims[1], decoder_dims[2], decoder_guidance_proj_dims[2], text_guidance_channels[2], nheads, attention_type, pooling_size, cnext_type, kernel_size[2], (96, 96), window_size)
self.head0 = nn.Conv2d(hidden_dim, 1, kernel_size=3, stride=1, padding=1)
self.head1 = nn.Conv2d(decoder_dims[0], 1, kernel_size=3, stride=1, padding=1)
self.head2 = nn.Conv2d(decoder_dims[1], 1, kernel_size=3, stride=1, padding=1)
self.head = nn.Conv2d(decoder_dims[2], 1, kernel_size=3, stride=1, padding=1)
self.fast_infer = fast_inference
self.topK = topK
# self.clip_vis_projection = nn.Conv2d(1536, 768, kernel_size=3, stride=1, padding=1)
def feature_map(self, img_feats, text_feats):
img_feats = F.normalize(img_feats, dim=1) # B C H W
img_feats = repeat(img_feats, "B C H W -> B C T H W", T=text_feats.shape[1])
text_feats = F.normalize(text_feats, dim=-1) # B T P C
text_feats = text_feats.mean(dim=-2)
text_feats = F.normalize(text_feats, dim=-1) # B T C
text_feats = repeat(text_feats, "B T C -> B C T H W", H=img_feats.shape[-2], W=img_feats.shape[-1])
return torch.cat((img_feats, text_feats), dim=1) # B 2C T H W
def correlation(self, img_feats, text_feats, logit_scale):
img_feats = F.normalize(img_feats, dim=1) # B C H W
text_feats = F.normalize(text_feats, dim=-1) # B T P C
corr = torch.einsum('bchw, btpc -> bpthw', img_feats, text_feats)
logit_scale = torch.clamp(logit_scale.exp(), max=100)
# corr = logit_scale * corr
return corr
def corr_embed(self, x):
B = x.shape[0]
corr_embed = rearrange(x, 'B P T H W -> (B T) P H W')
corr_embed = self.conv1(corr_embed)
corr_embed = rearrange(corr_embed, '(B T) C H W -> B C T H W', B=B)
return corr_embed
def corr_projection(self, x, proj):
corr_embed = rearrange(x, 'B C T H W -> B T H W C')
corr_embed = proj(corr_embed)
corr_embed = rearrange(corr_embed, 'B T H W C -> B C T H W')
return corr_embed
def upsample(self, x):
B = x.shape[0]
corr_embed = rearrange(x, 'B C T H W -> (B T) C H W')
corr_embed = F.interpolate(corr_embed, scale_factor=2, mode='bilinear', align_corners=True)
corr_embed = rearrange(corr_embed, '(B T) C H W -> B C T H W', B=B)
return corr_embed
def conv_decoder(self, x, text_guidance, guidance, corr_guidance):
B = x.shape[0]
corr_embed = rearrange(x, 'B C T H W -> (B T) C H W')
mask_aux = self.head0(corr_embed.detach())
mask_aux = rearrange(mask_aux, '(B T) () H W -> B T H W', B=B)
corr_embed = self.decoder1(corr_embed, text_guidance[0], guidance[0], corr_guidance[0])
mask_aux0 = self.head1(corr_embed.detach())
mask_aux0 = rearrange(mask_aux0, '(B T) () H W -> B T H W', B=B)
corr_embed = self.decoder2(corr_embed, text_guidance[1], guidance[1], corr_guidance[1])
mask_aux1 = self.head2(corr_embed.detach())
mask_aux1 = rearrange(mask_aux1, '(B T) () H W -> B T H W', B=B)
corr_embed = self.decoder3(corr_embed, text_guidance[2], guidance[2], corr_guidance[2])
corr_embed = self.head(corr_embed)
corr_embed = rearrange(corr_embed, '(B T) () H W -> B T H W', B=B)
return corr_embed, mask_aux, mask_aux0, mask_aux1
def fast_conv_decoder(self, x, text_guidance, guidance, corr_guidance, topK):
B = x.shape[0]
T = x.shape[2]
corr_embed = rearrange(x, 'B C T H W -> (B T) C H W')
mask_aux = self.head0(corr_embed)
mask_aux = rearrange(mask_aux, '(B T) () H W -> B T H W', B=B)
aux_valid = self.get_valid_idx_from_mask(mask_aux, topK)
corr_embed = self.decoder1(corr_embed[aux_valid], text_guidance[0], guidance[0], corr_guidance[0][aux_valid])
mask_aux0 = self.head1(corr_embed)
mask_aux0 = rearrange(mask_aux0, '(B T) () H W -> B T H W', B=B)
aux0_valid = self.get_valid_idx_from_mask(mask_aux0, topK)
corr_embed = self.decoder2(corr_embed[aux0_valid], text_guidance[1], guidance[1], corr_guidance[1][aux_valid][aux0_valid])
mask_aux1 = self.head2(corr_embed)
mask_aux1 = rearrange(mask_aux1, '(B T) () H W -> B T H W', B=B)
aux1_valid = self.get_valid_idx_from_mask(mask_aux1, topK)
corr_embed = self.decoder3(corr_embed[aux1_valid], text_guidance[2], guidance[2], corr_guidance[2][aux_valid][aux0_valid][aux1_valid])
corr_embed = self.head(corr_embed)
corr_embed = rearrange(corr_embed, '(B T) () H W -> B T H W', B=B)
valid_idx = aux_valid.clone()
valid_idx1 = aux0_valid.clone()
valid_idx1[aux0_valid] = aux1_valid
valid_idx[aux_valid] = valid_idx1
_, _, H, W = corr_embed.size()
output = torch.zeros([B, T, H, W]).cuda()
output[:, valid_idx] = corr_embed.sigmoid()
return [output]
def get_valid_idx_from_mask(self, mask, topK=2):
T = mask.size(1)
mask = torch.argsort(mask, dim=1, descending=True)
mask = mask[:, :topK].reshape(-1)
valid_idx = torch.bincount(mask, minlength=T) > 0
return valid_idx
def forward(self, img_feats, text_feats, appearance_guidance, logit_scale):
"""
Arguments:
img_feats: (B, C, H, W)
text_feats: (B, T, P, C)
apperance_guidance: tuple of (B, C, H, W)
"""
# img_feats = self.clip_vis_projection(img_feats)
corr = self.correlation(img_feats, text_feats, logit_scale)
#corr = self.feature_map(img_feats, text_feats)
corr_embed = self.corr_embed(corr)
projected_guidance, projected_text_guidance, projected_decoder_guidance = None, [None, None, None], [None, None]
projected_corr_decoder_guidance = [None, None, None]
# if self.guidance_projection is not None:
# projected_guidance = self.guidance_projection(appearance_guidance[0])
if self.decoder_guidance_projection is not None:
projected_decoder_guidance = [proj(g) for proj, g in zip(self.decoder_guidance_projection, appearance_guidance[1:])]
if self.decoder_corr_guidance_projection is not None:
corr = rearrange(corr, 'B P T H W -> (B T) P H W')
for i, proj in enumerate(self.decoder_corr_guidance_projection):
corr = proj(corr)
projected_corr_decoder_guidance[i] = corr
if self.text_guidance_projection is not None:
text_feats = text_feats.mean(dim=-2)
text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)
projected_text_guidance = [proj(text_feats) for proj in self.text_guidance_projection]
if self.fast_infer and not self.training:
logit = self.fast_conv_decoder(corr_embed, projected_text_guidance, projected_decoder_guidance, projected_corr_decoder_guidance, topK=self.topK)
else:
logit = self.conv_decoder(corr_embed, projected_text_guidance, projected_decoder_guidance, projected_corr_decoder_guidance)
return logit
# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
# Modified by Jian Ding from: https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py
import fvcore.nn.weight_init as weight_init
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.config import configurable
from detectron2.layers import Conv2d
from .model import Aggregator
from sed.third_party import clip
from sed.third_party import imagenet_templates
import numpy as np
import open_clip
class SEDPredictor(nn.Module):
@configurable
def __init__(
self,
*,
train_class_json: str,
test_class_json: str,
clip_pretrained: str,
prompt_ensemble_type: str,
text_guidance_dim: int,
text_guidance_proj_dim: int,
appearance_guidance_dim: int,
appearance_guidance_proj_dim: int,
prompt_depth: int,
prompt_length: int,
decoder_dims: list,
decoder_guidance_dims: list,
decoder_guidance_proj_dims: list,
num_heads: int,
num_layers: tuple,
hidden_dims: tuple,
pooling_sizes: tuple,
feature_resolution: tuple,
window_sizes: tuple,
attention_type: str,
cnext_type: str,
kernel_size: list,
fast_inference: bool,
topK: int,
):
"""
Args:
"""
super().__init__()
import json
# use class_texts in train_forward, and test_class_texts in test_forward
with open(train_class_json, 'r') as f_in:
self.class_texts = json.load(f_in)
with open(test_class_json, 'r') as f_in:
self.test_class_texts = json.load(f_in)
assert self.class_texts != None
if self.test_class_texts == None:
self.test_class_texts = self.class_texts
device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = None
if clip_pretrained == "Convnext-L":
name, pretrain = ('convnext_large_d_320', 'laion2b_s29b_b131k_ft_soup')
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(
name,
pretrained=pretrain,
device=device,)
self.tokenizer = open_clip.get_tokenizer(name)
elif clip_pretrained == "Convnext-B":
name, pretrain = ('convnext_base_w_320', 'laion_aesthetic_s13b_b82k')
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(
name,
pretrained=pretrain,
device=device,)
self.tokenizer = open_clip.get_tokenizer(name)
elif clip_pretrained == "ViT-G" or clip_pretrained == "ViT-H":
# for OpenCLIP models
name, pretrain = ('ViT-H-14', 'laion2b_s32b_b79k') if clip_pretrained == 'ViT-H' else ('ViT-bigG-14', 'laion2b_s39b_b160k')
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(
name,
pretrained=pretrain,
device=device,
force_image_size=336,)
self.tokenizer = open_clip.get_tokenizer(name)
else:
# for OpenAI models
clip_model, clip_preprocess = clip.load(clip_pretrained, device=device, jit=False, prompt_depth=prompt_depth, prompt_length=prompt_length)
self.prompt_ensemble_type = prompt_ensemble_type
if self.prompt_ensemble_type == "imagenet_select":
prompt_templates = imagenet_templates.IMAGENET_TEMPLATES_SELECT
elif self.prompt_ensemble_type == "imagenet":
prompt_templates = imagenet_templates.IMAGENET_TEMPLATES
elif self.prompt_ensemble_type == "single":
prompt_templates = ['A photo of a {} in the scene',]
else:
raise NotImplementedError
self.text_features = self.class_embeddings(self.class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
self.text_features_test = self.class_embeddings(self.test_class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
self.clip_model = clip_model.float()
self.clip_preprocess = clip_preprocess
transformer = Aggregator(
text_guidance_dim=text_guidance_dim,
text_guidance_proj_dim=text_guidance_proj_dim,
appearance_guidance_dim=appearance_guidance_dim,
appearance_guidance_proj_dim=appearance_guidance_proj_dim,
decoder_dims=decoder_dims,
decoder_guidance_dims=decoder_guidance_dims,
decoder_guidance_proj_dims=decoder_guidance_proj_dims,
num_layers=num_layers,
nheads=num_heads,
hidden_dim=hidden_dims,
pooling_size=pooling_sizes,
feature_resolution=feature_resolution,
window_size=window_sizes,
attention_type=attention_type,
prompt_channel=len(prompt_templates),
cnext_type=cnext_type,
kernel_size=kernel_size,
fast_inference=fast_inference,
topK=topK,
)
self.transformer = transformer
@classmethod
def from_config(cls, cfg):#, in_channels, mask_classification):
ret = {}
ret["train_class_json"] = cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON
ret["test_class_json"] = cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON
ret["clip_pretrained"] = cfg.MODEL.SEM_SEG_HEAD.CLIP_PRETRAINED
ret["prompt_ensemble_type"] = cfg.MODEL.PROMPT_ENSEMBLE_TYPE
# Aggregator parameters:
ret["text_guidance_dim"] = cfg.MODEL.SEM_SEG_HEAD.TEXT_GUIDANCE_DIM
ret["text_guidance_proj_dim"] = cfg.MODEL.SEM_SEG_HEAD.TEXT_GUIDANCE_PROJ_DIM
ret["appearance_guidance_dim"] = cfg.MODEL.SEM_SEG_HEAD.APPEARANCE_GUIDANCE_DIM
ret["appearance_guidance_proj_dim"] = cfg.MODEL.SEM_SEG_HEAD.APPEARANCE_GUIDANCE_PROJ_DIM
ret["decoder_dims"] = cfg.MODEL.SEM_SEG_HEAD.DECODER_DIMS
ret["decoder_guidance_dims"] = cfg.MODEL.SEM_SEG_HEAD.DECODER_GUIDANCE_DIMS
ret["decoder_guidance_proj_dims"] = cfg.MODEL.SEM_SEG_HEAD.DECODER_GUIDANCE_PROJ_DIMS
ret["prompt_depth"] = cfg.MODEL.SEM_SEG_HEAD.PROMPT_DEPTH
ret["prompt_length"] = cfg.MODEL.SEM_SEG_HEAD.PROMPT_LENGTH
ret["num_layers"] = cfg.MODEL.SEM_SEG_HEAD.NUM_LAYERS
ret["num_heads"] = cfg.MODEL.SEM_SEG_HEAD.NUM_HEADS
ret["hidden_dims"] = cfg.MODEL.SEM_SEG_HEAD.HIDDEN_DIMS
ret["pooling_sizes"] = cfg.MODEL.SEM_SEG_HEAD.POOLING_SIZES
ret["feature_resolution"] = cfg.MODEL.SEM_SEG_HEAD.FEATURE_RESOLUTION
ret["window_sizes"] = cfg.MODEL.SEM_SEG_HEAD.WINDOW_SIZES
ret["attention_type"] = cfg.MODEL.SEM_SEG_HEAD.ATTENTION_TYPE
ret["cnext_type"] = cfg.MODEL.SEM_SEG_HEAD.CNEXT_TYPE
ret["kernel_size"] = cfg.MODEL.SEM_SEG_HEAD.KERNEL_SIZE
ret["fast_inference"] = cfg.TEST.FAST_INFERENCE
ret["topK"] = cfg.TEST.TOPK
return ret
def forward(self, x, vis_guidance, ):
logit_scale = self.clip_model.logit_scale
vis = [vis_guidance[k] for k in vis_guidance.keys()][::-1]
text = self.text_features if self.training else self.text_features_test
text = text.repeat(x.shape[0], 1, 1, 1)
out = self.transformer(x, text, vis, logit_scale)
return out
@torch.no_grad()
def class_embeddings(self, classnames, templates, clip_model):
zeroshot_weights = []
for classname in classnames:
if ', ' in classname:
classname_splits = classname.split(', ')
texts = []
for template in templates:
for cls_split in classname_splits:
texts.append(template.format(cls_split))
else:
texts = [template.format(classname) for template in templates] # format with class
if self.tokenizer is not None:
texts = self.tokenizer(texts).cuda()
else:
texts = clip.tokenize(texts).cuda()
class_embeddings = clip_model.encode_text(texts)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
if len(templates) != class_embeddings.shape[0]:
class_embeddings = class_embeddings.reshape(len(templates), -1, class_embeddings.shape[-1]).mean(dim=1)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
return zeroshot_weights
\ No newline at end of file
# Copyright (c) Facebook, Inc. and its affiliates.
from typing import Tuple
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.config import configurable
from detectron2.data import MetadataCatalog
from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
from detectron2.modeling.backbone import Backbone
from detectron2.modeling.postprocessing import sem_seg_postprocess
from detectron2.structures import ImageList
from detectron2.utils.memory import _ignore_torch_cuda_oom
from einops import rearrange
@META_ARCH_REGISTRY.register()
class SED(nn.Module):
@configurable
def __init__(
self,
*,
backbone: Backbone,
sem_seg_head: nn.Module,
size_divisibility: int,
pixel_mean: Tuple[float],
pixel_std: Tuple[float],
clip_pixel_mean: Tuple[float],
clip_pixel_std: Tuple[float],
train_class_json: str,
test_class_json: str,
sliding_window: bool,
clip_finetune: str,
backbone_multiplier: float,
clip_pretrained: str,
in_features,
fast_inference: bool,
):
"""
Args:
backbone: a backbone module, must follow detectron2's backbone interface
sem_seg_head: a module that predicts semantic segmentation from backbone features
"""
super().__init__()
self.backbone = backbone
self.sem_seg_head = sem_seg_head
if size_divisibility < 0:
size_divisibility = self.backbone.size_divisibility
self.size_divisibility = size_divisibility
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
self.register_buffer("clip_pixel_mean", torch.Tensor(clip_pixel_mean).view(-1, 1, 1), False)
self.register_buffer("clip_pixel_std", torch.Tensor(clip_pixel_std).view(-1, 1, 1), False)
self.train_class_json = train_class_json
self.test_class_json = test_class_json
self.clip_finetune = clip_finetune
for name, params in self.sem_seg_head.predictor.clip_model.named_parameters():
if "visual" in name:
if clip_finetune == "prompt":
params.requires_grad = True if "prompt" in name else False
elif clip_finetune == "conv":
params.requires_grad = True if "conv" in name or "position" in name else False
elif clip_finetune == "full":
params.requires_grad = True
elif clip_finetune == "mlp":
params.requires_grad = True if "mlp" in name or "position" in name else False
elif clip_finetune == "full_res5":
if "stages.3" in name:
params.requires_grad = True
else:
params.requires_grad = False
else:
params.requires_grad = False
else:
params.requires_grad = False
if clip_finetune == "fast_infer":
for name, params in self.sem_seg_head.predictor.transformer.named_parameters():
if "head1" in name or "head2" in name or "head0" in name:
params.requires_grad = True
else:
params.requires_grad = False
finetune_backbone = backbone_multiplier > 0.
for name, params in self.backbone.named_parameters():
if "norm0" in name:
params.requires_grad = False
else:
params.requires_grad = finetune_backbone
self.sliding_window = sliding_window
# self.clip_resolution = (384, 384) if clip_pretrained == "ViT-B/16" else (336, 336)
self.clip_resolution = (768, 768)
self.sequential = False
del self.backbone
self.in_features = in_features
self.fast_inference = fast_inference
self.clip_finetune = clip_finetune
@classmethod
def from_config(cls, cfg):
backbone = build_backbone(cfg)
sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape())
return {
"backbone": backbone,
"sem_seg_head": sem_seg_head,
"size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY,
"pixel_mean": cfg.MODEL.PIXEL_MEAN,
"pixel_std": cfg.MODEL.PIXEL_STD,
"clip_pixel_mean": cfg.MODEL.CLIP_PIXEL_MEAN,
"clip_pixel_std": cfg.MODEL.CLIP_PIXEL_STD,
"train_class_json": cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON,
"test_class_json": cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON,
"sliding_window": cfg.TEST.SLIDING_WINDOW,
"clip_finetune": cfg.MODEL.SEM_SEG_HEAD.CLIP_FINETUNE,
"backbone_multiplier": cfg.SOLVER.BACKBONE_MULTIPLIER,
"clip_pretrained": cfg.MODEL.SEM_SEG_HEAD.CLIP_PRETRAINED,
"in_features": cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES,
"fast_inference": cfg.TEST.FAST_INFERENCE,
}
@property
def device(self):
return self.pixel_mean.device
def forward(self, batched_inputs):
"""
Args:
batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
Each item in the list contains the inputs for one image.
For now, each item in the list is a dict that contains:
* "image": Tensor, image in (C, H, W) format.
* "instances": per-region ground truth
* Other information that's included in the original dicts, such as:
"height", "width" (int): the output resolution of the model (may be different
from input resolution), used in inference.
Returns:
list[dict]:
each dict has the results for one image. The dict contains the following keys:
* "sem_seg":
A Tensor that represents the
per-pixel segmentation prediced by the head.
The prediction has shape KxHxW that represents the logits of
each class for each pixel.
"""
images = [x["image"].to(self.device) for x in batched_inputs]
self.sliding_window = False
if not self.training:
self.size_divisibility = -1
if not self.training and self.sliding_window:
if not self.sequential:
with _ignore_torch_cuda_oom():
return self.inference_sliding_window(batched_inputs)
self.sequential = True
return self.inference_sliding_window(batched_inputs)
clip_images = [(x - self.clip_pixel_mean) / self.clip_pixel_std for x in images]
clip_images = ImageList.from_tensors(clip_images, self.size_divisibility)
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
images = ImageList.from_tensors(images, self.size_divisibility)
clip_images = F.interpolate(clip_images.tensor, size=self.clip_resolution, mode='bilinear', align_corners=False, )
clip_features = self.sem_seg_head.predictor.clip_model.encode_image(clip_images, dense=True)
images_resized = F.interpolate(images.tensor, size=(384, 384), mode='bilinear', align_corners=False,)
# features = self.backbone(images_resized)
clip_vis_dense = clip_features["clip_vis_dense"]
fusion_features = {k: v.clone().detach() for k,v in clip_features.items() if k in self.in_features}
outputs = self.sem_seg_head(clip_vis_dense, fusion_features)
if self.training:
print_flag = False
for name, param in self.named_parameters():
if param.grad == None and param.requires_grad:
print(name)
print_flag = True
if print_flag:
print("--------------------------------------------------------------------\n")
targets = torch.stack([x["sem_seg"].to(self.device) for x in batched_inputs], dim=0)
num_classes = outputs[0].shape[1]
mask = targets != self.sem_seg_head.ignore_value
losses = {}
for i, output_ in enumerate(outputs):
if self.clip_finetune == "fast_infer" and i==0:
continue
output_ = F.interpolate(output_, size=(targets.shape[-2], targets.shape[-1]), mode="bilinear", align_corners=False)
output_ = output_.permute(0,2,3,1)
_targets = torch.zeros(output_.shape, device=self.device)
_onehot = F.one_hot(targets[mask], num_classes=num_classes).float()
_targets[mask] = _onehot
loss = F.binary_cross_entropy_with_logits(output_, _targets)
losses.update({f"loss_sem_seg_{i}" : loss})
return losses
else:
if self.fast_inference:
outputs = outputs[0]
else:
outputs = outputs[0].sigmoid()
image_size = images.image_sizes[0]
height = batched_inputs[0].get("height", image_size[0])
width = batched_inputs[0].get("width", image_size[1])
output = sem_seg_postprocess(outputs[0], image_size, height, width)
processed_results = [{'sem_seg': output}]
return processed_results
@torch.no_grad()
def inference_sliding_window(self, batched_inputs, kernel=384, overlap=0.333, out_res=[640, 640]):
images = [x["image"].to(self.device, dtype=torch.float32) for x in batched_inputs]
stride = int(kernel * (1 - overlap))
unfold = nn.Unfold(kernel_size=kernel, stride=stride)
fold = nn.Fold(out_res, kernel_size=kernel, stride=stride)
image = F.interpolate(images[0].unsqueeze(0), size=out_res, mode='bilinear', align_corners=False).squeeze()
image = rearrange(unfold(image), "(C H W) L-> L C H W", C=3, H=kernel)
global_image = F.interpolate(images[0].unsqueeze(0), size=(kernel, kernel), mode='bilinear', align_corners=False)
image = torch.cat((image, global_image), dim=0)
images = (image - self.pixel_mean) / self.pixel_std
clip_images = (image - self.clip_pixel_mean) / self.clip_pixel_std
clip_images = F.interpolate(clip_images, size=self.clip_resolution, mode='bilinear', align_corners=False, )
if self.sequential:
outputs = []
for clip_feat, image in zip(clip_images, images):
feature = self.backbone(image.unsqueeze(0))
clip_feat = self.sem_seg_head.predictor.clip_model.encode_image(clip_feat.unsqueeze(0), dense=True)
output = self.sem_seg_head(clip_feat, feature)
outputs.append(output[0])
outputs = torch.stack(outputs, dim=0)
else:
# features = self.backbone(images)
features = {}
clip_features = self.sem_seg_head.predictor.clip_model.encode_image(clip_images, dense=True)
outputs = self.sem_seg_head(clip_features["clip_vis_dense"], features)
outputs = F.interpolate(outputs, size=kernel, mode="bilinear", align_corners=False)
outputs = outputs.sigmoid()
global_output = outputs[-1:]
global_output = F.interpolate(global_output, size=out_res, mode='bilinear', align_corners=False,)
outputs = outputs[:-1]
outputs = fold(outputs.flatten(1).T) / fold(unfold(torch.ones([1] + out_res, device=self.device)))
outputs = (outputs + global_output) / 2.
height = batched_inputs[0].get("height", out_res[0])
width = batched_inputs[0].get("width", out_res[1])
output = sem_seg_postprocess(outputs[0], out_res, height, width)
return [{'sem_seg': output}]
\ No newline at end of file
# Copyright (c) Facebook, Inc. and its affiliates.
import copy
from itertools import count
import numpy as np
import torch
from fvcore.transforms import HFlipTransform
from torch import nn
from torch.nn.parallel import DistributedDataParallel
from detectron2.data.detection_utils import read_image
from detectron2.modeling import DatasetMapperTTA
__all__ = [
"SemanticSegmentorWithTTA",
]
class SemanticSegmentorWithTTA(nn.Module):
"""
A SemanticSegmentor with test-time augmentation enabled.
Its :meth:`__call__` method has the same interface as :meth:`SemanticSegmentor.forward`.
"""
def __init__(self, cfg, model, tta_mapper=None, batch_size=1):
"""
Args:
cfg (CfgNode):
model (SemanticSegmentor): a SemanticSegmentor to apply TTA on.
tta_mapper (callable): takes a dataset dict and returns a list of
augmented versions of the dataset dict. Defaults to
`DatasetMapperTTA(cfg)`.
batch_size (int): batch the augmented images into this batch size for inference.
"""
super().__init__()
if isinstance(model, DistributedDataParallel):
model = model.module
self.cfg = cfg.clone()
self.model = model
if tta_mapper is None:
tta_mapper = DatasetMapperTTA(cfg)
self.tta_mapper = tta_mapper
self.batch_size = batch_size
def _batch_inference(self, batched_inputs):
"""
Execute inference on a list of inputs,
using batch size = self.batch_size, instead of the length of the list.
Inputs & outputs have the same format as :meth:`SemanticSegmentor.forward`
"""
outputs = []
inputs = []
for idx, input in zip(count(), batched_inputs):
inputs.append(input)
if len(inputs) == self.batch_size or idx == len(batched_inputs) - 1:
with torch.no_grad():
outputs.extend(self.model(inputs))
inputs = []
return outputs
def __call__(self, batched_inputs):
"""
Same input/output format as :meth:`SemanticSegmentor.forward`
"""
def _maybe_read_image(dataset_dict):
ret = copy.copy(dataset_dict)
if "image" not in ret:
image = read_image(ret.pop("file_name"), self.model.input_format)
image = torch.from_numpy(np.ascontiguousarray(image.transpose(2, 0, 1))) # CHW
ret["image"] = image
if "height" not in ret and "width" not in ret:
ret["height"] = image.shape[1]
ret["width"] = image.shape[2]
return ret
return [self._inference_one_image(_maybe_read_image(x)) for x in batched_inputs]
def _inference_one_image(self, input):
"""
Args:
input (dict): one dataset dict with "image" field being a CHW tensor
Returns:
dict: one output dict
"""
augmented_inputs, tfms = self._get_augmented_inputs(input)
# 1: forward with all augmented images
outputs = self._batch_inference(augmented_inputs)
# Delete now useless variables to avoid being out of memory
del augmented_inputs
# 2: merge the results
# handle flip specially
new_outputs = []
for output, tfm in zip(outputs, tfms):
if any(isinstance(t, HFlipTransform) for t in tfm.transforms):
new_outputs.append(output.pop("sem_seg").flip(dims=[2]))
else:
new_outputs.append(output.pop("sem_seg"))
del outputs
# to avoid OOM with torch.stack
final_predictions = new_outputs[0]
for i in range(1, len(new_outputs)):
final_predictions += new_outputs[i]
final_predictions = final_predictions / len(new_outputs)
del new_outputs
return {"sem_seg": final_predictions}
def _get_augmented_inputs(self, input):
augmented_inputs = self.tta_mapper(input)
tfms = [x.pop("transforms") for x in augmented_inputs]
return augmented_inputs, tfms
import hashlib
import os
import urllib
import warnings
from typing import Union, List
import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from tqdm import tqdm
#from .model import build_model
from .model_vpt import build_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
__all__ = ["available_models", "load", "tokenize"]
_tokenizer = _Tokenizer()
_MODELS = {
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
}
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
return download_target
def available_models():
return list(_MODELS.keys())
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True, prompt_depth=0, prompt_length=0):
if name not in _MODELS:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
model_path = _download(_MODELS[name])
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
n_px = model.input_resolution.item()
transform = Compose([
Resize(n_px, interpolation=Image.BICUBIC),
CenterCrop(n_px),
lambda image: image.convert("RGB"),
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
if not jit:
model = build_model(model.state_dict(), prompt_depth, prompt_length).to(device)
return model, transform
# patch the device names
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
def patch_device(module):
graphs = [module.graph] if hasattr(module, "graph") else []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("prim::Constant"):
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
node.copyAttributes(device_node)
model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)
# patch dtype to float32 on CPU
if device == "cpu":
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()
def patch_float(module):
graphs = [module.graph] if hasattr(module, "graph") else []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("aten::to"):
inputs = list(node.inputs())
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
if inputs[i].node()["value"] == 5:
inputs[i].node().copyAttributes(float_node)
model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)
model.float()
return model, transform
def load_custom(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True, n_px=224):
if name not in _MODELS:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
model_path = _download(_MODELS[name])
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
# n_px = model.input_resolution.item()
transform = Compose([
Resize(n_px, interpolation=Image.BICUBIC),
CenterCrop(n_px),
lambda image: image.convert("RGB"),
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
if not jit:
model = build_model(model.state_dict()).to(device)
return model, transform
# patch the device names
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
def patch_device(module):
graphs = [module.graph] if hasattr(module, "graph") else []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("prim::Constant"):
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
node.copyAttributes(device_node)
model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)
# patch dtype to float32 on CPU
if device == "cpu":
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()
def patch_float(module):
graphs = [module.graph] if hasattr(module, "graph") else []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("aten::to"):
inputs = list(node.inputs())
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
if inputs[i].node()["value"] == 5:
inputs[i].node().copyAttributes(float_node)
model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)
model.float()
return model, transform
def tokenize(texts: Union[str, List[str]], context_length: int = 77):
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder["<|startoftext|>"]
eot_token = _tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
result[i, :len(tokens)] = torch.tensor(tokens)
return result
# source: https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb
IMAGENET_TEMPLATES = [
'a bad photo of a {}.',
'a photo of many {}.',
'a sculpture of a {}.',
'a photo of the hard to see {}.',
'a low resolution photo of the {}.',
'a rendering of a {}.',
'graffiti of a {}.',
'a bad photo of the {}.',
'a cropped photo of the {}.',
'a tattoo of a {}.',
'the embroidered {}.',
'a photo of a hard to see {}.',
'a bright photo of a {}.',
'a photo of a clean {}.',
'a photo of a dirty {}.',
'a dark photo of the {}.',
'a drawing of a {}.',
'a photo of my {}.',
'the plastic {}.',
'a photo of the cool {}.',
'a close-up photo of a {}.',
'a black and white photo of the {}.',
'a painting of the {}.',
'a painting of a {}.',
'a pixelated photo of the {}.',
'a sculpture of the {}.',
'a bright photo of the {}.',
'a cropped photo of a {}.',
'a plastic {}.',
'a photo of the dirty {}.',
'a jpeg corrupted photo of a {}.',
'a blurry photo of the {}.',
'a photo of the {}.',
'a good photo of the {}.',
'a rendering of the {}.',
'a {} in a video game.',
'a photo of one {}.',
'a doodle of a {}.',
'a close-up photo of the {}.',
'a photo of a {}.',
'the origami {}.',
'the {} in a video game.',
'a sketch of a {}.',
'a doodle of the {}.',
'a origami {}.',
'a low resolution photo of a {}.',
'the toy {}.',
'a rendition of the {}.',
'a photo of the clean {}.',
'a photo of a large {}.',
'a rendition of a {}.',
'a photo of a nice {}.',
'a photo of a weird {}.',
'a blurry photo of a {}.',
'a cartoon {}.',
'art of a {}.',
'a sketch of the {}.',
'a embroidered {}.',
'a pixelated photo of a {}.',
'itap of the {}.',
'a jpeg corrupted photo of the {}.',
'a good photo of a {}.',
'a plushie {}.',
'a photo of the nice {}.',
'a photo of the small {}.',
'a photo of the weird {}.',
'the cartoon {}.',
'art of the {}.',
'a drawing of the {}.',
'a photo of the large {}.',
'a black and white photo of a {}.',
'the plushie {}.',
'a dark photo of a {}.',
'itap of a {}.',
'graffiti of the {}.',
'a toy {}.',
'itap of my {}.',
'a photo of a cool {}.',
'a photo of a small {}.',
'a tattoo of the {}.',
# 'A photo of a {} in the scene.',
]
# v1: 59.0875
IMAGENET_TEMPLATES_SELECT = [
'itap of a {}.',
'a bad photo of the {}.',
'a origami {}.',
'a photo of the large {}.',
'a {} in a video game.',
'art of the {}.',
'a photo of the small {}.',
'A photo of a {} in the scene',
]
# v2: 58.2584
# IMAGENET_TEMPLATES_SELECT = [
# 'itap of a {}',
# 'a bad photo of the {}',
# 'a origami {}',
# 'a photo of the large {}',
# 'art of the {}',
# 'a photo of the small {}',
# 'A photo of a {} in the scene',
# ]
# v3: 59.1006
# IMAGENET_TEMPLATES_SELECT = [
# 'itap of a {}.',
# 'a bad photo of the {}.',
# 'a origami {}.',
# 'a photo of the large {}.',
# 'art of the {}.',
# 'a photo of the small {}.',
# 'a cropped photo of a {}.',
# 'A photo of a {} in the scene',
# 'itap of a {} in the scene',
# 'a bad photo of the {} in the scene',
# 'a origami {} in the scene',
# 'a photo of the large {} in the scene',
# 'art of the {} in the scene',
# 'a photo of the small {} in the scene',
# 'a cropped photo of a {} in the scene',
# ]
# v4: 59.8659
# IMAGENET_TEMPLATES_SELECT = [
# 'a bad photo of the {}.',
# 'a photo of the large {}.',
# 'art of the {}.',
# 'a photo of the small {}.',
# 'a cropped photo of a {}.',
# 'A photo of a {} in the scene',
# 'a bad photo of the {} in the scene',
# 'a photo of the large {} in the scene',
# 'art of the {} in the scene',
# 'a photo of the small {} in the scene',
# 'a cropped photo of a {} in the scene',
# 'a photo of a masked {} in the scene',
# ]
# v5: 59.9346
# IMAGENET_TEMPLATES_SELECT = [
# 'a bad photo of the {}.',
# 'a photo of the large {}.',
# 'art of the {}.',
# 'a photo of the small {}.',
# 'a cropped photo of a {}.',
# 'This is a photo of a {}',
# 'This is a photo of a small {}',
# 'This is a photo of a medium {}',
# 'This is a photo of a large {}',
# 'A photo of a {} in the scene',
# 'a bad photo of the {} in the scene',
# 'a photo of the large {} in the scene',
# 'art of the {} in the scene',
# 'a photo of the small {} in the scene',
# 'a cropped photo of a {} in the scene',
# 'a photo of a masked {} in the scene',
# 'There is a {} in the scene',
# 'There is the {} in the scene',
# 'This is a {} in the scene',
# 'This is the {} in the scene',
# 'This is one {} in the scene',
# ]
# v6: 60.6611
# IMAGENET_TEMPLATES_SELECT = [
# 'a bad photo of the {}.',
# 'a photo of the large {}.',
# 'art of the {}.',
# 'a photo of the small {}.',
# 'a cropped photo of a {}.',
# 'This is a photo of a {}',
# 'This is a photo of a small {}',
# 'This is a photo of a medium {}',
# 'This is a photo of a large {}',
# 'A photo of a {} in the scene',
# 'a bad photo of the {} in the scene',
# 'a photo of the large {} in the scene',
# 'art of the {} in the scene',
# 'a photo of the small {} in the scene',
# 'a cropped photo of a {} in the scene',
# 'a photo of a masked {} in the scene',
# 'There is a {} in the scene',
# 'There is the {} in the scene',
# 'This is a {} in the scene',
# 'This is the {} in the scene',
# 'This is one {} in the scene',
#
# 'There is a masked {} in the scene',
# 'There is the masked {} in the scene',
# 'This is a masked {} in the scene',
# 'This is the masked {} in the scene',
# 'This is one masked {} in the scene',
# ]
# v7: 60.4529
# IMAGENET_TEMPLATES_SELECT = [
# 'a bad photo of the {}.',
# 'a photo of the large {}.',
# 'art of the {}.',
# 'a photo of the small {}.',
# 'a cropped photo of a {}.',
# 'This is a photo of a {}',
# 'This is a photo of a small {}',
# 'This is a photo of a medium {}',
# 'This is a photo of a large {}',
# 'A photo of a {} in the scene',
# 'a bad photo of the {} in the scene',
# 'a photo of the large {} in the scene',
# 'art of the {} in the scene',
# 'a photo of the small {} in the scene',
# 'a cropped photo of a {} in the scene',
# 'a photo of a masked {} in the scene',
# 'There is a {} in the scene',
# 'There is the {} in the scene',
# 'This is a {} in the scene',
# 'This is the {} in the scene',
# 'This is one {} in the scene',
#
# 'There is a cropped {} in the scene',
# 'There is the cropped {} in the scene',
# 'This is a cropped {} in the scene',
# 'This is the cropped {} in the scene',
# 'This is one cropped {} in the scene',
#
# 'a cropped photo of the {}',
# 'a cropped photo of a {}',
# 'a cropped photo of one {}',
#
# 'There is a masked {} in the scene',
# 'There is the masked {} in the scene',
# 'This is a masked {} in the scene',
# 'This is the masked {} in the scene',
# 'This is one masked {} in the scene',
# ]
# v8: 60.7057
# IMAGENET_TEMPLATES_SELECT = [
# 'a bad photo of the {}.',
# 'a photo of the large {}.',
# 'a photo of the small {}.',
# 'a cropped photo of a {}.',
# 'This is a photo of a {}',
# 'This is a photo of a small {}',
# 'This is a photo of a medium {}',
# 'This is a photo of a large {}',
#
# 'This is a masked photo of a {}',
# 'This is a masked photo of a small {}',
# 'This is a masked photo of a medium {}',
# 'This is a masked photo of a large {}',
#
# 'A photo of a {} in the scene',
# 'a bad photo of the {} in the scene',
# 'a photo of the large {} in the scene',
# 'a photo of the small {} in the scene',
# 'a cropped photo of a {} in the scene',
# 'a photo of a masked {} in the scene',
# 'There is a {} in the scene',
# 'There is the {} in the scene',
# 'This is a {} in the scene',
# 'This is the {} in the scene',
# 'This is one {} in the scene',
#
# 'There is a masked {} in the scene',
# 'There is the masked {} in the scene',
# 'This is a masked {} in the scene',
# 'This is the masked {} in the scene',
# 'This is one masked {} in the scene',
# ]
# v9: 60.8775
# IMAGENET_TEMPLATES_SELECT = [
# 'a bad photo of the {}.',
# 'a photo of the large {}.',
# 'a photo of the small {}.',
# 'a cropped photo of a {}.',
# 'This is a photo of a {}',
# 'This is a photo of a small {}',
# 'This is a photo of a medium {}',
# 'This is a photo of a large {}',
#
# 'This is a masked photo of a {}',
# 'This is a masked photo of a small {}',
# 'This is a masked photo of a medium {}',
# 'This is a masked photo of a large {}',
#
# 'This is a cropped photo of a {}',
# 'This is a cropped photo of a small {}',
# 'This is a cropped photo of a medium {}',
# 'This is a cropped photo of a large {}',
#
# 'A photo of a {} in the scene',
# 'a bad photo of the {} in the scene',
# 'a photo of the large {} in the scene',
# 'a photo of the small {} in the scene',
# 'a cropped photo of a {} in the scene',
# 'a photo of a masked {} in the scene',
# 'There is a {} in the scene',
# 'There is the {} in the scene',
# 'This is a {} in the scene',
# 'This is the {} in the scene',
# 'This is one {} in the scene',
#
# 'There is a masked {} in the scene',
# 'There is the masked {} in the scene',
# 'This is a masked {} in the scene',
# 'This is the masked {} in the scene',
# 'This is one masked {} in the scene',
# ]
# v9
IMAGENET_TEMPLATES_SELECT_CLIP = [
'a bad photo of the {}.',
'a photo of the large {}.',
'a photo of the small {}.',
'a cropped photo of a {}.',
'This is a photo of a {}',
'This is a photo of a small {}',
'This is a photo of a medium {}',
'This is a photo of a large {}',
'This is a masked photo of a {}',
'This is a masked photo of a small {}',
'This is a masked photo of a medium {}',
'This is a masked photo of a large {}',
'This is a cropped photo of a {}',
'This is a cropped photo of a small {}',
'This is a cropped photo of a medium {}',
'This is a cropped photo of a large {}',
'A photo of a {} in the scene',
'a bad photo of the {} in the scene',
'a photo of the large {} in the scene',
'a photo of the small {} in the scene',
'a cropped photo of a {} in the scene',
'a photo of a masked {} in the scene',
'There is a {} in the scene',
'There is the {} in the scene',
'This is a {} in the scene',
'This is the {} in the scene',
'This is one {} in the scene',
'There is a masked {} in the scene',
'There is the masked {} in the scene',
'This is a masked {} in the scene',
'This is the masked {} in the scene',
'This is one masked {} in the scene',
]
# v10, for comparison
# IMAGENET_TEMPLATES_SELECT_CLIP = [
# 'a photo of a {}.',
#
# 'This is a photo of a {}',
# 'This is a photo of a small {}',
# 'This is a photo of a medium {}',
# 'This is a photo of a large {}',
#
# 'This is a photo of a {}',
# 'This is a photo of a small {}',
# 'This is a photo of a medium {}',
# 'This is a photo of a large {}',
#
# 'a photo of a {} in the scene',
# 'a photo of a {} in the scene',
#
# 'There is a {} in the scene',
# 'There is the {} in the scene',
# 'This is a {} in the scene',
# 'This is the {} in the scene',
# 'This is one {} in the scene',
# ]
ViLD_templates = [
'There is {article} {category} in the scene.',
'There is the {category} in the scene.',
'a photo of {article} {category} in the scene.',
'a photo of the {category} in the scene.',
'a photo of one {category} in the scene.',
'itap of {article} {category}.',
'itap of my {category}.',
'itap of the {category}.',
'a photo of {article} {category}.',
'a photo of my {category}.',
'a photo of the {category}.',
'a photo of one {category}.',
'a photo of many {category}.',
'a good photo of {article} {category}.',
'a good photo of the {category}.',
'a bad photo of {article} {category}.',
'a bad photo of the {category}.',
'a photo of a nice {category}.',
'a photo of the nice {category}.',
'a photo of a cool {category}.',
'a photo of the cool {category}.',
'a photo of a weird {category}.',
'a photo of the weird {category}.',
'a photo of a small {category}.',
'a photo of the small {category}.',
'a photo of a large {category}.',
'a photo of the large {category}.',
'a photo of a clean {category}.',
'a photo of the clean {category}.',
'a photo of a dirty {category}.',
'a photo of the dirty {category}.',
'a bright photo of {article} {category}.',
'a bright photo of the {category}.',
'a dark photo of {article} {category}.',
'a dark photo of the {category}.',
'a photo of a hard to see {category}.',
'a photo of the hard to see {category}.',
'a low resolution photo of {article} {category}.',
'a low resolution photo of the {category}.',
'a cropped photo of {article} {category}.',
'a cropped photo of the {category}.',
'a close-up photo of {article} {category}.',
'a close-up photo of the {category}.',
'a jpeg corrupted photo of {article} {category}.',
'a jpeg corrupted photo of the {category}.',
'a blurry photo of {article} {category}.',
'a blurry photo of the {category}.',
'a pixelated photo of {article} {category}.',
'a pixelated photo of the {category}.',
'a black and white photo of the {category}.',
'a black and white photo of {article} {category}.',
'a plastic {category}.',
'the plastic {category}.',
'a toy {category}.',
'the toy {category}.',
'a plushie {category}.',
'the plushie {category}.',
'a cartoon {category}.',
'the cartoon {category}.',
'an embroidered {category}.',
'the embroidered {category}.',
'a painting of the {category}.',
'a painting of a {category}.'
]
\ No newline at end of file
from collections import OrderedDict
from typing import Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.relu2 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu3 = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(OrderedDict([
("-1", nn.AvgPool2d(stride)),
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
("1", nn.BatchNorm2d(planes * self.expansion))
]))
def forward(self, x: torch.Tensor):
identity = x
out = self.relu1(self.bn1(self.conv1(x)))
out = self.relu2(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu3(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x[:1], key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x.squeeze(0)
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
super().__init__()
self.output_dim = output_dim
self.input_resolution = input_resolution
# the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.relu3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
def stem(x):
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x
x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
self.mask_pre_mlp = True
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
def forward_dense(self, x: torch.Tensor):
y = self.ln_1(x)
y = F.linear(y, self.attn.in_proj_weight, self.attn.in_proj_bias)
L, N, D = y.shape # L N 3D
y = y.reshape(L, N, 3, D // 3).permute(2, 1, 0, 3).reshape(3 * N, L, D // 3)
y = F.linear(y, self.attn.out_proj.weight, self.attn.out_proj.bias)
q, k, v = y.tensor_split(3, dim=0)
v = v.transpose(1, 0) + x # L N D
v = v + self.mlp(self.ln_2(v))
return v
class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
def forward(self, x: torch.Tensor, dense=False):
for i, resblock in enumerate(self.resblocks):
if i == self.layers - 1 and dense:
x = resblock.forward_dense(x)
else:
x = resblock(x)
return x
class VisualTransformer(nn.Module):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
super().__init__()
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
self.patch_size = patch_size
self.input_resolution = input_resolution
def forward(self, x: torch.Tensor, dense=False):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
if dense and (x.shape[1] != self.positional_embedding.shape[0]):
x = x + self.resized_pos_embed(self.input_resolution, x.shape[1]).to(x.dtype)
else:
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, dense)
x = x.permute(1, 0, 2) # LND -> NLD
if dense:
x = self.ln_post(x[:, :, :])
else:
x = self.ln_post(x[:, 0, :])
if self.proj is not None:
x = x @ self.proj
return x
def resized_pos_embed(self, in_res, tgt_res, mode="bicubic"):
#assert L == (input_resolution // self.patch_size) ** 2 + 1
L, D = self.positional_embedding.shape
in_side = in_res // self.patch_size
#tgt_side = tgt_res // self.patch_size
tgt_side = int((tgt_res - 1) ** 0.5)
cls_pos = self.positional_embedding[0].unsqueeze(0) # 1 D
pos_embed = self.positional_embedding[1:].reshape(1, in_side, in_side, D).permute(0, 3, 1, 2) # L-1 D -> 1 D S S
resized_pos_embed = F.interpolate(pos_embed, size=(tgt_side, tgt_side), mode=mode, align_corners=False,) # 1 D S S -> 1 D S' S'
resized_pos_embed = resized_pos_embed.squeeze(0).reshape(D, -1).T # L'-1 D
return torch.cat((cls_pos, resized_pos_embed), dim=0)
class CLIP(nn.Module):
def __init__(self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int
):
super().__init__()
self.context_length = context_length
self.image_resolution = image_resolution
if isinstance(vision_layers, (tuple, list)):
vision_heads = vision_width * 32 // 64
self.visual = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width
)
else:
vision_heads = vision_width // 64
self.visual = VisualTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim
)
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask()
)
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]))
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
@property
def dtype(self):
return self.visual.conv1.weight.dtype
def encode_image(self, image, masks=None, pool_mask=None, dense=False):
if pool_mask is not None:
return self.visual(image.type(self.dtype), mask=pool_mask, dense=dense)
if masks == None:
return self.visual(image.type(self.dtype), dense=dense)
else:
return self.visual(image.type(self.dtype), masks.type(self.dtype))
def encode_text(self, text):
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# import pdb; pdb.set_trace()
# normalized features
# image_features shape: [1, 1024]
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_iamge = logit_scale * image_features @ text_features.t()
logits_per_text = logit_scale * text_features @ image_features.t()
# shape = [global_batch_size, global_batch_size]
return logits_per_iamge, logits_per_text
def convert_weights(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
if isinstance(l, nn.MultiheadAttention):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.half()
for name in ["text_projection", "proj"]:
if hasattr(l, name):
attr = getattr(l, name)
if attr is not None:
attr.data = attr.data.half()
model.apply(_convert_weights_to_fp16)
def build_model(state_dict: dict):
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_resolution = vision_patch_size * grid_size
else:
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
vision_patch_size = None
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
image_resolution = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
model = CLIP(
embed_dim,
image_resolution, vision_layers, vision_width, vision_patch_size,
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
)
for key in ["input_resolution", "context_length", "vocab_size"]:
del state_dict[key]
convert_weights(model)
model.load_state_dict(state_dict)
return model.eval()
from collections import OrderedDict
from typing import Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(OrderedDict([
("-1", nn.AvgPool2d(stride)),
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
("1", nn.BatchNorm2d(planes * self.expansion))
]))
def forward(self, x: torch.Tensor):
identity = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x[:1], key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x.squeeze(0)
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
super().__init__()
self.output_dim = output_dim
self.input_resolution = input_resolution
# the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.relu3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
def stem(x):
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x
x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
self.mask_pre_mlp = True
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
def forward_dense(self, x: torch.Tensor):
y = self.ln_1(x)
y = F.linear(y, self.attn.in_proj_weight, self.attn.in_proj_bias)
L, N, D = y.shape # L N 3D
y = y.reshape(L, N, 3, D // 3).permute(2, 1, 0, 3).reshape(3 * N, L, D // 3)
y = F.linear(y, self.attn.out_proj.weight, self.attn.out_proj.bias)
q, k, v = y.tensor_split(3, dim=0)
v = v.transpose(1, 0) + x # L N D
v = v + self.mlp(self.ln_2(v))
return v
class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, prompt_length=0, prompt_depth=0):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
self.prompt_length = prompt_length
self.prompt_depth = prompt_depth
self.prompt_tokens = nn.Parameter(torch.zeros(prompt_depth, prompt_length, width)) if prompt_length > 0 else None
if self.prompt_tokens is not None:
nn.init.xavier_uniform_(self.prompt_tokens)
def forward(self, x: torch.Tensor, dense=False):
for i, resblock in enumerate(self.resblocks):
if self.prompt_length > 0 and i < self.prompt_depth:
l = self.prompt_length + 1 if i > 0 else 1
x = torch.cat((x[0:1, :, :], self.prompt_tokens[i].repeat(x.shape[1], 1, 1).permute(1, 0, 2) ,x[l:, :, :]))
if i == self.layers - 1 and dense:
x = resblock.forward_dense(x)
x = torch.cat((x[0:1, :, :], x[self.prompt_length + 1: :, :]), dim=0)
else:
x = resblock(x)
return x
class VisualTransformer(nn.Module):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, prompt_depth: int, prompt_length: int):
super().__init__()
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads, prompt_depth=prompt_depth, prompt_length=prompt_length)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
self.patch_size = patch_size
self.input_resolution = input_resolution
def forward(self, x: torch.Tensor, dense=False):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
if dense and (x.shape[1] != self.positional_embedding.shape[0]):
x = x + self.resized_pos_embed(self.input_resolution, x.shape[1]).to(x.dtype)
else:
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, dense)
x = x.permute(1, 0, 2) # LND -> NLD
if dense:
x = self.ln_post(x[:, :, :])
else:
x = self.ln_post(x[:, 0, :])
if self.proj is not None:
x = x @ self.proj
return x
def resized_pos_embed(self, in_res, tgt_res, mode="bicubic"):
#assert L == (input_resolution // self.patch_size) ** 2 + 1
L, D = self.positional_embedding.shape
in_side = in_res // self.patch_size
#tgt_side = tgt_res // self.patch_size
tgt_side = int((tgt_res - 1) ** 0.5)
cls_pos = self.positional_embedding[0].unsqueeze(0) # 1 D
pos_embed = self.positional_embedding[1:].reshape(1, in_side, in_side, D).permute(0, 3, 1, 2) # L-1 D -> 1 D S S
resized_pos_embed = F.interpolate(pos_embed, size=(tgt_side, tgt_side), mode=mode, align_corners=False,) # 1 D S S -> 1 D S' S'
resized_pos_embed = resized_pos_embed.squeeze(0).reshape(D, -1).T # L'-1 D
return torch.cat((cls_pos, resized_pos_embed), dim=0)
class CLIP(nn.Module):
def __init__(self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int,
# prompt
prompt_depth: int=0,
prompt_length: int=0,
):
super().__init__()
self.context_length = context_length
self.image_resolution = image_resolution
if isinstance(vision_layers, (tuple, list)):
assert prompt_length == 0 and prompt_depth==0
vision_heads = vision_width * 32 // 64
self.visual = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width
)
else:
vision_heads = vision_width // 64
self.visual = VisualTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim,
prompt_depth=prompt_depth,
prompt_length=prompt_length,
)
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask()
)
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]))
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
@property
def dtype(self):
return self.visual.conv1.weight.dtype
def encode_image(self, image, masks=None, pool_mask=None, dense=False):
if pool_mask is not None:
return self.visual(image.type(self.dtype), mask=pool_mask, dense=dense)
if masks == None:
return self.visual(image.type(self.dtype), dense=dense)
else:
return self.visual(image.type(self.dtype), masks.type(self.dtype))
def encode_text(self, text):
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# import pdb; pdb.set_trace()
# normalized features
# image_features shape: [1, 1024]
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_iamge = logit_scale * image_features @ text_features.t()
logits_per_text = logit_scale * text_features @ image_features.t()
# shape = [global_batch_size, global_batch_size]
return logits_per_iamge, logits_per_text
def convert_weights(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
if isinstance(l, nn.MultiheadAttention):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.half()
for name in ["text_projection", "proj"]:
if hasattr(l, name):
attr = getattr(l, name)
if attr is not None:
attr.data = attr.data.half()
model.apply(_convert_weights_to_fp16)
def build_model(state_dict: dict, prompt_depth=0, prompt_length=0):
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_resolution = vision_patch_size * grid_size
else:
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
vision_patch_size = None
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
image_resolution = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
model = CLIP(
embed_dim,
image_resolution, vision_layers, vision_width, vision_patch_size,
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers,
prompt_depth=prompt_depth, prompt_length=prompt_length,
)
for key in ["input_resolution", "context_length", "vocab_size"]:
del state_dict[key]
convert_weights(model)
model.load_state_dict(state_dict, strict=False)
return model.eval()
import gzip
import html
import os
from functools import lru_cache
import ftfy
import regex as re
@lru_cache()
def default_bpe():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
merges = merges[1:49152-256-2+1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v+'</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
pairs = get_pairs(word)
if not pairs:
return token+'</w>'
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
return text
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py
"""
Misc functions, including distributed helpers.
Mostly copy-paste from torchvision references.
"""
from typing import List, Optional
import torch
import torch.distributed as dist
import torchvision
from torch import Tensor
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
class NestedTensor(object):
def __init__(self, tensors, mask: Optional[Tensor]):
self.tensors = tensors
self.mask = mask
def to(self, device):
# type: (Device) -> NestedTensor # noqa
cast_tensor = self.tensors.to(device)
mask = self.mask
if mask is not None:
assert mask is not None
cast_mask = mask.to(device)
else:
cast_mask = None
return NestedTensor(cast_tensor, cast_mask)
def decompose(self):
return self.tensors, self.mask
def __repr__(self):
return str(self.tensors)
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
# TODO make this more general
if tensor_list[0].ndim == 3:
if torchvision._is_tracing():
# nested_tensor_from_tensor_list() does not export well to ONNX
# call _onnx_nested_tensor_from_tensor_list() instead
return _onnx_nested_tensor_from_tensor_list(tensor_list)
# TODO make it support different-sized images
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], : img.shape[2]] = False
else:
raise ValueError("not supported")
return NestedTensor(tensor, mask)
# _onnx_nested_tensor_from_tensor_list() is an implementation of
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
@torch.jit.unused
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
max_size = []
for i in range(tensor_list[0].dim()):
max_size_i = torch.max(
torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
).to(torch.int64)
max_size.append(max_size_i)
max_size = tuple(max_size)
# work around for
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
# m[: img.shape[1], :img.shape[2]] = False
# which is not yet supported in onnx
padded_imgs = []
padded_masks = []
for img in tensor_list:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
padded_imgs.append(padded_img)
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
padded_masks.append(padded_mask.to(torch.bool))
tensor = torch.stack(padded_imgs)
mask = torch.stack(padded_masks)
return NestedTensor(tensor, mask=mask)
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment