Commit 106580f9 authored by chenych's avatar chenych
Browse files

First commit

parents
Pipeline #689 failed with stages
in 0 seconds
# 模型唯一标识
modelCode=xxx
# 模型名称
modelName=hdetr_pytorch
# 模型描述
modelDescription=HDETR引入了一种混合匹配方案,这个新的匹配机制允许将多个查询分配给每个正样本,从而提高了训练效果,适用于多种视觉任务如目标检测、3D物体检测、姿势估计和对象跟踪等
# 应用场景
appScenario=推理,训练,目标检测,教育,交通,公安
# 框架类型
frameType=PyTorch
# --------------------------------------------------------
# Images Speak in Images: A Generalist Painter for In-Context Visual Learning (https://arxiv.org/abs/2212.02499)
# Github source: https://github.com/baaivision/Painter
# Copyright (c) 2022 Beijing Academy of Artificial Intelligence (BAAI)
# Licensed under The MIT License [see LICENSE for details]
# By Xinlong Wang, Wen Wang
# Based on MAE, BEiT, detectron2, Mask2Former, bts, mmcv, mmdetetection, mmpose, MIRNet, MPRNet, and Uformer codebases
# --------------------------------------------------------'
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
##########################
import fvcore.nn.weight_init as weight_init
from detectron2.layers import CNNBlockBase, Conv2d, get_norm
from fairscale.nn.checkpoint import checkpoint_wrapper
from timm.models.layers import DropPath, trunc_normal_
from timm.models.vision_transformer import Mlp
from util.vitdet_utils import (
PatchEmbed,
add_decomposed_rel_pos,
get_abs_pos,
window_partition,
window_unpartition,
LayerNorm2D,
)
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim,
num_heads=8,
qkv_bias=True,
use_rel_pos=False,
rel_pos_zero_init=True,
input_size=None,
):
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool: If True, add a learnable bias to query, key, value.
rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (int or None): Input resolution for calculating the relative positional
parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
if not rel_pos_zero_init:
trunc_normal_(self.rel_pos_h, std=0.02)
trunc_normal_(self.rel_pos_w, std=0.02)
def forward(self, x):
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)
return x
class ResBottleneckBlock(CNNBlockBase):
"""
The standard bottleneck residual block without the last activation layer.
It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
"""
def __init__(
self,
in_channels,
out_channels,
bottleneck_channels,
norm="LN",
act_layer=nn.GELU,
):
"""
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
bottleneck_channels (int): number of output channels for the 3x3
"bottleneck" conv layers.
norm (str or callable): normalization for all conv layers.
See :func:`layers.get_norm` for supported format.
act_layer (callable): activation for all conv layers.
"""
super().__init__(in_channels, out_channels, 1)
self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False)
self.norm1 = get_norm(norm, bottleneck_channels)
self.act1 = act_layer()
self.conv2 = Conv2d(
bottleneck_channels,
bottleneck_channels,
3,
padding=1,
bias=False,
)
self.norm2 = get_norm(norm, bottleneck_channels)
self.act2 = act_layer()
self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False)
self.norm3 = get_norm(norm, out_channels)
for layer in [self.conv1, self.conv2, self.conv3]:
weight_init.c2_msra_fill(layer)
for layer in [self.norm1, self.norm2]:
layer.weight.data.fill_(1.0)
layer.bias.data.zero_()
# zero init last norm layer.
self.norm3.weight.data.zero_()
self.norm3.bias.data.zero_()
def forward(self, x):
out = x
for layer in self.children():
out = layer(out)
out = x + out
return out
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=True,
drop_path=0.0,
norm_layer=nn.LayerNorm,
act_layer=nn.GELU,
use_rel_pos=False,
rel_pos_zero_init=True,
window_size=0,
use_residual_block=False,
input_size=None,
):
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
drop_path (float): Stochastic depth rate.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then not
use window attention.
use_residual_block (bool): If True, use a residual block after the MLP block.
input_size (int or None): Input resolution for calculating the relative positional
parameter size.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer)
self.window_size = window_size
self.use_residual_block = use_residual_block
if use_residual_block:
# Use a residual block with bottleneck channel as dim // 2
self.residual = ResBottleneckBlock(
in_channels=dim,
out_channels=dim,
bottleneck_channels=dim // 2,
norm="LN",
act_layer=act_layer,
)
def forward(self, x):
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
if self.use_residual_block:
x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
return x
class Painter(nn.Module):
""" Masked Autoencoder with VisionTransformer backbone
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4.,
qkv_bias=True,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
act_layer=nn.GELU,
use_abs_pos=True,
use_rel_pos=False,
rel_pos_zero_init=True,
window_size=0,
window_block_indexes=(),
residual_block_indexes=(),
use_act_checkpoint=False,
pretrain_img_size=224,
pretrain_use_cls_token=True,
out_feature="last_feat",
decoder_embed_dim=128,
loss_func="smoothl1",
):
super().__init__()
# --------------------------------------------------------------------------
self.pretrain_use_cls_token = pretrain_use_cls_token
self.patch_size = patch_size
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.patch_embed.num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, embed_dim))
self.segment_token_x = nn.Parameter(torch.zeros(1, 1, 1, embed_dim))
self.segment_token_y = nn.Parameter(torch.zeros(1, 1, 1, embed_dim))
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim), requires_grad=True)
else:
self.pos_embed = None
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i in window_block_indexes else 0,
use_residual_block=i in residual_block_indexes,
input_size=(img_size[0] // patch_size, img_size[1] // patch_size),
)
if use_act_checkpoint:
block = checkpoint_wrapper(block)
self.blocks.append(block)
self._out_feature_channels = {out_feature: embed_dim}
self._out_feature_strides = {out_feature: patch_size}
self._out_features = [out_feature]
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=0.02)
self.norm = norm_layer(embed_dim)
# --------------------------------------------------------------------------
# --------------------------------------------------------------------------
self.decoder_embed_dim = decoder_embed_dim
self.decoder_embed = nn.Linear(embed_dim*4, patch_size ** 2 * self.decoder_embed_dim, bias=True)
self.decoder_pred = nn.Sequential(
nn.Conv2d(self.decoder_embed_dim, self.decoder_embed_dim, kernel_size=3, padding=1, ),
LayerNorm2D(self.decoder_embed_dim),
nn.GELU(),
nn.Conv2d(self.decoder_embed_dim, 3, kernel_size=1, bias=True),
)
# --------------------------------------------------------------------------
self.loss_func = loss_func
torch.nn.init.normal_(self.mask_token, std=.02)
torch.nn.init.normal_(self.segment_token_x, std=.02)
torch.nn.init.normal_(self.segment_token_y, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def patchify(self, imgs):
"""
imgs: (N, 3, H, W)
x: (N, L, patch_size**2 *3)
"""
p = self.patch_size
assert imgs.shape[2] == 2 * imgs.shape[3] and imgs.shape[2] % p == 0
w = imgs.shape[3] // p
h = w * 2
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
return x
def unpatchify(self, x):
"""
x: (N, L, patch_size**2 *3)
imgs: (N, 3, H, W)
"""
p = self.patch_size
w = int((x.shape[1]*0.5)**.5)
h = w * 2
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], 3, h * p, w * p))
return imgs
def forward_encoder(self, imgs, tgts, bool_masked_pos):
# embed patches
x = self.patch_embed(imgs)
y = self.patch_embed(tgts)
batch_size, Hp, Wp, _ = x.size()
seq_len = Hp * Wp
mask_token = self.mask_token.expand(batch_size, Hp, Wp, -1)
# replace the masked visual tokens by mask_token
w = bool_masked_pos.unsqueeze(-1).type_as(mask_token).reshape(-1, Hp, Wp, 1)
y = y * (1 - w) + mask_token * w
# add pos embed w/o cls token
x = x + self.segment_token_x
y = y + self.segment_token_y
if self.pos_embed is not None:
x = x + get_abs_pos(
self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
)
y = y + get_abs_pos(
self.pos_embed, self.pretrain_use_cls_token, (y.shape[1], y.shape[2])
)
merge_idx = 2
x = torch.cat((x, y), dim=0)
# apply Transformer blocks
out = []
for idx, blk in enumerate(self.blocks):
x = blk(x)
if idx == merge_idx:
x = (x[:x.shape[0]//2] + x[x.shape[0]//2:]) * 0.5
if idx in [5, 11, 17, 23]:
out.append(self.norm(x))
return out
def forward_decoder(self, x):
# predictor projection
x = torch.cat(x, dim=-1)
x = self.decoder_embed(x)
p = self.patch_size
h, w = x.shape[1], x.shape[2]
x = x.reshape(shape=(x.shape[0], h, w, p, p, self.decoder_embed_dim))
x = torch.einsum('nhwpqc->nchpwq', x)
x = x.reshape(shape=(x.shape[0], -1, h * p, w * p))
x = self.decoder_pred(x) # Bx3xHxW
return x
def forward_loss(self, pred, tgts, mask, valid):
"""
tgts: [N, 3, H, W]
pred: [N, 3, H, W]
mask: [N, L], 0 is keep, 1 is remove,
valid: [N, 3, H, W]
"""
mask = mask[:, :, None].repeat(1, 1, self.patch_size**2 * 3)
mask = self.unpatchify(mask)
# ignore if the unmasked pixels are all zeros
imagenet_mean=torch.tensor([0.485, 0.456, 0.406]).to(tgts.device)[None, :, None, None]
imagenet_std=torch.tensor([0.229, 0.224, 0.225]).to(tgts.device)[None, :, None, None]
inds_ign = ((tgts * imagenet_std + imagenet_mean) * (1 - 1.*mask)).sum((1, 2, 3)) < 100*3
if inds_ign.sum() > 0:
valid[inds_ign] = 0.
mask = mask * valid
target = tgts
if self.loss_func == "l1l2":
loss = ((pred - target).abs() + (pred - target) ** 2.) * 0.5
elif self.loss_func == "l1":
loss = (pred - target).abs()
elif self.loss_func == "l2":
loss = (pred - target) ** 2.
elif self.loss_func == "smoothl1":
loss = F.smooth_l1_loss(pred, target, reduction="none", beta=0.01)
loss = (loss * mask).sum() / (mask.sum() + 1e-2) # mean loss on removed patches
return loss
def forward(self, imgs, tgts, bool_masked_pos=None, valid=None):
if bool_masked_pos is None:
bool_masked_pos = torch.zeros((imgs.shape[0], self.patch_embed.num_patches), dtype=torch.bool).to(imgs.device)
else:
bool_masked_pos = bool_masked_pos.flatten(1).to(torch.bool)
latent = self.forward_encoder(imgs, tgts, bool_masked_pos)
pred = self.forward_decoder(latent) # [N, L, p*p*3]
loss = self.forward_loss(pred, tgts, bool_masked_pos, valid)
return loss, self.patchify(pred), bool_masked_pos
def painter_vit_large_patch16_input896x448_win_dec64_8glb_sl1(**kwargs):
model = Painter(
img_size=(896, 448), patch_size=16, embed_dim=1024, depth=24, num_heads=16,
drop_path_rate=0.1, window_size=14, qkv_bias=True,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6),
window_block_indexes=(list(range(0, 2)) + list(range(3, 5)) + list(range(6, 8)) + list(range(9, 11)) + \
list(range(12, 14)), list(range(15, 17)), list(range(18, 20)), list(range(21, 23))),
residual_block_indexes=[], use_rel_pos=True, out_feature="last_feat",
decoder_embed_dim=64,
loss_func="smoothl1",
**kwargs)
return model
def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
"""
Calculate lr decay rate for different ViT blocks.
Args:
name (string): parameter name.
lr_decay_rate (float): base lr decay rate.
num_layers (int): number of ViT blocks.
Returns:
lr decay rate for the given parameter.
"""
layer_id = num_layers + 1
if name.startswith("backbone"):
if ".pos_embed" in name or ".patch_embed" in name:
layer_id = 0
elif ".blocks." in name and ".residual." not in name:
layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
return lr_decay_rate ** (num_layers + 1 - layer_id)
timm==0.3.2
git+https://github.com/cocodataset/panopticapi.git
h5py # for depth
xtcocotools # for pose
natsort # for denoising
wandb
\ No newline at end of file
#!/bin/bash
DATA_PATH=datasets
name=painter_vit_large
python -m torch.distributed.launch --nproc_per_node=8 \
--nnodes=${WORLD_SIZE} --node_rank=$RANK \
--master_addr=$MASTER_ADDR --master_port=12358 \
--use_env main_train.py \
--batch_size 2 \
--accum_iter 16 \
--model painter_vit_large_patch16_input896x448_win_dec64_8glb_sl1 \
--num_mask_patches 784 \
--max_mask_patches_per_block 392 \
--epochs 15 \
--warmup_epochs 1 \
--lr 1e-3 \
--clip_grad 3 \
--layer_decay 0.8 \
--drop_path 0.1 \
--input_size 896 448 \
--save_freq 1 \
--data_path $DATA_PATH/ \
--json_path \
$DATA_PATH/nyu_depth_v2/nyuv2_sync_image_depth.json \
$DATA_PATH/ade20k/ade20k_training_image_semantic.json \
$DATA_PATH/coco/pano_ca_inst/coco_train_image_panoptic_inst.json \
$DATA_PATH/coco/pano_sem_seg/coco_train2017_image_panoptic_sem_seg.json \
$DATA_PATH/coco_pose/coco_pose_256x192_train.json \
$DATA_PATH/denoise/denoise_ssid_train.json \
$DATA_PATH/derain/derain_train.json \
$DATA_PATH/light_enhance/enhance_lol_train.json \
--val_json_path \
$DATA_PATH/nyu_depth_v2/nyuv2_test_image_depth.json \
$DATA_PATH/ade20k/ade20k_validation_image_semantic.json \
$DATA_PATH/coco/pano_ca_inst/coco_val_image_panoptic_inst.json \
$DATA_PATH/coco/pano_sem_seg/coco_val2017_image_panoptic_sem_seg.json \
$DATA_PATH/coco_pose/coco_pose_256x192_val.json \
$DATA_PATH/denoise/denoise_ssid_val.json \
$DATA_PATH/derain/derain_test_rain100h.json \
$DATA_PATH/light_enhance/enhance_lol_val.json \
--output_dir models/$name \
--log_dir models/$name/logs \
--finetune path/to/mae_pretrain_vit_large.pth \
# --log_wandb \
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from torchvision import transforms
from torchvision.transforms import functional as F
class RandomResizedCrop(transforms.RandomResizedCrop):
"""
RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
This may lead to results different with torchvision's version.
Following BYOL's TF code:
https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
"""
@staticmethod
def get_params(img, scale, ratio):
width, height = F._get_image_size(img)
area = height * width
target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
log_ratio = torch.log(torch.tensor(ratio))
aspect_ratio = torch.exp(
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
).item()
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
w = min(w, width)
h = min(h, height)
i = torch.randint(0, height - h + 1, size=(1,)).item()
j = torch.randint(0, width - w + 1, size=(1,)).item()
return i, j, h, w
\ No newline at end of file
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
import os
import PIL
from torchvision import datasets, transforms
from timm.data import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def build_dataset(is_train, args):
transform = build_transform(is_train, args)
root = os.path.join(args.data_path, 'train' if is_train else 'val')
dataset = datasets.ImageFolder(root, transform=transform)
print(dataset)
return dataset
def build_transform(is_train, args):
mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD
# train transform
if is_train:
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=args.input_size,
is_training=True,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation='bicubic',
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
mean=mean,
std=std,
)
return transform
# eval transform
t = []
if args.input_size <= 224:
crop_pct = 224 / 256
else:
crop_pct = 1.0
size = int(args.input_size / crop_pct)
t.append(
transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(args.input_size))
t.append(transforms.ToTensor())
t.append(transforms.Normalize(mean, std))
return transforms.Compose(t)
import os
import glob
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset
import torch.distributed as dist
class DatasetTest(Dataset):
"""
define dataset for ddp
"""
def __init__(self, img_src_dir, input_size, ext_list=('*.png', '*.jpg'), ):
super(DatasetTest, self).__init__()
self.img_src_dir = img_src_dir
self.input_size = input_size
img_path_list = []
for ext in ext_list:
img_path_tmp = glob.glob(os.path.join(img_src_dir, ext))
img_path_list.extend(img_path_tmp)
self.img_path_list = img_path_list
def __len__(self):
return len(self.img_path_list)
def __getitem__(self, index):
img_path = self.img_path_list[index]
img = Image.open(img_path).convert("RGB")
size_org = img.size
img = img.resize((self.input_size, self.input_size))
img = np.array(img) / 255.
return img, img_path, size_org
def collate_fn(batch):
return batch
# batch = list(zip(*batch))
# return tuple(batch)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ and 'LOCAL_RANK' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
else:
print('Not using distributed mode')
args.distributed = False
return args
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
return args
import os
import glob
import json
import tqdm
import argparse
import shutil
def get_args_parser():
parser = argparse.ArgumentParser('get toy dataset for json', add_help=False)
parser.add_argument('--json_path', type=str, help='path to json file', required=True)
parser.add_argument('--data_s', type=str, default='datasets')
parser.add_argument('--data_t', type=str, default='toy_datasets')
parser.add_argument('--num_sample', type=int, help='number of samples', default=10)
return parser.parse_args()
if __name__ == "__main__":
args = get_args_parser()
dataset_full = json.load(open(args.json_path, 'r'))
dataset = dataset_full[:args.num_sample]
for data in dataset:
image_path_src = os.path.join(args.data_s, data['image_path'])
target_path_src = os.path.join(args.data_s, data['target_path'])
print(image_path_src)
image_path_tgt = os.path.join(args.data_t, data['image_path'])
target_path_tgt = os.path.join(args.data_t, data['target_path'])
if not os.path.exists(os.path.dirname(image_path_tgt)):
os.makedirs(os.path.dirname(image_path_tgt))
if not os.path.exists(os.path.dirname(target_path_tgt)):
os.makedirs(os.path.dirname(target_path_tgt))
shutil.copy(image_path_src, image_path_tgt)
shutil.copy(target_path_src, target_path_tgt)
save_path = args.json_path.replace('datasets', 'toy_datasets')
json.dump(dataset, open(save_path, 'w'))
print(save_path)
\ No newline at end of file
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# LARS optimizer, implementation from MoCo v3:
# https://github.com/facebookresearch/moco-v3
# --------------------------------------------------------
import torch
class LARS(torch.optim.Optimizer):
"""
LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
"""
def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)
super().__init__(params, defaults)
@torch.no_grad()
def step(self):
for g in self.param_groups:
for p in g['params']:
dp = p.grad
if dp is None:
continue
if p.ndim > 1: # if not normalization gamma/beta or bias
dp = dp.add(p, alpha=g['weight_decay'])
param_norm = torch.norm(p)
update_norm = torch.norm(dp)
one = torch.ones_like(param_norm)
q = torch.where(param_norm > 0.,
torch.where(update_norm > 0,
(g['trust_coefficient'] * param_norm / update_norm), one),
one)
dp = dp.mul(q)
param_state = self.state[p]
if 'mu' not in param_state:
param_state['mu'] = torch.zeros_like(p)
mu = param_state['mu']
mu.mul_(g['momentum']).add_(dp)
p.add_(mu, alpha=-g['lr'])
\ No newline at end of file
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# ELECTRA https://github.com/google-research/electra
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import json
def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
"""
Parameter groups for layer-wise lr decay
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
"""
param_group_names = {}
param_groups = {}
num_layers = len(model.blocks) + 1
layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
for n, p in model.named_parameters():
if not p.requires_grad:
continue
# no decay: all 1D parameters and model specific ones
if p.ndim == 1 or n in no_weight_decay_list:
g_decay = "no_decay"
this_decay = 0.
else:
g_decay = "decay"
this_decay = weight_decay
layer_id = get_layer_id_for_vit(n, num_layers)
group_name = "layer_%d_%s" % (layer_id, g_decay)
if group_name not in param_group_names:
this_scale = layer_scales[layer_id]
param_group_names[group_name] = {
"lr_scale": this_scale,
"weight_decay": this_decay,
"params": [],
}
param_groups[group_name] = {
"lr_scale": this_scale,
"weight_decay": this_decay,
"params": [],
}
param_group_names[group_name]["params"].append(n)
param_groups[group_name]["params"].append(p)
# print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
return list(param_groups.values())
def get_layer_id_for_vit(name, num_layers):
"""
Assign a parameter with its layer id
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
"""
if name in ['cls_token', 'pos_embed']:
return 0
elif name.startswith('patch_embed'):
return 0
elif name.startswith('blocks'):
return int(name.split('.')[1]) + 1
else:
return num_layers
\ No newline at end of file
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
def adjust_learning_rate(optimizer, epoch, args):
"""Decay the learning rate with half-cycle cosine after warmup"""
if epoch < args.warmup_epochs:
lr = args.lr * epoch / args.warmup_epochs
else:
lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
(1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
for param_group in optimizer.param_groups:
if "lr_scale" in param_group:
param_group["lr"] = lr * param_group["lr_scale"]
else:
param_group["lr"] = lr
return lr
# --------------------------------------------------------
# Images Speak in Images: A Generalist Painter for In-Context Visual Learning (https://arxiv.org/abs/2212.02499)
# Github source: https://github.com/baaivision/Painter
# Copyright (c) 2022 Beijing Academy of Artificial Intelligence (BAAI)
# Licensed under The MIT License [see LICENSE for details]
# By Xinlong Wang, Wen Wang
# Based on MAE, BEiT, detectron2, Mask2Former, bts, mmcv, mmdetetection, mmpose, MIRNet, MPRNet, and Uformer codebases
# --------------------------------------------------------'
import random
import math
import numpy as np
class MaskingGenerator:
def __init__(
self, input_size, num_masking_patches, min_num_patches=4, max_num_patches=None,
min_aspect=0.3, max_aspect=None):
if not isinstance(input_size, tuple):
input_size = (input_size,) * 2
self.height, self.width = input_size
self.num_patches = self.height * self.width
self.num_masking_patches = num_masking_patches
self.min_num_patches = min_num_patches
self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches
max_aspect = max_aspect or 1 / min_aspect
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
def __repr__(self):
repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
self.height, self.width, self.min_num_patches, self.max_num_patches,
self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1])
return repr_str
def get_shape(self):
return self.height, self.width
def _mask(self, mask, max_mask_patches):
delta = 0
for attempt in range(10):
target_area = random.uniform(self.min_num_patches, max_mask_patches)
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < self.width and h < self.height:
top = random.randint(0, self.height - h)
left = random.randint(0, self.width - w)
num_masked = mask[top: top + h, left: left + w].sum()
# Overlap
if 0 < h * w - num_masked <= max_mask_patches:
for i in range(top, top + h):
for j in range(left, left + w):
if mask[i, j] == 0:
mask[i, j] = 1
delta += 1
if delta > 0:
break
return delta
def __call__(self):
mask = np.zeros(shape=self.get_shape(), dtype=np.int32)
mask_count = 0
while mask_count < self.num_masking_patches:
max_mask_patches = self.num_masking_patches - mask_count
max_mask_patches = min(max_mask_patches, self.max_num_patches)
delta = self._mask(mask, max_mask_patches)
if delta == 0:
break
else:
mask_count += delta
# maintain a fix number {self.num_masking_patches}
if mask_count > self.num_masking_patches:
delta = mask_count - self.num_masking_patches
mask_x, mask_y = mask.nonzero()
to_vis = np.random.choice(mask_x.shape[0], delta, replace=False)
mask[mask_x[to_vis], mask_y[to_vis]] = 0
elif mask_count < self.num_masking_patches:
delta = self.num_masking_patches - mask_count
mask_x, mask_y = (mask == 0).nonzero()
to_mask = np.random.choice(mask_x.shape[0], delta, replace=False)
mask[mask_x[to_mask], mask_y[to_mask]] = 1
assert mask.sum() == self.num_masking_patches, f"mask: {mask}, mask count {mask.sum()}"
return mask
if __name__ == '__main__':
import pdb
generator = MaskingGenerator(input_size=14, num_masking_patches=118, min_num_patches=16, )
for i in range(10000000):
mask = generator()
if mask.sum() != 118:
pdb.set_trace()
print(mask)
print(mask.sum())
\ No newline at end of file
# Copyright (c) OpenMMLab. All rights reserved.
import torch
def mask_matrix_nms(masks,
labels,
scores,
filter_thr=-1,
nms_pre=-1,
max_num=-1,
kernel='gaussian',
sigma=2.0,
mask_area=None):
"""Matrix NMS for multi-class masks.
Args:
masks (Tensor): Has shape (num_instances, h, w)
labels (Tensor): Labels of corresponding masks,
has shape (num_instances,).
scores (Tensor): Mask scores of corresponding masks,
has shape (num_instances).
filter_thr (float): Score threshold to filter the masks
after matrix nms. Default: -1, which means do not
use filter_thr.
nms_pre (int): The max number of instances to do the matrix nms.
Default: -1, which means do not use nms_pre.
max_num (int, optional): If there are more than max_num masks after
matrix, only top max_num will be kept. Default: -1, which means
do not use max_num.
kernel (str): 'linear' or 'gaussian'.
sigma (float): std in gaussian method.
mask_area (Tensor): The sum of seg_masks.
Returns:
tuple(Tensor): Processed mask results.
- scores (Tensor): Updated scores, has shape (n,).
- labels (Tensor): Remained labels, has shape (n,).
- masks (Tensor): Remained masks, has shape (n, w, h).
- keep_inds (Tensor): The indices number of
the remaining mask in the input mask, has shape (n,).
"""
assert len(labels) == len(masks) == len(scores)
if len(labels) == 0:
return scores.new_zeros(0), labels.new_zeros(0), masks.new_zeros(
0, *masks.shape[-2:]), labels.new_zeros(0)
if mask_area is None:
mask_area = masks.sum((1, 2)).float()
else:
assert len(masks) == len(mask_area)
# sort and keep top nms_pre
scores, sort_inds = torch.sort(scores, descending=True)
keep_inds = sort_inds
if nms_pre > 0 and len(sort_inds) > nms_pre:
sort_inds = sort_inds[:nms_pre]
keep_inds = keep_inds[:nms_pre]
scores = scores[:nms_pre]
masks = masks[sort_inds]
mask_area = mask_area[sort_inds]
labels = labels[sort_inds]
num_masks = len(labels)
flatten_masks = masks.reshape(num_masks, -1).float()
# inter.
inter_matrix = torch.mm(flatten_masks, flatten_masks.transpose(1, 0))
expanded_mask_area = mask_area.expand(num_masks, num_masks)
# Upper triangle iou matrix.
iou_matrix = (inter_matrix /
(expanded_mask_area + expanded_mask_area.transpose(1, 0) -
inter_matrix)).triu(diagonal=1)
# label_specific matrix.
expanded_labels = labels.expand(num_masks, num_masks)
# Upper triangle label matrix.
label_matrix = (expanded_labels == expanded_labels.transpose(
1, 0)).triu(diagonal=1)
# IoU compensation
compensate_iou, _ = (iou_matrix * label_matrix).max(0)
compensate_iou = compensate_iou.expand(num_masks,
num_masks).transpose(1, 0)
# IoU decay
decay_iou = iou_matrix * label_matrix
# Calculate the decay_coefficient
if kernel == 'gaussian':
decay_matrix = torch.exp(-1 * sigma * (decay_iou**2))
compensate_matrix = torch.exp(-1 * sigma * (compensate_iou**2))
decay_coefficient, _ = (decay_matrix / compensate_matrix).min(0)
elif kernel == 'linear':
decay_matrix = (1 - decay_iou) / (1 - compensate_iou)
decay_coefficient, _ = decay_matrix.min(0)
else:
raise NotImplementedError(
f'{kernel} kernel is not supported in matrix nms!')
# update the score.
scores = scores * decay_coefficient
if filter_thr > 0:
keep = scores >= filter_thr
keep_inds = keep_inds[keep]
if not keep.any():
return scores.new_zeros(0), labels.new_zeros(0), masks.new_zeros(
0, *masks.shape[-2:]), labels.new_zeros(0)
masks = masks[keep]
scores = scores[keep]
labels = labels[keep]
# sort and keep top max_num
scores, sort_inds = torch.sort(scores, descending=True)
keep_inds = keep_inds[sort_inds]
if max_num > 0 and len(sort_inds) > max_num:
sort_inds = sort_inds[:max_num]
keep_inds = keep_inds[:max_num]
scores = scores[:max_num]
masks = masks[sort_inds]
labels = labels[sort_inds]
return scores, labels, masks, keep_inds
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import builtins
import datetime
import os
import time
from collections import defaultdict, deque
from pathlib import Path
import json
import torch
import torch.distributed as dist
from torch._six import inf
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if v is None:
continue
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
log_msg = [
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
]
if torch.cuda.is_available():
log_msg.append('max mem: {memory:.0f}')
log_msg = self.delimiter.join(log_msg)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
builtin_print = builtins.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
force = force or (get_world_size() > 8)
if is_master or force:
now = datetime.datetime.now().time()
builtin_print('[{}] '.format(now), end='') # print with time stamp
builtin_print(*args, **kwargs)
builtins.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if args.dist_on_itp:
args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
os.environ['LOCAL_RANK'] = str(args.gpu)
os.environ['RANK'] = str(args.rank)
os.environ['WORLD_SIZE'] = str(args.world_size)
# ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
else:
print('Not using distributed mode')
setup_for_distributed(is_master=True) # hack
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}, gpu {}'.format(
args.rank, args.dist_url, args.gpu), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
class NativeScalerWithGradNormCount:
state_dict_key = "amp_scaler"
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
self._scaler.scale(loss).backward(create_graph=create_graph)
if update_grad:
if clip_grad is not None:
assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
else:
self._scaler.unscale_(optimizer)
norm = get_grad_norm_(parameters)
self._scaler.step(optimizer)
self._scaler.update()
else:
norm = None
return norm
def state_dict(self):
return self._scaler.state_dict()
def load_state_dict(self, state_dict):
self._scaler.load_state_dict(state_dict)
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
norm_type = float(norm_type)
if len(parameters) == 0:
return torch.tensor(0.)
device = parameters[0].grad.device
if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
else:
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
return total_norm
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
output_dir = Path(args.output_dir)
epoch_name = str(epoch)
if loss_scaler is not None:
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
for checkpoint_path in checkpoint_paths:
to_save = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'scaler': loss_scaler.state_dict(),
'args': args,
}
save_on_master(to_save, checkpoint_path)
else:
client_state = {'epoch': epoch}
model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
def load_model(args, model_without_ddp, optimizer, loss_scaler):
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
print("Resume checkpoint %s" % args.resume)
if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
optimizer.load_state_dict(checkpoint['optimizer'])
args.start_epoch = checkpoint['epoch'] + 1
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
print("With optim & sched!")
def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler):
output_dir = Path(args.output_dir)
if loss_scaler is not None:
# torch.amp
if args.auto_resume and len(args.resume) == 0:
import glob
all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
latest_ckpt = -1
for ckpt in all_checkpoints:
t = ckpt.split('-')[-1].split('.')[0]
if t.isdigit():
latest_ckpt = max(int(t), latest_ckpt)
if latest_ckpt >= 0:
args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
print("Auto resume checkpoint: %s" % args.resume)
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
print("Resume checkpoint %s" % args.resume)
if 'optimizer' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
args.start_epoch = checkpoint['epoch'] + 1
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
print('loss scaler', checkpoint['scaler'])
print("With optim & sched!")
else:
# deepspeed, only support '--auto_resume'.
if args.auto_resume:
import glob
all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*'))
latest_ckpt = -1
for ckpt in all_checkpoints:
t = ckpt.split('-')[-1].split('.')[0]
if t.isdigit():
latest_ckpt = max(int(t), latest_ckpt)
if latest_ckpt >= 0:
args.resume = os.path.join(output_dir, 'checkpoint-%d' % latest_ckpt)
print("Auto resume checkpoint: %d" % latest_ckpt)
_, client_states = model.load_checkpoint(args.output_dir, tag='checkpoint-%d' % latest_ckpt)
args.start_epoch = client_states['epoch'] + 1
def all_reduce_mean(x):
world_size = get_world_size()
if world_size > 1:
x_reduce = torch.tensor(x).cuda()
dist.all_reduce(x_reduce)
x_reduce /= world_size
return x_reduce.item()
else:
return x
def create_ds_config(args):
args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json")
with open(args.deepspeed_config, mode="w") as writer:
ds_config = {
"train_batch_size": args.batch_size * args.accum_iter * get_world_size(),
"train_micro_batch_size_per_gpu": args.batch_size,
"steps_per_print": 1000,
"optimizer": {
"type": "Adam",
"adam_w_mode": True,
"params": {
"lr": args.lr,
"weight_decay": args.weight_decay,
"bias_correction": True,
"betas": [
args.opt_betas[0],
args.opt_betas[1]
],
"eps": args.opt_eps
}
},
"fp16": {
"enabled": True,
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
# "bf16": {
# "enabled": True
# },
"amp": {
"enabled": False,
"opt_level": "O2"
},
"flops_profiler": {
"enabled": True,
"profile_step": -1,
"module_depth": -1,
"top_modules": 1,
"detailed": True,
},
}
if args.clip_grad is not None:
ds_config.update({'gradient_clipping': args.clip_grad})
if args.zero_stage == 1:
ds_config.update({"zero_optimization": {"stage": args.zero_stage, "reduce_bucket_size": 5e8}})
elif args.zero_stage > 1:
raise NotImplementedError()
writer.write(json.dumps(ds_config, indent=2))
def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
parameter_group_names = {}
parameter_group_vars = {}
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
group_name = "no_decay"
this_weight_decay = 0.
else:
group_name = "decay"
this_weight_decay = weight_decay
if get_num_layer is not None:
layer_id = get_num_layer(name)
group_name = "layer_%d_%s" % (layer_id, group_name)
else:
layer_id = None
if group_name not in parameter_group_names:
if get_layer_scale is not None:
scale = get_layer_scale(layer_id)
else:
scale = 1.
parameter_group_names[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"lr_scale": scale
}
parameter_group_vars[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"lr_scale": scale
}
parameter_group_vars[group_name]["params"].append(param)
parameter_group_names[group_name]["params"].append(name)
print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
return list(parameter_group_vars.values())
# -*- coding: utf-8 -*-
import sys
import os
import requests
import argparse
import torch
import torch.nn.functional as F
import numpy as np
import glob
import tqdm
import matplotlib.pyplot as plt
from PIL import Image
sys.path.append('.')
import models_painter
imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])
def get_args_parser():
parser = argparse.ArgumentParser('Painter Demo Inference', add_help=False)
parser.add_argument('--ckpt_dir', type=str, help='dir to ckpt',
default='')
parser.add_argument('--model', type=str, help='dir to ckpt',
default='painter_vit_large_patch16_input896x448_win_dec64_8glb_sl1')
parser.add_argument('--epoch', type=int, help='model epochs',
default=14)
return parser.parse_args()
def prepare_model(chkpt_dir, arch='painter_vit_large_patch16_input896x448_win_dec64_8glb_sl1'):
# build model
model = getattr(models_painter, arch)()
# load model
checkpoint = torch.load(chkpt_dir, map_location='cuda:0')
msg = model.load_state_dict(checkpoint['model'], strict=False)
print(msg)
model.eval()
return model
def run_one_image(img, tgt, size, model, out_path, device):
x = torch.tensor(img)
x = x.unsqueeze(dim=0)
x = torch.einsum('nhwc->nchw', x)
tgt = torch.tensor(tgt)
tgt = tgt.unsqueeze(dim=0)
tgt = torch.einsum('nhwc->nchw', tgt)
bool_masked_pos = torch.zeros(model.patch_embed.num_patches)
bool_masked_pos[model.patch_embed.num_patches//2:] = 1
bool_masked_pos = bool_masked_pos.unsqueeze(dim=0)
valid = torch.ones_like(tgt)
loss, y, mask = model(x.float().to(device), tgt.float().to(device), bool_masked_pos.to(device), valid.float().to(device))
y = model.unpatchify(y)
y = torch.einsum('nchw->nhwc', y).detach().cpu()
output = y[0, y.shape[1]//2:, :, :]
output = torch.clip((output * imagenet_std + imagenet_mean) * 255, 0, 255)
output = F.interpolate(output[None, ...].permute(0, 3, 1, 2), size=[size[1], size[0]], mode='nearest').permute(0, 2, 3, 1)[0]
output = output.int()
output = Image.fromarray(output.numpy().astype(np.uint8))
output.save(out_path)
if __name__ == '__main__':
args = get_args_parser()
ckpt_dir = args.ckpt_dir
model = args.model
epoch = args.epoch
ckpt_file = 'checkpoint-{}.pth'.format(epoch)
assert ckpt_dir[-1] != "/"
ckpt_path = os.path.join(ckpt_dir, ckpt_file)
model_painter = prepare_model(ckpt_path, model)
print('Model loaded.')
device = torch.device("cuda")
model_painter.to(device)
img2_path = "path/to/img2"
tgt2_path = "path/to/tgt2"
img_path = "path/to/img"
img_name = os.path.basename(img_path)
out_path = os.path.join("path/to/out", img_name.replace('.jpg', '.png'))
res = 448
img2 = Image.open(img2_path).convert("RGB")
img2 = img2.resize((res, res))
img2 = np.array(img2) / 255.
tgt2 = Image.open(tgt2_path).convert("RGB")
tgt2 = tgt2.resize((res, res), Image.NEAREST)
tgt2 = np.array(tgt2) / 255.
img = Image.open(img_path).convert("RGB")
size = img.size
img = img.resize((res, res))
img = np.array(img) / 255.
tgt = tgt2 # tgt is not available
tgt = np.concatenate((tgt2, tgt), axis=0)
img = np.concatenate((img2, img), axis=0)
assert img.shape == (2*res, res, 3)
# normalize by ImageNet mean and std
img = img - imagenet_mean
img = img / imagenet_std
assert tgt.shape == (2*res, res, 3)
# normalize by ImageNet mean and std
tgt = tgt - imagenet_mean
tgt = tgt / imagenet_std
torch.manual_seed(2)
run_one_image(img, tgt, size, model_painter, out_path, device)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Position embedding utils
# --------------------------------------------------------
import numpy as np
import torch
# --------------------------------------------------------
# 2D sine-cosine position embedding
# References:
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
# MoCo v3: https://github.com/facebookresearch/moco-v3
# --------------------------------------------------------
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
# --------------------------------------------------------
# Interpolate position embeddings for high-resolution
# References:
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
def interpolate_pos_embed(model, checkpoint_model):
if 'pos_embed' in checkpoint_model:
pos_embed_checkpoint = checkpoint_model['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model['pos_embed'] = new_pos_embed
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = [
"window_partition",
"window_unpartition",
"add_decomposed_rel_pos",
"get_abs_pos",
"PatchEmbed",
]
def window_partition(x, window_size):
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(windows, window_size, pad_hw, hw):
"""
Window unpartition into original sequences and removing padding.
Args:
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def get_rel_pos(q_size, k_size, rel_pos):
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
attn = (
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
).view(B, q_h * q_w, k_h * k_w)
return attn
def get_abs_pos(abs_pos, has_cls_token, hw):
"""
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
dimension for the original embeddings.
Args:
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
hw (Tuple): size of input image tokens.
Returns:
Absolute positional embeddings after processing with shape (1, H, W, C)
"""
h, w = hw
if has_cls_token:
abs_pos = abs_pos[:, 1:]
xy_num = abs_pos.shape[1]
size = int(math.sqrt(xy_num))
assert size * size == xy_num
if size != h or size != w:
new_abs_pos = F.interpolate(
abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
size=(h, w),
mode="bicubic",
align_corners=False,
)
return new_abs_pos.permute(0, 2, 3, 1)
else:
return abs_pos.reshape(1, h, w, -1)
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768
):
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
def forward(self, x):
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x
class LayerNorm2D(nn.Module):
"""
A LayerNorm variant, popularized by Transformers, that performs point-wise mean and
variance normalization over the channel dimension for inputs that have shape
(batch_size, channels, height, width).
https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950
"""
def __init__(self, normalized_shape, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.normalized_shape = (normalized_shape,)
def forward(self, x):
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment