Commit 0911606e authored by wanglch's avatar wanglch
Browse files

Initial commit

parent bde8d813
# TextMonkey: An OCR-Free Large Multimodal Model for Understanding Document
<br>
<p align="center">
<img src="https://v1.ax1x.com/2024/04/13/7ySD7w.png" width="300"/>
<p>
> [**TextMonkey: An OCR-Free Large Multimodal Model for Understanding Document**](https://arxiv.org/abs/2403.04473)<br>
> Yuliang Liu, Biao Yang, Qiang Liu, Zhang Li, Zhiyin Ma, Shuo Zhang, Xiang Bai <br>
[![arXiv](https://img.shields.io/badge/Arxiv-2403.04473-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2403.04473)
[![Source_code](https://img.shields.io/badge/Code-Available-white)](monkey_model/text_monkey/README.md)
[![Demo](https://img.shields.io/badge/Demo-blue)](http://vlrlab-monkey.xyz:7684/)
[![Data](https://img.shields.io/badge/Data-yellow)](https://huggingface.co/datasets/MelosY/TextMonkey_Data)
[![Model Weight](https://img.shields.io/badge/Model_Weight-gray)](https://www.modelscope.cn/models/lvskiller/TextMonkey)
-----
**TextMonkey** is a multi-modal large model (LMM) focused on text-related tasks, including document question answering and scene text question answering. Compared with Monkey, TextMonkey has been improved in many aspects: by using zero-initialized Shifted Window Attention, TextMonkey realizes information interaction between windows at a higher input resolution; by calculating similarity to filter out important image features, not only can it simplify the input, but it can also improve the performance of the model. Furthermore, TextMonkey enhances interpretability and reduces hallucinations by extending multiple text-related tasks and incorporating location information into responses. At the same time, after fine-tuning, TextMonkey can also have the ability to understand user instructions and click on the corresponding location in the APP Agent, demonstrating its huge potential for downstream applications.
# TODO
- [x] Open source code, weight, and data
- [ ] Support training using 3090 GPUs (24Gb video memory)
- [ ] Improve Chinese language proficiency
- [ ] TextMonkey with different LLMs
# Model Zoo
TextMonkey was trained using 8 A800 GPUs on a dataset of 400k data, requiring approximately 1 day and 6 hours of training time. It is capable of running inference on a 3090 GPU.
| Method | LLM | STVQA | TextVQA | OCRVQA | DocVQA | InfoVQA | ChartQA | FUNSD | SROIE | POIE | OCRBench |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| BLIP2-OPT-6.7B | OPT-6.7B | 20.9 | 23.5 | 9.7 | 3.2 | 11.3 | 3.4 | 0.2 | 0.1 | 0.3 | 235 |
| mPLUG-Owl | LLaMA-7B | 30.5 | 34.0 | 21.1 | 7.4 | 20.0 | 7.9 | 0.5 | 1.7 | 2.5 | 297 |
| InstructBLIP | Vircuna-7B | 27.4 | 29.1 | 41.3 | 4.5 | 16.4 | 5.3 | 0.2 | 0.6 | 1.0 | 276 |
| LLaVAR | Vircuna-7B | 39.2 | 41.8 | 24.0 | 12.3 | 16.5 | 12.2 | 0.5 | 5.2 | 5.9 | 346 |
| BLIVA | Vircuna-7B | 32.1 | 33.3 | 50.7 | 5.8 | 23.6 | 8.7 | 0.2 | 0.7 | 2.1 | 291 |
| mPLUG-Owl2 | LLaMA-7B | 49.8 | 53.9 | 58.7 | 17.9 | 18.9 | 19.4 | 1.4 | 3.2 | 9.9 | 366 |
| LLaVA1.5-7B$ | Vircuna-7B | 38.1 | 38.7 | 58.1 | 8.5 | 14.7 | 9.3 | 0.2 | 1.7 | 2.5 | 297 |
| TGDoc$ | Vircuna-7B | 36.3 | 46.2 | 37.2 | 9.0 | 12.8 | 12.7 | 1.4 | 3.0 | 22.2 | - |
| UniDoc | Vircuna-7B | 35.2 | 46.2 | 36.8 | 7.7 | 14.7 | 10.9 | 1.0 | 2.9 | 5.1 | - |
| DocPedia | Vircuna-7B | 45.5 | 60.2 | 57.2 | 47.1 | 15.2 | 46.9 | 9.9 | 21.4 | 39.9 | - |
| Monkey | Qwen-7B | 54.7 | 64.3 | 64.4 | 50.1 | 25.8 | 54.0 | 24.1 | 41.9 | 19.9 | 514 |
| InternVL | - | 62.2 | 59.8 | 30.5 | 28.7 | 23.6 | 45.6 | 6.5 | 26.4 | 25.9 | 517 |
| InternLM-XComposer2 | InternLM-7B | 59.6 | 62.2 | 49.6 | 39.7 | 28.6 | 51.6 | 15.3 | 34.2 | 49.3 | 511 |
| TextMonkey (~400k data)| Qwen-7B | 61.8 | 65.9 | 71.3 | 64.3 | 28.2 | 58.2 | 32.3 | 47.0 | 27.9 | 561 |
| TextMonkey (~500k data) | Qwen-7B | 61.2 | 64.3 | 72.2 | 66.7 | 28.6 | 59.9 | 42.9 | 46.2 | 32.0 | 558 |
## Environment
```python
conda create -n textmonkey python=3.10
conda activate textmonkey
git clone https://github.com/Yuliang-Liu/Monkey.git
cd ./Monkey
pip install -r requirements.txt
```
## Evaluate
We also offer TextMonkey's model testing code, which you can explore above. You can execute the training code through executing:
```python
bash eval/eval_doc.sh
```
## Train
Execute the training code:
```python
bash finetune/finetune_textmonkey.sh
```
## Cases
TextMonkey can accurately locate and recognize text in both scene images and document images. In addition, the natural image in (a), the document in (b), the diagram in (c), and the table in (d) all demonstrate TextMonkey’s ability to identify, understand, and locate text information in a variety of scenarios.
<br>
<p align="center">
<img src="https://v1.ax1x.com/2024/04/13/7ySSXO.png" width="700"/>
<p>
<br>
TextMonkey has shown strong feasibility as an agent for smartphone applications. After fine-tuning using 15k user click data from the Rico dataset, TextMonkey was able to understand user intent and click the corresponding icon.
<br>
<p align="center">
<img src="https://v1.ax1x.com/2024/04/13/7ySOV6.png" width="700"/>
<p>
<br>
## Citing TextMonkey
If you wish to refer to the baseline results published here, please use the following BibTeX entries:
```BibTeX
@article{liu2024textmonkey,
title={TextMonkey: An OCR-Free Large Multimodal Model for Understanding Document},
author={Liu, Yuliang and Yang, Biao and Liu, Qiang and Li, Zhang and Ma, Zhiyin and Zhang, Shuo and Bai, Xiang},
journal={arXiv preprint arXiv:2403.04473},
year={2024}
}
```
## Copyright
We welcome suggestions to help us improve the TextMonkey. For any query, please contact Dr. Yuliang Liu: ylliu@hust.edu.cn. If you find something interesting, please also feel free to share with us through email or open an issue.
import math
from typing import Callable, Tuple
import torch
def self_soft_matching(
metric: torch.Tensor,
r: int,):
t = metric.shape[1]
with torch.no_grad():
metric = metric / metric.norm(dim=-1, keepdim=True)
a, b = metric[..., :, :], metric[..., :, :]
scores = a @ b.transpose(-1, -2) # a_lxb_l
b,_,_ = scores.shape
scores_diag = torch.tril(torch.ones(t,t))*2
scores_diag = scores_diag.expand(b, -1, -1).to(metric.device)
scores = scores-scores_diag
node_max, node_idx = scores.max(dim=-1) # a中最相似的点
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] # a中相似度排序并得到idx,降序
unm_idx = edge_idx[..., t-r:, :] # Unmerged Tokens # 后面的就是不merge的
def merge(src: torch.Tensor) -> torch.Tensor:
n, t1, c = src.shape
unm = src.gather(dim=-2, index=unm_idx.expand(n, r, c))
unm_idx_new = unm_idx
all_idx = unm_idx_new
all_max,all_idx_idx = torch.sort(all_idx,dim=1)
return unm.gather(dim=-2, index=all_idx_idx.expand(n, r, c))
return merge
from einops import rearrange, repeat
from einops_exts import rearrange_many
from torch import einsum
import torch.nn as nn
import torch
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
#Resample model
from einops import rearrange, repeat
from einops_exts import rearrange_many
from torch import einsum
from monkey_model.text_monkey.merge import *
class FeedForward(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.norm = nn.LayerNorm(in_features)
self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
self.scale = nn.Parameter(torch.ones(1))
with torch.no_grad():
nn.init.kaiming_uniform_(self.fc1.weight, a=math.sqrt(5))
nn.init.zeros_(self.fc2.weight)
def forward(self, x):
x = self.norm(x)
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
x = self.scale*x
return x
class Block(nn.Module):
def __init__(self, input_size,output_size):
super().__init__()
self.fc_1 = nn.Linear(input_size, output_size)
self.norm = nn.LayerNorm(output_size)
def forward(self, x):
x = self.fc_1(x)
x = self.norm(x)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm_media = nn.LayerNorm(dim)
self.norm_latents = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
x = self.norm_media(x)
latents = self.norm_latents(latents)
h = self.heads
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
q = q * self.scale
# attention
sim = einsum("... i d, ... j d -> ... i j", q, k)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
out = einsum("... i j, ... j d -> ... i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)", h=h)
return self.to_out(out)
class PerceiverResampler(nn.Module):
def __init__(
self,
*,
in_dim=1024,
out_dim=4096,
depth=1,
dim_head=128,
heads=8,
visual_tokens_num=512,
ff_mult=4,
):
super().__init__()
self.downsample = nn.Linear(out_dim,in_dim,bias=False)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=in_dim, dim_head=dim_head, heads=heads),
FeedForward(in_features=in_dim, hidden_features=in_dim,out_features=out_dim),
]
)
)
def forward(self, x,r=0):
B,L,C = x.shape
merge = self_soft_matching(x, r) # Replace with your features and r value
latents = merge(x)
down_x = self.downsample(x)
down_latent = self.downsample(latents)
for attn, ff in self.layers:
down_latent = attn(down_x, down_latent)
latents = ff(down_latent) + latents
return latents
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import math
import requests
from io import BytesIO
from functools import partial
from PIL import Image
from typing import Callable, Optional, Sequence, Tuple, List
import numpy as np
import sys
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.init import trunc_normal_
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from monkey_model.text_monkey.window import CrossWindowAttention,PatchMerging
from monkey_model.text_monkey.resampler import *
import random
import numpy as np
import matplotlib.pyplot as plt
def reconstruct_matrix(windows):
temp =[]
for col in windows:
temp.append(torch.cat((col),dim=3))
all_img = torch.cat(temp,dim=2)
return all_img
def sliding_window(matrix, window_size, stride):
b,c,height, width = matrix.shape
window_rows = math.ceil((height - window_size[0]) / stride) + 1
window_cols = math.ceil((width - window_size[1]) / stride) + 1
#windows = np.zeros((window_rows, window_cols, window_size[0], window_size[1]))
windows = []
for i in range(window_rows):
windows_col = []
for j in range(window_cols):
window = matrix[:,:, i*stride:i*stride+window_size[0], j*stride:j*stride+window_size[1]]
windows_col.append(window)
windows.append(windows_col)
return windows
def get_abs_pos(abs_pos, tgt_size):
# abs_pos: L, C
# tgt_size: M
# return: M, C
src_size = int(math.sqrt(abs_pos.size(0)))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
return F.interpolate(
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size, tgt_size),
mode="bicubic",
align_corners=False,
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
else:
return abs_pos
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class Resampler(nn.Module):
"""
A 2D perceiver-resampler network with one cross attention layers by
(grid_size**2) learnable queries and 2d sincos pos_emb
Outputs:
A tensor with the shape of (grid_size**2, embed_dim)
"""
def __init__(
self,
grid_size,
embed_dim,
num_heads,
kv_dim=None,
norm_layer=nn.LayerNorm
):
super().__init__()
self.num_queries = grid_size ** 2
self.embed_dim = embed_dim
self.num_heads = num_heads
self.pos_embed = nn.Parameter(
torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float()
).requires_grad_(False)
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
trunc_normal_(self.query, std=.02)
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
else:
self.kv_proj = nn.Identity()
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x, attn_mask=None):
pos_embed = get_abs_pos(self.pos_embed, x.size(1))
x = self.kv_proj(x)
x = self.ln_kv(x).permute(1, 0, 2)
N = x.shape[1]
q = self.ln_q(self.query)
out = self.attn(
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
x + pos_embed.unsqueeze(1),
x,
attn_mask=attn_mask)[0]
return out.permute(1, 0, 2)
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)
class Lora_Adapter(nn.Module):
def __init__(self,
d_model=None,
out_feat=None,
r=16,
dropout=0.05):
super().__init__()
self.d_model = d_model
self.out_feat = out_feat
self.r = r
self.lora_scale = nn.Parameter(torch.ones(1))
self.lora_a = nn.Linear(self.d_model, self.r,bias=False)
self.lora_b = nn.Linear(self.r, self.out_feat,bias=False)
self.lora_dropout = nn.Dropout(p=dropout)
with torch.no_grad():
nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_b.weight)
def forward(self, x ):
#residual = x if residual is None else residual
x = self.lora_dropout(x)
down = self.lora_a(x)
up = self.lora_b(down)
up = up * self.lora_scale
output = up
return output
class VisualAttention(nn.Module):
"""self-attention layer class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def __init__(self, embed_dim, num_heads,
bias=True, kdim=None, vdim=None,lora_repeat_num=4):
super(VisualAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
# Per attention head and per partition values.
assert embed_dim % num_heads == 0
self.hidden_size_per_attention_head = embed_dim // num_heads
self.num_attention_heads_per_partition = num_heads
self.hidden_size_per_partition = embed_dim
# Strided linear layer.
assert self._qkv_same_embed_dim, 'Only Support SelfAttention Currently'
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim)
self.in_proj_lora = []
for _ in range(lora_repeat_num):
self.in_proj_lora.append(Lora_Adapter(d_model=embed_dim,out_feat=3 * embed_dim))
self.in_proj_lora = nn.ModuleList(self.in_proj_lora)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj_lora = []
for _ in range(lora_repeat_num):
self.out_proj_lora.append(Lora_Adapter(d_model=embed_dim,out_feat=embed_dim))
self.out_proj_lora = nn.ModuleList(self.out_proj_lora)
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
def forward(self, query, key, value, attn_mask = None,lora_idx = None):
# query/key/value: [sq, b, h]
sq, b, _ = query.size()
assert query is key, 'Only Support Self-Attention Currently'
sk = sq
mixed_x_layer = self.in_proj(query)
if lora_idx == None:
pass
else:
lora_res = self.in_proj_lora[lora_idx](query)
mixed_x_layer += lora_res
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
query_layer, key_layer, value_layer = mixed_x_layer.split(
self.hidden_size_per_attention_head, dim=-1)
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(sq,
b * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head).transpose(0, 1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(sk,
b * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head).transpose(0, 1)
q_scaled = query_layer / self.norm_factor
if attn_mask is not None:
attention_probs = torch.baddbmm(attn_mask, q_scaled, key_layer.transpose(-2, -1))
else:
attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1))
attention_probs = attention_probs.softmax(dim=-1)
value_layer = value_layer.view(sk,
b * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head).transpose(0, 1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer)
# change view [b, np, sq, hn]
context_layer = context_layer.view(b,
self.num_attention_heads_per_partition,
sq, self.hidden_size_per_attention_head)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.out_proj(context_layer)
if lora_idx == None:
pass
else:
lora_res = self.out_proj_lora[lora_idx](context_layer)
output += lora_res
return output
class VisualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
is_cross_attention: bool = False,
lora_repeat_num = 4,
):
super().__init__()
self.ln_1 = norm_layer(d_model)
if is_cross_attention:
self.ln_1_kv = norm_layer(d_model)
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.attn = VisualAttention(d_model, n_head,lora_repeat_num = lora_repeat_num)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, mlp_width)),
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model))
]))
self.mlp_lora = []
for _ in range(lora_repeat_num):
self.mlp_lora.append(Lora_Adapter(d_model=d_model,out_feat=d_model,r=32))
self.mlp_lora = nn.ModuleList(self.mlp_lora)
def attention(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
lora_idx = None
):
k_x = k_x if k_x is not None else q_x
v_x = v_x if v_x is not None else q_x
attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
return self.attn(q_x, k_x, v_x, attn_mask=attn_mask,lora_idx=lora_idx)
def forward(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
lora_idx = None
):
k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
x = q_x + self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask,lora_idx=lora_idx)
residual = x
x = x + self.mlp(self.ln_2(x))
if lora_idx == None:
pass
else:
x += self.mlp_lora[lora_idx](residual)
return x
class TransformerBlock(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
lora_repeat_num=4,
add_window=False,
window_all=False,
image_size=(896,896)
):
super().__init__()
self.width = width
self.layers = layers
self.add_window = add_window
self.window_all = window_all
self.window_pos = [2,6,24,46]
self.window_dim = [128,256,512,1024]
self.window_head = [4,8,16,32]
if isinstance(image_size, tuple) or isinstance(image_size, list):
image_size = tuple(size // 14 for size in image_size)
else:
image_size = image_size//14
if self.add_window:
self.window_attention = []
for idx in range(len(self.window_pos)):
self.window_attention.append(CrossWindowAttention(image_size=image_size,dim=1664,hidden_dim=self.window_dim[idx],head=self.window_head[idx]))
self.window_attention = nn.ModuleList(self.window_attention)
self.resblocks = nn.ModuleList([
VisualAttentionBlock(
width, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer,lora_repeat_num=lora_repeat_num)
for _ in range(layers)
])
def get_cast_dtype(self) -> torch.dtype:
return self.resblocks[0].mlp.c_fc.weight.dtype
def get_cast_device(self) -> torch.device:
return self.resblocks[0].mlp.c_fc.weight.device
def forward(self, x,attn_mask: Optional[torch.Tensor] = None,lora_idx=None,image_size=(64,64)):
if isinstance(x,List):
window_idx = 0
for r_idx,r in enumerate(self.resblocks):
if self.add_window:
if r_idx in self.window_pos:
for i in range(len(x)):
for j in range(len(x[i])):
x[i][j] = x[i][j].permute(1, 0, 2) # LND -> NLD
x[i][j] = x[i][j].permute(0, 2, 1) # shape = [*, width, grid ** 2,]
x[i][j] = x[i][j].reshape(x[i][j].shape[0], x[i][j].shape[1], 32,32)
whole_image = reconstruct_matrix(x) #shape = [*,width,grid,grid]
whole_image = self.window_attention[window_idx](whole_image,image_size)
x = sliding_window(whole_image,(32,32),32)
for i in range(len(x)):
for j in range(len(x[i])):
x[i][j] = x[i][j].reshape(x[i][j].shape[0], x[i][j].shape[1], -1)
x[i][j] = x[i][j].permute(0, 2, 1) # shape = [*, grid ** 2, width]
x[i][j] = x[i][j].permute(1, 0, 2) # NLD -> LND
window_idx += 1
if lora_idx is None or lora_idx == 0 :
for i in range(len(x)):
for j in range(len(x[i])):
x[i][j] = r(x[i][j],attn_mask=attn_mask,lora_idx=lora_idx)
else:
temp_lora_idx = 0
for i in range(len(x)):
for j in range(len(x[i])):
x[i][j] = r(x[i][j],attn_mask=attn_mask,lora_idx=temp_lora_idx)
temp_lora_idx += 1
return x
else:
for r in self.resblocks:
x = r(x, attn_mask=attn_mask)
return x
class VisionTransformer(nn.Module):
def __init__(
self,
image_size,
patch_size: int,
width: int,
layers: int,
heads: int,
mlp_ratio: float,
n_queries: int = 256,
output_dim: int = 512,
lora_repeat_num: int = 0,
add_window: bool = False,
use_global:bool =False,
resampler=False,
r=512,
**kwargs
):
super().__init__()
if isinstance(image_size, tuple) or isinstance(image_size, list):
image_height, image_width = self.image_size = image_size
else:
image_height, image_width = self.image_size = (image_size,image_size)
patch_height, patch_width = self.patch_size = (patch_size, patch_size)
self.grid_size = (image_height // patch_height, image_width // patch_width)
self.output_dim = output_dim
self.add_window = add_window
self.use_global = use_global
self.resampler = resampler
self.r = r
mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)
self.image_transform = transforms.Compose([
transforms.Resize(
self.image_size,
interpolation=InterpolationMode.BICUBIC
),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
self.lora_repeat_num = lora_repeat_num
# class embeddings and positional embeddings
scale = width ** -0.5
self.positional_embedding = nn.Parameter(scale * torch.randn(256, width))
norm_layer = partial(nn.LayerNorm, eps=1e-6)
act_layer = nn.GELU
self.ln_pre = norm_layer(width)
self.transformer = TransformerBlock(
width,
layers,
heads,
mlp_ratio,
act_layer=act_layer,
norm_layer=norm_layer,
lora_repeat_num=lora_repeat_num,
add_window=add_window,
image_size=image_size
)
self.attn_pool = Resampler(
grid_size=int(math.sqrt(256)),
embed_dim=output_dim,
num_heads=output_dim // 128,
kv_dim=width,
norm_layer=norm_layer,
)
self.ln_post = norm_layer(output_dim)
self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim))
if self.resampler:
self.downresampler = PerceiverResampler()
def forward(self, x: torch.Tensor,lora_idx=None,add_window=False):
x = x.to(
dtype=self.transformer.get_cast_dtype(),
device=self.transformer.get_cast_device(),
)
# to patches
x = self.conv1(x) # shape = [b, width, grid, grid]
b,c,h,w = x.shape
if add_window:
x = sliding_window(x,(32,32),32)
for i in range(len(x)):
for j in range(len(x[i])):
x[i][j] = x[i][j].reshape(x[i][j].shape[0], x[i][j].shape[1], -1)
x[i][j] = x[i][j].permute(0, 2, 1) # shape = [*, grid ** 2, width]
x[i][j] = x[i][j] + get_abs_pos(self.positional_embedding,x[i][j].size(1))
x[i][j] = self.ln_pre(x[i][j])
x[i][j] = x[i][j].permute(1, 0, 2) # NLD -> LND
else:
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = x + get_abs_pos(self.positional_embedding, x.size(1))
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x,lora_idx=lora_idx,image_size=(h,w))
if add_window:
for i in range(len(x)):
for j in range(len(x[i])):
x[i][j] = x[i][j].permute(1, 0, 2) # LND -> NLD
x[i][j] = self.attn_pool(x[i][j])
x[i][j] = self.ln_post(x[i][j])
x[i][j] = x[i][j] @ self.proj
temp =[]
for col in x:
temp.append(torch.cat((col),dim=1))
x = torch.cat(temp,dim=1)
else:
x = x.permute(1, 0, 2) # LND -> NLD
x = self.attn_pool(x)
x = self.ln_post(x)
x = x @ self.proj
return x
def encode(self, image_paths: List[str],lora_idx=None,input_image=None):
if input_image is None:
images = []
for image_path in image_paths:
if image_path.startswith("http://") or image_path.startswith("https://"):
image = Image.open(requests.get(image_path, stream=True).raw)
else:
image = Image.open(image_path)
image = image.convert("RGB")
## to imitate transmission loss in the real world.
if self.training:
output = BytesIO()
qual = random.randint(20, 100)
image.save(output, format='JPEG', quality=qual)
image_data = output.getvalue()
image =Image.open(BytesIO(image_data))
images.append(self.image_transform(image))
images = torch.stack(images, dim=0)
else:
images = input_image
images_448 = F.interpolate(images, size=(448,448), mode='bicubic')
if lora_idx == 1:
local_feat = self(images,0,add_window=True)
else:
local_feat = self(images,lora_idx,add_window=True)
if self.resampler:
local_feat = self.downresampler(local_feat,r = self.r)
if self.use_global:
global_feat = self(images_448,lora_idx=None,add_window=False)
return torch.cat([local_feat,global_feat],dim=1)
else:
return local_feat
import logging
import math
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import numpy as np
from itertools import repeat
import collections.abc
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.init import trunc_normal_
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from functools import partial
from itertools import repeat
import collections.abc
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
min_value = min_value or divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < round_limit * v:
new_v += divisor
return new_v
def extend_tuple(x, n):
# pads a tuple to specified n by padding with last value
if not isinstance(x, (tuple, list)):
x = (x,)
else:
x = tuple(x)
pad_n = n - len(x)
if pad_n <= 0:
return x[:n]
return x + (x[-1],) * pad_n
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f'drop_prob={round(self.drop_prob,3):0.3f}'
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=False)
self.act = act_layer()
self.drop1 = nn.Dropout(0.05)
self.fc2 = linear_layer(hidden_features, out_features, bias=False)
self.scale = nn.Parameter(torch.ones(1))
with torch.no_grad():
nn.init.kaiming_uniform_(self.fc1.weight, a=math.sqrt(5))
nn.init.zeros_(self.fc2.weight)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.scale*x
return x
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
_int_or_tuple_2_t = Union[int, Tuple[int, int]]
def window_partition(
x: torch.Tensor,
window_size: Tuple[int, int],
) -> torch.Tensor:
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
return windows
def window_reverse(windows, window_size: Tuple[int, int], H: int, W: int):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
C = windows.shape[-1]
x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
return x
def get_relative_position_index(win_h: int, win_w: int):
# get pair-wise relative position index for each token inside the window
coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += win_h - 1 # shift to start from 0
relative_coords[:, :, 1] += win_w - 1
relative_coords[:, :, 0] *= 2 * win_w - 1
return relative_coords.sum(-1) # Wh*Ww, Wh*Ww
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.reduction_2 = nn.Linear(2 * dim, dim, bias=False)
self.norm = norm_layer(4 * dim)
self.norm_2 = norm_layer(2 * dim)
def forward(self, x):
"""
X bxcxgxg
x: B, H*W, C
"""
size= self.input_resolution
B, C, G,_ = x.shape
assert G*G == size * size, "input feature has wrong size"
x = x.reshape(x.shape[0],x.shape[1],-1) #bxcxl
x = x.permute(0,2,1) #bxlxc
B, _, C = x.shape
x = x.view(B, size, size, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
x = self.norm_2(x)
x = self.reduction_2(x)
x = x.view(B,-1,C)
x = x.permute(0,2,1) #bxcxl
x = x.reshape(x.shape[0],x.shape[1],G//2,G//2) #bxcxl
return x
class WindowAttention(nn.Module):
""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports shifted and non-shifted windows.
"""
fused_attn: torch.jit.Final[bool]
def __init__(
self,
dim: int,
num_heads: int,
head_dim: Optional[int] = None,
window_size: _int_or_tuple_2_t = 7,
qkv_bias: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
):
"""
Args:
dim: Number of input channels.
num_heads: Number of attention heads.
head_dim: Number of channels per head (dim // num_heads if not set)
window_size: The height and width of the window.
qkv_bias: If True, add a learnable bias to query, key, value.
attn_drop: Dropout ratio of attention weight.
proj_drop: Dropout ratio of output.
"""
super().__init__()
self.window_size = to_2tuple(window_size) # Wh, Ww
win_h, win_w = self.window_size
self.window_area = win_h * win_w
self.num_heads = num_heads
head_dim = head_dim or dim // num_heads
attn_dim = head_dim * num_heads
self.scale = head_dim ** -0.5
# define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH
self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads))
# get pair-wise relative position index for each token inside the window
self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w), persistent=False)
self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(attn_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def _get_rel_pos_bias(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(self.window_area, self.window_area, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
return relative_position_bias.unsqueeze(0)
def forward(self, x, mask: Optional[torch.Tensor] = None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn + self._get_rel_pos_bias()
if mask is not None:
num_win = mask.shape[0]
attn = attn.view(-1, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B_, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock(nn.Module):
""" Swin Transformer Block.
"""
def __init__(
self,
dim: int,
hidden_dim:int,
input_resolution: _int_or_tuple_2_t,
num_heads: int = 4,
head_dim: Optional[int] = None,
window_size: _int_or_tuple_2_t = 7,
shift_size: int = 0,
mlp_ratio: float = 1.,
qkv_bias: bool = True,
proj_drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm
):
"""
Args:
dim: Number of input channels.
input_resolution: Input resolution.
window_size: Window size.
num_heads: Number of attention heads.
head_dim: Enforce the number of channels per head
shift_size: Shift size for SW-MSA.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: If True, add a learnable bias to query, key, value.
proj_drop: Dropout rate.
attn_drop: Attention dropout rate.
drop_path: Stochastic depth rate.
act_layer: Activation layer.
norm_layer: Normalization layer.
"""
super().__init__()
self.dim = dim
self.hidden_dim = hidden_dim
self.input_resolution = input_resolution
ws, ss = self._calc_window_shift(window_size, shift_size)
self.window_size: Tuple[int, int] = ws
self.shift_size: Tuple[int, int] = ss
self.window_area = self.window_size[0] * self.window_size[1]
self.mlp_ratio = mlp_ratio
self.downsample = nn.Linear(dim,hidden_dim,bias=False)
self.norm1 = norm_layer(hidden_dim)
self.attn = WindowAttention(
hidden_dim,
num_heads=num_heads,
head_dim=head_dim,
window_size=to_2tuple(self.window_size),
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(hidden_dim)
self.mlp = Mlp(
in_features=hidden_dim,
hidden_features=int(hidden_dim * 2),
out_features = 1664,
act_layer=act_layer,
drop=proj_drop,
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
attn_mask = self.calc_attn(self.input_resolution)
self.register_buffer("attn_mask", attn_mask, persistent=False)
def calc_attn(self,input_resolution):
H, W = input_resolution
H = math.ceil(H / self.window_size[0]) * self.window_size[0]
W = math.ceil(W / self.window_size[1]) * self.window_size[1]
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
cnt = 0
for h in (
slice(0, -self.window_size[0]),
slice(-self.window_size[0], -self.shift_size[0]),
slice(-self.shift_size[0], None)):
for w in (
slice(0, -self.window_size[1]),
slice(-self.window_size[1], -self.shift_size[1]),
slice(-self.shift_size[1], None)):
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_area)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]:
target_window_size = to_2tuple(target_window_size)
target_shift_size = to_2tuple(target_shift_size)
window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
return tuple(window_size), tuple(shift_size)
def _attn(self, x):
B, H, W, C = x.shape
# cyclic shift
has_shift = any(self.shift_size)
if has_shift:
shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
else:
shifted_x = x
# pad for resolution not divisible by window size
pad_h = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
pad_w = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
shifted_x = torch.nn.functional.pad(shifted_x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_area, C) # nW*B, window_size*window_size, C
attn_mask = self.attn_mask
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
shifted_x = shifted_x[:, :H, :W, :].contiguous()
# reverse cyclic shift
if has_shift:
x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2))
else:
x = shifted_x
return x
def forward(self, x):
x = self.downsample(x)
B, H, W, C = x.shape
# C = hidden_dim
x = x + self.drop_path1(self._attn(self.norm1(x)))
x = x.reshape(B, -1, C)
x = self.drop_path2(self.mlp(self.norm2(x)))
x = x.reshape(B, H, W, self.dim)
return x
def get_abs_pos(abs_pos, tgt_size):
# abs_pos: L, C
# tgt_size: M
# return: M, C
src_size = int(math.sqrt(abs_pos.size(1)))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
return F.interpolate(
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size, tgt_size),
mode="bicubic",
align_corners=False,
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
else:
return abs_pos
class CrossWindowAttention(nn.Module):
""" Patch Merging Layer.
"""
def __init__(
self,
image_size,
dim,
hidden_dim,
head,
window_size = 12
):
"""
Args:
dim: Number of input channels.
out_dim: Number of output channels (or 2 * dim if None)
norm_layer: Normalization layer.
"""
super().__init__()
if isinstance(image_size, tuple) or isinstance(image_size, list):
self.image_size = image_size
else:
self.image_size = (image_size,image_size)
self.dim = dim
self.window_size = window_size
self.shift_size = window_size // 2
self.position_embedding = nn.Parameter(torch.zeros(1, self.image_size[0]*self.image_size[1], 1664))
trunc_normal_(self.position_embedding, std=.02)
self.shift_attn = SwinTransformerBlock(dim=dim,hidden_dim=hidden_dim,input_resolution = self.image_size,num_heads=head,window_size =self.window_size,shift_size=self.shift_size)
def forward(self, x,image_size):
# X bxcxgxg
B,C,G,_=x.shape
x = x.reshape(x.shape[0],x.shape[1],-1) #bxcxl
x = x.permute(0,2,1) #bxlxc
B, L, C = x.shape
residual = x
H,W = image_size
pos_embed = get_abs_pos(self.position_embedding,x.size(1))
x = x + pos_embed
x = x.view(B,H,W,C)
x = self.shift_attn(x)
x = x.view(B,-1,C)
x = x + residual
x = x.permute(0,2,1) #bxcxl
x = x.reshape(x.shape[0],x.shape[1],G,G) #bxcxl
return x
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Tokenization classes for QWen."""
import base64
import logging
import os
import requests
import unicodedata
from typing import Collection, Dict, List, Set, Tuple, Union, Any, Callable, Optional
import tiktoken
import numpy as np
from PIL import Image
from PIL import ImageFont
from PIL import ImageDraw
from transformers import PreTrainedTokenizer, AddedToken
from transformers.utils import try_to_load_from_cache
import matplotlib.colors as mcolors
from matplotlib.font_manager import FontProperties
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken", "ttf": "SimSun.ttf"}
PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
ENDOFTEXT = "<|endoftext|>"
IMSTART = "<|im_start|>"
IMEND = "<|im_end|>"
# as the default behavior is changed to allow special tokens in
# regular texts, the surface forms of special tokens need to be
# as different as possible to minimize the impact
EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
SPECIAL_TOKENS = (
ENDOFTEXT,
IMSTART,
IMEND,
) + EXTRAS
# IMG_TOKEN_SPAN = 1024
def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
with open(tiktoken_bpe_file, "rb") as f:
contents = f.read()
return {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
}
def _list_find(
input_list: List[Any],
candidates: Tuple[Any],
start: int = 0,
):
for i in range(start, len(input_list)):
if input_list[i] in candidates:
return i
return -1
def _replace_closed_tag(
input_tokens: List[Any],
start_tags: Union[Any, Tuple[Any]],
end_tags: Union[Any, Tuple[Any]],
inclusive_replace_func: Callable,
exclusive_replace_func: Callable = lambda x: x,
):
if isinstance(start_tags, (str, int)):
start_tags = (start_tags,)
if isinstance(end_tags, (str, int)):
end_tags = (end_tags,)
assert len(start_tags) == len(end_tags)
output_tokens = []
end = 0
while True:
start = _list_find(input_tokens, start_tags, end)
if start == -1:
break
output_tokens.extend(exclusive_replace_func(input_tokens[end : start]))
tag_idx = start_tags.index(input_tokens[start])
end = _list_find(input_tokens, (end_tags[tag_idx],), start)
if end == -1:
raise ValueError("Unclosed image token")
output_tokens.extend(inclusive_replace_func(input_tokens[start : end + 1]))
end += 1
output_tokens.extend(exclusive_replace_func(input_tokens[end : ]))
return output_tokens
class QWenTokenizer(PreTrainedTokenizer):
"""QWen tokenizer."""
vocab_files_names = VOCAB_FILES_NAMES
def __init__(
self,
vocab_file,
errors="replace",
image_start_tag='<img>',
image_end_tag='</img>',
image_pad_tag='<imgpad>',
ref_start_tag='<ref>',
ref_end_tag='</ref>',
box_start_tag='<box>',
box_end_tag='</box>',
quad_start_tag='<quad>',
quad_end_tag='</quad>',
**kwargs,
):
super().__init__(**kwargs)
self.image_start_tag = image_start_tag
self.image_end_tag = image_end_tag
self.image_pad_tag = image_pad_tag
self.ref_start_tag = ref_start_tag
self.ref_end_tag = ref_end_tag
self.box_start_tag = box_start_tag
self.box_end_tag = box_end_tag
self.quad_start_tag = quad_start_tag
self.quad_end_tag = quad_end_tag
self.IMAGE_ST = (
ref_start_tag, ref_end_tag,
box_start_tag, box_end_tag,
quad_start_tag, quad_end_tag,
image_start_tag, image_end_tag,
image_pad_tag
)
self.IMG_TOKEN_SPAN = 1280
self.errors = errors # how to handle errors in decoding
self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
self.special_tokens = {
token: index
for index, token in enumerate(
SPECIAL_TOKENS + self.IMAGE_ST, start=len(self.mergeable_ranks)
)
}
self.img_start_id = self.special_tokens[self.image_start_tag]
self.img_end_id = self.special_tokens[self.image_end_tag]
self.img_pad_id = self.special_tokens[self.image_pad_tag]
self.ref_start_id = self.special_tokens[self.ref_start_tag]
self.ref_end_id = self.special_tokens[self.ref_end_tag]
self.box_start_id = self.special_tokens[self.box_start_tag]
self.box_end_id = self.special_tokens[self.box_end_tag]
self.quad_start_id = self.special_tokens[self.quad_start_tag]
self.quad_end_id = self.special_tokens[self.quad_end_tag]
enc = tiktoken.Encoding(
"Qwen",
pat_str=PAT_STR,
mergeable_ranks=self.mergeable_ranks,
special_tokens=self.special_tokens,
)
assert (
len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
self.decoder = {
v: k for k, v in self.mergeable_ranks.items()
} # type: dict[int, bytes|str]
self.decoder.update({v: k for k, v in self.special_tokens.items()})
self.tokenizer = enc # type: tiktoken.Encoding
self.eod_id = self.tokenizer.eot_token
self.im_start_id = self.special_tokens[IMSTART]
self.im_end_id = self.special_tokens[IMEND]
def __getstate__(self):
# for pickle lovers
state = self.__dict__.copy()
del state['tokenizer']
return state
def __setstate__(self, state):
# tokenizer is not python native; don't pass it; rebuild it
self.__dict__.update(state)
enc = tiktoken.Encoding(
"Qwen",
pat_str=PAT_STR,
mergeable_ranks=self.mergeable_ranks,
special_tokens=self.special_tokens,
)
self.tokenizer = enc
def __len__(self) -> int:
return self.tokenizer.n_vocab
def get_vocab(self) -> Dict[bytes, int]:
return self.mergeable_ranks
def convert_tokens_to_ids(
self, tokens: Union[bytes, str, List[Union[bytes, str]]]
) -> List[int]:
ids = []
if isinstance(tokens, (str, bytes)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.mergeable_ranks.get(tokens)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.mergeable_ranks.get(token))
return ids
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
if not special_tokens and new_tokens:
raise ValueError('Adding regular tokens is not supported')
for token in new_tokens:
surface_form = token.content if isinstance(token, AddedToken) else token
if surface_form not in SPECIAL_TOKENS + self.IMAGE_ST:
raise ValueError('Adding unknown special tokens is not supported')
return 0
def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
"""
Save only the vocabulary of the tokenizer (vocabulary).
Returns:
`Tuple(str)`: Paths to the files saved.
"""
file_path = os.path.join(save_directory, "qwen.tiktoken")
with open(file_path, "w", encoding="utf8") as w:
for k, v in self.mergeable_ranks.items():
line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
w.write(line)
return (file_path,)
def tokenize(
self,
text: str,
allowed_special: Union[Set, str] = "all",
disallowed_special: Union[Collection, str] = (),
**kwargs,
) -> List[Union[bytes, str]]:
"""
Converts a string in a sequence of tokens.
Args:
text (`str`):
The sequence to be encoded.
allowed_special (`Literal["all"]` or `set`):
The surface forms of the tokens to be encoded as special tokens in regular texts.
Default to "all".
disallowed_special (`Literal["all"]` or `Collection`):
The surface forms of the tokens that should not be in regular texts and trigger errors.
Default to an empty tuple.
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific encode method.
Returns:
`List[bytes|str]`: The list of tokens.
"""
tokens = []
text = unicodedata.normalize("NFC", text)
# this implementation takes a detour: text -> token id -> token surface forms
for t in self.tokenizer.encode(
text, allowed_special=allowed_special, disallowed_special=disallowed_special
):
tokens.append(self.decoder[t])
def _encode_imgurl(img_tokens):
assert img_tokens[0] == self.image_start_tag and img_tokens[-1] == self.image_end_tag
img_tokens = img_tokens[1:-1]
img_url = b''.join(img_tokens)
out_img_tokens = list(map(self.decoder.get, img_url))
if len(out_img_tokens) > self.IMG_TOKEN_SPAN:
raise ValueError("The content in {}..{} is too long".format(
self.image_start_tag, self.image_end_tag))
out_img_tokens.extend([self.image_pad_tag] * (self.IMG_TOKEN_SPAN - len(out_img_tokens)))
out_img_tokens = [self.image_start_tag] + out_img_tokens + [self.image_end_tag]
return out_img_tokens
return _replace_closed_tag(tokens, self.image_start_tag, self.image_end_tag, _encode_imgurl)
def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
"""
Converts a sequence of tokens in a single string.
"""
text = ""
temp = b""
for t in tokens:
if isinstance(t, str):
if temp:
text += temp.decode("utf-8", errors=self.errors)
temp = b""
text += t
elif isinstance(t, bytes):
temp += t
else:
raise TypeError("token should only be of type types or str")
if temp:
text += temp.decode("utf-8", errors=self.errors)
return text
@property
def vocab_size(self):
return self.tokenizer.n_vocab
def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
"""Converts an id to a token, special tokens included"""
if index in self.decoder:
return self.decoder[index]
raise ValueError("unknown ids")
def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
"""Converts a token to an id using the vocab, special tokens included"""
if token in self.special_tokens:
return self.special_tokens[token]
if token in self.mergeable_ranks:
return self.mergeable_ranks[token]
raise ValueError("unknown token")
def _tokenize(self, text: str, **kwargs):
"""
Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
Do NOT take care of added tokens.
"""
raise NotImplementedError
def _decode(
self,
token_ids: Union[int, List[int]],
skip_special_tokens: bool = False,
errors: str = None,
**kwargs,
) -> str:
if isinstance(token_ids, int):
token_ids = [token_ids]
def _decode_imgurl(img_token_ids):
assert img_token_ids[0] == self.img_start_id and img_token_ids[-1] == self.img_end_id
img_token_ids = img_token_ids[1:-1]
img_token_ids = img_token_ids[ : img_token_ids.index(self.img_pad_id)]
img_url = bytes(img_token_ids).decode('utf-8')
return [self.img_start_id] + self.tokenizer.encode(img_url) + [self.img_end_id]
token_ids = _replace_closed_tag(token_ids, self.img_start_id, self.img_end_id, _decode_imgurl)
if skip_special_tokens:
token_ids = [i for i in token_ids if i < self.eod_id]
return self.tokenizer.decode(token_ids, errors=errors or self.errors)
def to_list_format(self, text: str):
text = unicodedata.normalize("NFC", text)
token_ids = self.tokenizer.encode(
text, allowed_special=set(self.IMAGE_ST + (ENDOFTEXT,)))
def _encode_vl_info(tokens):
if len(tokens) == 0:
return []
if tokens[0] == self.img_start_id and tokens[-1] == self.img_end_id:
key = 'image'
elif tokens[0] == self.ref_start_id and tokens[-1] == self.ref_end_id:
key = 'ref'
elif tokens[0] == self.box_start_id and tokens[-1] == self.box_end_id:
key = 'box'
elif tokens[0] == self.quad_start_id and tokens[-1] == self.quad_end_id:
key = 'quad'
else:
_tobytes = lambda x: x.encode('utf-8') if isinstance(x, str) else x
return [{'text': b''.join(map(_tobytes, map(self.decoder.get, tokens))).decode('utf-8')}]
_tobytes = lambda x: x.encode('utf-8') if isinstance(x, str) else x
val = b''.join(map(_tobytes, map(self.decoder.get, tokens[1:-1]))).decode('utf-8')
return [{key: val}]
return _replace_closed_tag(
token_ids,
(self.img_start_id, self.ref_start_id, self.box_start_id, self.quad_start_id),
(self.img_end_id, self.ref_end_id, self.box_end_id, self.quad_end_id),
_encode_vl_info,
_encode_vl_info,
)
def from_list_format(self, list_format: List[Dict]):
text = ''
num_images = 0
for ele in list_format:
if 'image' in ele:
num_images += 1
text += f'Picture {num_images}:'
text += self.image_start_tag + ele['image'] + self.image_end_tag
text += '\n'
elif 'text' in ele:
text += ele['text']
elif 'box' in ele:
if 'ref' in ele:
text += self.ref_start_tag + ele['ref'] + self.ref_end_tag
for box in ele['box']:
text += self.box_start_tag + '(%d,%d),(%d,%d)' % (box[0], box[1], box[2], box[3]) + self.box_end_tag
else:
raise ValueError("Unsupport element: " + str(ele))
return text
def _fetch_latest_picture(self, response, history):
if history is None:
history = []
_history = history + [(response, None)]
for q, r in _history[::-1]:
for ele in self.to_list_format(q)[::-1]:
if 'image' in ele:
return ele['image']
return None
def _fetch_all_box_with_ref(self, text):
list_format = self.to_list_format(text)
output = []
for i, ele in enumerate(list_format):
if 'box' in ele:
bbox = tuple(map(int, ele['box'].replace('(', '').replace(')', '').split(',')))
assert len(bbox) == 4
output.append({'box': bbox})
if i > 0 and 'ref' in list_format[i-1]:
output[-1]['ref'] = list_format[i-1]['ref'].strip()
return output
def draw_bbox_on_latest_picture(
self,
response,
history=None,
) -> Optional[Image.Image]:
image = self._fetch_latest_picture(response, history)
if image is None:
return None
if image.startswith("http://") or image.startswith("https://"):
image = Image.open(requests.get(image, stream=True).raw).convert("RGB")
h, w = image.height, image.width
else:
image = np.asarray(Image.open(image).convert("RGB"))
h, w = image.shape[0], image.shape[1]
visualizer = Visualizer(image)
boxes = self._fetch_all_box_with_ref(response)
if not boxes:
return None
color = random.choice([_ for _ in mcolors.TABLEAU_COLORS.keys()]) # init color
for box in boxes:
if 'ref' in box: # random new color for new refexps
color = random.choice([_ for _ in mcolors.TABLEAU_COLORS.keys()])
x1, y1, x2, y2 = box['box']
x1, y1, x2, y2 = (int(x1 / 1000 * w), int(y1 / 1000 * h), int(x2 / 1000 * w), int(y2 / 1000 * h))
visualizer.draw_box((x1, y1, x2, y2), alpha=1, edge_color=color)
if 'ref' in box:
visualizer.draw_text(box['ref'], (x1, y1), color=color, horizontal_alignment="left")
return visualizer.output
import colorsys
import logging
import math
import numpy as np
import matplotlib as mpl
import matplotlib.colors as mplc
import matplotlib.figure as mplfigure
import torch
from matplotlib.backends.backend_agg import FigureCanvasAgg
from PIL import Image
import random
logger = logging.getLogger(__name__)
class VisImage:
def __init__(self, img, scale=1.0):
self.img = img
self.scale = scale
self.width, self.height = img.shape[1], img.shape[0]
self._setup_figure(img)
def _setup_figure(self, img):
fig = mplfigure.Figure(frameon=False)
self.dpi = fig.get_dpi()
# add a small 1e-2 to avoid precision lost due to matplotlib's truncation
# (https://github.com/matplotlib/matplotlib/issues/15363)
fig.set_size_inches(
(self.width * self.scale + 1e-2) / self.dpi,
(self.height * self.scale + 1e-2) / self.dpi,
)
self.canvas = FigureCanvasAgg(fig)
# self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
ax.axis("off")
self.fig = fig
self.ax = ax
self.reset_image(img)
def reset_image(self, img):
img = img.astype("uint8")
self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest")
def save(self, filepath):
self.fig.savefig(filepath)
def get_image(self):
canvas = self.canvas
s, (width, height) = canvas.print_to_buffer()
buffer = np.frombuffer(s, dtype="uint8")
img_rgba = buffer.reshape(height, width, 4)
rgb, alpha = np.split(img_rgba, [3], axis=2)
return rgb.astype("uint8")
class Visualizer:
def __init__(self, img_rgb, metadata=None, scale=1.0):
self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
self.font_path = try_to_load_from_cache("Qwen/Qwen-VL-Chat", "SimSun.ttf")
self.output = VisImage(self.img, scale=scale)
self.cpu_device = torch.device("cpu")
# too small texts are useless, therefore clamp to 14
self._default_font_size = max(
np.sqrt(self.output.height * self.output.width) // 30, 15 // scale
)
def draw_text(
self,
text,
position,
*,
font_size=None,
color="g",
horizontal_alignment="center",
rotation=0,
):
if not font_size:
font_size = self._default_font_size
# since the text background is dark, we don't want the text to be dark
color = np.maximum(list(mplc.to_rgb(color)), 0.2)
color[np.argmax(color)] = max(0.8, np.max(color))
x, y = position
self.output.ax.text(
x,
y,
text,
size=font_size * self.output.scale,
fontproperties=FontProperties(fname=self.font_path),
bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
verticalalignment="top",
horizontalalignment=horizontal_alignment,
color=color,
zorder=10,
rotation=rotation,
)
return self.output
def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"):
x0, y0, x1, y1 = box_coord
width = x1 - x0
height = y1 - y0
linewidth = max(self._default_font_size / 4, 1)
self.output.ax.add_patch(
mpl.patches.Rectangle(
(x0, y0),
width,
height,
fill=False,
edgecolor=edge_color,
linewidth=linewidth * self.output.scale,
alpha=alpha,
linestyle=line_style,
)
)
return self.output
def get_output(self):
return self.output
{
"auto_map": {
"AutoTokenizer": [
"tokenization_qwen.QWenTokenizer",
null
]
},
"clean_up_tokenization_spaces": true,
"model_max_length": 2048,
"padding_side": "right",
"tokenizer_class": "QWenTokenizer"
}
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import math
import requests
from io import BytesIO
from functools import partial
from PIL import Image
from typing import Callable, Optional, Sequence, Tuple, List
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.init import trunc_normal_
from torchvision import transforms
from torchvision.transforms import InterpolationMode
def reconstruct_matrix(windows):
temp =[]
for col in windows:
temp.append(torch.cat((col),dim=3))
all_img = torch.cat(temp,dim=2)
return all_img
def sliding_window(matrix, window_size, stride):
b,c,height, width = matrix.shape
window_rows = (height - window_size[0]) // stride + 1
window_cols = (width - window_size[1]) // stride + 1
windows = []
for i in range(window_rows):
windows_col = []
for j in range(window_cols):
window = matrix[:,:, i*stride:i*stride+window_size[0], j*stride:j*stride+window_size[1]]
windows_col.append(window)
windows.append(windows_col)
return windows
def get_abs_pos(abs_pos, tgt_size):
# abs_pos: L, C
# tgt_size: M
# return: M, C
src_size = int(math.sqrt(abs_pos.size(0)))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
return F.interpolate(
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size, tgt_size),
mode="bicubic",
align_corners=False,
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
else:
return abs_pos
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class Resampler(nn.Module):
"""
A 2D perceiver-resampler network with one cross attention layers by
(grid_size**2) learnable queries and 2d sincos pos_emb
Outputs:
A tensor with the shape of (grid_size**2, embed_dim)
"""
def __init__(
self,
grid_size,
embed_dim,
num_heads,
kv_dim=None,
norm_layer=nn.LayerNorm
):
super().__init__()
self.num_queries = grid_size ** 2
self.embed_dim = embed_dim
self.num_heads = num_heads
self.pos_embed = nn.Parameter(
torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float()
).requires_grad_(False)
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
trunc_normal_(self.query, std=.02)
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
else:
self.kv_proj = nn.Identity()
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x, attn_mask=None):
pos_embed = get_abs_pos(self.pos_embed, x.size(1))
x = self.kv_proj(x)
x = self.ln_kv(x).permute(1, 0, 2)
N = x.shape[1]
q = self.ln_q(self.query)
out = self.attn(
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
x + pos_embed.unsqueeze(1),
x,
attn_mask=attn_mask)[0]
return out.permute(1, 0, 2)
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)
class Lora_Adapter(nn.Module):
def __init__(self,
d_model=None,
out_feat=None,
r=16,
dropout=0.05):
super().__init__()
self.d_model = d_model
self.out_feat = out_feat
self.r = r
self.lora_scale = nn.Parameter(torch.ones(1))
self.lora_a = nn.Linear(self.d_model, self.r,bias=False)
self.lora_b = nn.Linear(self.r, self.out_feat,bias=False)
self.lora_dropout = nn.Dropout(p=dropout)
with torch.no_grad():
nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_b.weight)
def forward(self, x ):
#residual = x if residual is None else residual
x = self.lora_dropout(x)
down = self.lora_a(x)
up = self.lora_b(down)
up = up * self.lora_scale
output = up
return output
class VisualAttention(nn.Module):
"""self-attention layer class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def __init__(self, embed_dim, num_heads,
bias=True, kdim=None, vdim=None,lora_repeat_num=4):
super(VisualAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
# Per attention head and per partition values.
assert embed_dim % num_heads == 0
self.hidden_size_per_attention_head = embed_dim // num_heads
self.num_attention_heads_per_partition = num_heads
self.hidden_size_per_partition = embed_dim
# Strided linear layer.
assert self._qkv_same_embed_dim, 'Only Support SelfAttention Currently'
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim)
self.in_proj_lora = []
for _ in range(lora_repeat_num):
self.in_proj_lora.append(Lora_Adapter(d_model=embed_dim,out_feat=3 * embed_dim))
self.in_proj_lora = nn.ModuleList(self.in_proj_lora)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj_lora = []
for _ in range(lora_repeat_num):
self.out_proj_lora.append(Lora_Adapter(d_model=embed_dim,out_feat=embed_dim))
self.out_proj_lora = nn.ModuleList(self.out_proj_lora)
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
def forward(self, query, key, value, attn_mask = None,idx = None):
# query/key/value: [sq, b, h]
sq, b, _ = query.size()
assert query is key, 'Only Support Self-Attention Currently'
sk = sq
mixed_x_layer = self.in_proj(query)
if idx == None:
pass
else:
lora_res = self.in_proj_lora[idx](query)
mixed_x_layer += lora_res
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
query_layer, key_layer, value_layer = mixed_x_layer.split(
self.hidden_size_per_attention_head, dim=-1)
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(sq,
b * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head).transpose(0, 1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(sk,
b * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head).transpose(0, 1)
q_scaled = query_layer / self.norm_factor
if attn_mask is not None:
attention_probs = torch.baddbmm(attn_mask, q_scaled, key_layer.transpose(-2, -1))
else:
attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1))
attention_probs = attention_probs.softmax(dim=-1)
value_layer = value_layer.view(sk,
b * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head).transpose(0, 1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer)
# change view [b, np, sq, hn]
context_layer = context_layer.view(b,
self.num_attention_heads_per_partition,
sq, self.hidden_size_per_attention_head)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.out_proj(context_layer)
if idx == None:
pass
else:
lora_res = self.out_proj_lora[idx](context_layer)
output += lora_res
return output
class VisualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
is_cross_attention: bool = False,
lora_repeat_num = 4,
):
super().__init__()
self.ln_1 = norm_layer(d_model)
if is_cross_attention:
self.ln_1_kv = norm_layer(d_model)
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.attn = VisualAttention(d_model, n_head,lora_repeat_num = lora_repeat_num)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, mlp_width)),
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model))
]))
self.mlp_lora = []
for _ in range(lora_repeat_num):
self.mlp_lora.append(Lora_Adapter(d_model=d_model,out_feat=d_model,r=32))
self.mlp_lora = nn.ModuleList(self.mlp_lora)
def attention(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
idx = None
):
k_x = k_x if k_x is not None else q_x
v_x = v_x if v_x is not None else q_x
attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
return self.attn(q_x, k_x, v_x, attn_mask=attn_mask,idx=idx)
def forward(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
idx = None
):
k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
x = q_x + self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask,idx=idx)
residual = x
x = x + self.mlp(self.ln_2(x))
if idx == None:
pass
else:
x += self.mlp_lora[idx](residual)
return x
class TransformerBlock(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
lora_repeat_num=4
):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.ModuleList([
VisualAttentionBlock(
width, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer,lora_repeat_num=lora_repeat_num)
for _ in range(layers)
])
def get_cast_dtype(self) -> torch.dtype:
return self.resblocks[0].mlp.c_fc.weight.dtype
def get_cast_device(self) -> torch.device:
return self.resblocks[0].mlp.c_fc.weight.device
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None,idx=None):
for r in self.resblocks:
x = r(x, attn_mask=attn_mask,idx=idx)
return x
class VisionTransformer(nn.Module):
def __init__(
self,
image_size: int,
patch_size: int,
width: int,
layers: int,
heads: int,
mlp_ratio: float,
n_queries: int = 256,
output_dim: int = 512,
lora_repeat_num: int = 4,
**kwargs
):
super().__init__()
if isinstance(image_size, tuple) or isinstance(image_size, list):
image_height, image_width = self.image_size = image_size
else:
image_height, image_width = self.image_size = (image_size,image_size)
patch_height, patch_width = self.patch_size = (patch_size, patch_size)
self.grid_size = (image_height // patch_height, image_width // patch_width)
self.output_dim = output_dim
mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)
self.image_transform = transforms.Compose([
transforms.Resize(
(image_size, image_size),
interpolation=InterpolationMode.BICUBIC
),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
# class embeddings and positional embeddings
scale = width ** -0.5
self.positional_embedding = nn.Parameter(scale * torch.randn(256, width))
norm_layer = partial(nn.LayerNorm, eps=1e-6)
act_layer = nn.GELU
self.ln_pre = norm_layer(width)
self.transformer = TransformerBlock(
width,
layers,
heads,
mlp_ratio,
act_layer=act_layer,
norm_layer=norm_layer,
lora_repeat_num=lora_repeat_num
)
self.attn_pool = Resampler(
grid_size=int(math.sqrt(n_queries)),
embed_dim=output_dim,
num_heads=output_dim // 128,
kv_dim=width,
norm_layer=norm_layer,
)
self.ln_post = norm_layer(output_dim)
self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim))
def forward(self, x: torch.Tensor,idx=None):
x = x.to(
dtype=self.transformer.get_cast_dtype(),
device=self.transformer.get_cast_device(),
)
# to patches
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = x + get_abs_pos(self.positional_embedding, x.size(1))
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x,idx=idx)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.attn_pool(x)
x = self.ln_post(x)
x = x @ self.proj
return x
def encode(self, image_paths: List[str]):
images = []
for image_path in image_paths:
if image_path.startswith("http://") or image_path.startswith("https://"):
image = Image.open(requests.get(image_path, stream=True).raw)
else:
image = Image.open(image_path)
image = image.convert("RGB")
images.append(self.image_transform(image))
images = torch.stack(images, dim=0)
B,C,H,W = images.shape
windows = sliding_window(images,window_size=(448,448),stride=448)
images_448 = F.interpolate(images, size=(448,448), mode='bicubic')
return windows,images_448
if __name__ == "__main__":
pass
visual = VisionTransformer(
image_size= 896,
patch_size= 14,
width=1664,
layers = 48,
heads= 16,
mlp_ratio = 4.9231,
output_dim= 4096)
img = torch.randn(1,3,896,896)
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType
# Define LoRA Config
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["in_proj","out_proj","c_fc","c_proj"],
lora_dropout=0.05,
bias="none",
)
# prepare int-8 model for training
model = visual
# add LoRA adaptor
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
print(model)
print(visual)
CUDA_VISIBLE_DEVICES=3,4 python demo_textmonkey.py -c /home/wanglch/projects/TextMonkey/TextMonkey_base
\ No newline at end of file
CUDA_VISIBLE_DEVICES=3,4 python demo_textmonkey.py -c ../TextMonkey/TextMonkey_base
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment