Unverified Commit b63ace06 authored by Melos's avatar Melos Committed by GitHub
Browse files

TextMonkey (#75)



* textmonkey

* textmonkey code

* Delete README_cn.md

---------
Co-authored-by: default avatarYuliang Liu <34134635+Yuliang-Liu@users.noreply.github.com>
parent ab58e6f0
from transformers import AutoModelForCausalLM, AutoTokenizer
import argparse
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="echo840/Monkey-Chat") #echo840/Monkey-Chat echo840/Monkey
parser.add_argument("--image_path", type=str, default=None)
parser.add_argument("--question", type=str, default=None)
args = parser.parse_args()
checkpoint = args.model_path
model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map='cuda', trust_remote_code=True).eval()
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
tokenizer.padding_side = 'left'
tokenizer.pad_token_id = tokenizer.eod_id
img_path = args.image_path
question = args.question
if question == "Generate the detailed caption in English:" and "Monkey-Chat" not in checkpoint:
query = f'<img>{img_path}</img> Generate the detailed caption in English: ' #detailed caption
else:
query = f'<img>{img_path}</img> {question} Answer: ' #VQA
input_ids = tokenizer(query, return_tensors='pt', padding='longest')
attention_mask = input_ids.attention_mask
input_ids = input_ids.input_ids
pred = model.generate(
input_ids=input_ids.cuda(),
attention_mask=attention_mask.cuda(),
do_sample=False,
num_beams=1,
max_new_tokens=512,
min_new_tokens=1,
length_penalty=1,
num_return_sequences=1,
output_hidden_states=True,
use_cache=True,
pad_token_id=tokenizer.eod_id,
eos_token_id=tokenizer.eod_id,
)
response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
print(f"Question: {question} Answer: {response}")
...@@ -45,7 +45,7 @@ from .qwen_generation_utils import ( ...@@ -45,7 +45,7 @@ from .qwen_generation_utils import (
StopWordsLogitsProcessor, StopWordsLogitsProcessor,
) )
from .visual import VisionTransformer from .visual import VisionTransformer
SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -71,6 +71,142 @@ apply_rotary_emb_func = None ...@@ -71,6 +71,142 @@ apply_rotary_emb_func = None
rms_norm = None rms_norm = None
# use flash attnetion, if your machine do not support it, you can close it
use_flash_attention = True
def _import_flash_attn():
global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func
try:
from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func
apply_rotary_emb_func = __apply_rotary_emb_func
except ImportError:
logger.warn(
"Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency "
"https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary"
)
# try:
# from flash_attn.ops.rms_norm import rms_norm as __rms_norm
# rms_norm = __rms_norm
# except ImportError:
# logger.warn(
# "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency "
# "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm"
# )
try:
import flash_attn
if not hasattr(flash_attn, '__version__'):
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
else:
if int(flash_attn.__version__.split(".")[0]) >= 2:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func
else:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
flash_attn_unpadded_func = __flash_attn_unpadded_func
except ImportError:
logger.warn(
"Warning: import flash_attn fail, please install FlashAttention to get higher efficiency "
"https://github.com/Dao-AILab/flash-attention"
)
class FlashSelfAttention(torch.nn.Module):
def __init__(
self,
causal=False,
softmax_scale=None,
attention_dropout=0.0,
):
super().__init__()
assert flash_attn_unpadded_func is not None, (
"Please install FlashAttention first, " "e.g., with pip install flash-attn"
)
assert (
rearrange is not None
), "Please install einops first, e.g., with pip install einops"
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def unpad_input(self, hidden_states, attention_mask):
valid_mask = attention_mask.squeeze(1).squeeze(1).eq(0)
seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
hidden_states = hidden_states[indices]
return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
def pad_input(self, hidden_states, indices, batch, seqlen):
output = torch.zeros(batch * seqlen, *hidden_states.shape[1:], device=hidden_states.device,
dtype=hidden_states.dtype)
output[indices] = hidden_states
return rearrange(output, '(b s) ... -> b s ...', b=batch)
def forward(self, q, k, v, attention_mask=None):
assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v)))
assert all((i.is_cuda for i in (q, k, v)))
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = k.shape[1]
seqlen_out = seqlen_q
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=q.device,
)
if batch_size > 1 and attention_mask is not None:
k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask)
if q.size(0) == v.size(0):
q = q[indices_k]
cu_seqlens_q = cu_seqlens_k
seqlen_q = seqlen_k
v = v[indices_k]
else:
cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * seqlen_k,
step=seqlen_k,
dtype=torch.int32,
device=q.device,
)
if self.training:
assert seqlen_k == seqlen_q
is_causal = self.causal
dropout_p = self.dropout_p
else:
is_causal = seqlen_q == seqlen_k
dropout_p = 0
output = flash_attn_unpadded_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
seqlen_q,
seqlen_k,
dropout_p,
softmax_scale=self.softmax_scale,
causal=is_causal,
)
if batch_size > 1 and attention_mask is not None and seqlen_q == seqlen_k:
output = self.pad_input(output, indices_k, batch_size, seqlen_out)
else:
new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:]
output = output.view(new_shape)
return output
# Copied from transformers.models.bart.modeling_bart._make_causal_mask # Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask( def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
...@@ -144,6 +280,9 @@ class QWenAttention(nn.Module): ...@@ -144,6 +280,9 @@ class QWenAttention(nn.Module):
self.logn_tensor = torch.tensor(logn_list)[None, :, None, None] self.logn_tensor = torch.tensor(logn_list)[None, :, None, None]
self.attn_dropout = nn.Dropout(config.attn_dropout_prob) self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
if use_flash_attention:
_import_flash_attn()
self.core_attention_flash = FlashSelfAttention(causal=True, attention_dropout=0)
def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None): def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) attn_weights = torch.matmul(query, key.transpose(-1, -2))
...@@ -297,12 +436,18 @@ class QWenAttention(nn.Module): ...@@ -297,12 +436,18 @@ class QWenAttention(nn.Module):
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :] logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
query = query * logn_tensor.expand_as(query) query = query * logn_tensor.expand_as(query)
query = query.permute(0, 2, 1, 3) if self.training and SUPPORT_TORCH2 and use_flash_attention:
key = key.permute(0, 2, 1, 3) attn_output = self.core_attention_flash(query,key,value)
value = value.permute(0, 2, 1, 3) attn_weight = None
attn_output, attn_weight = self._attn( else:
query, key, value, registered_causal_mask, attention_mask, head_mask
) query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
attn_output, attn_weight = self._attn(
query, key, value, registered_causal_mask, attention_mask, head_mask
)
context_layer = self._merge_heads( context_layer = self._merge_heads(
attn_output, self.num_heads, self.head_dim attn_output, self.num_heads, self.head_dim
) )
...@@ -410,6 +555,10 @@ class QWenPreTrainedModel(PreTrainedModel): ...@@ -410,6 +555,10 @@ class QWenPreTrainedModel(PreTrainedModel):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
def _init_weights(self, module): def _init_weights(self, module):
'''
There is no need to re_init
'''
return
"""Initialize the weights.""" """Initialize the weights."""
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
......
import importlib
import math
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.cuda.amp import autocast
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
from transformers.generation.logits_process import LogitsProcessorList
if TYPE_CHECKING:
from transformers.generation.streamers import BaseStreamer
from transformers.generation.utils import GenerateOutput
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
try:
from einops import rearrange
except ImportError:
rearrange = None
from torch import nn
from monkey_model.modeling_qwen import QWenModel,QWenPreTrainedModel,QWenLMHeadModel
from monkey_model.text_monkey.visual_text import VisionTransformer
SUPPORT_CUDA = torch.cuda.is_available()
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
logger = logging.get_logger(__name__)
class TextMonkeyModel(QWenModel):
def __init__(self, config):
super().__init__(config)
self.visual = VisionTransformer(**config.visual)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
if past_key_values is None and torch.any(input_ids == self.config.visual['image_start_id']):
bos_pos = torch.where(input_ids == self.config.visual['image_start_id'])
eos_pos = torch.where(input_ids == self.config.visual['image_start_id'] + 1)
assert (bos_pos[0] == eos_pos[0]).all()
img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
images = []
for i, a, b in img_pos:
image = input_ids[i][a + 1 : b - 1].tolist()
image = image[ : image.index(self.config.visual['image_start_id'] + 2)]
images.append(bytes(image).decode('utf-8'))
if self.visual.lora_repeat_num>0:
images = self.visual.encode(images,lora_idx=self.visual.lora_repeat_num)
else:
images = self.visual.encode(images)
assert images.shape[0] == len(images)
else:
images = None
return super().forward(input_ids,
past_key_values,
attention_mask,
token_type_ids,
position_ids,
head_mask,inputs_embeds,
encoder_hidden_states,
encoder_attention_mask,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
images)
class TextMonkeyLMHeadModel(QWenLMHeadModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"]
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"]
def __init__(self, config):
super().__init__(config)
assert (
config.bf16 + config.fp16 + config.fp32 <= 1
), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true"
autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
if autoset_precision:
if SUPPORT_BF16:
logger.warn(
"The model is automatically converting to bf16 for faster inference. "
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
)
config.bf16 = True
elif SUPPORT_FP16:
logger.warn(
"The model is automatically converting to fp16 for faster inference. "
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
)
config.fp16 = True
else:
config.fp32 = True
if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".")
if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster")
if config.fp32:
if SUPPORT_BF16:
logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
elif SUPPORT_FP16:
logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
self.transformer = TextMonkeyModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.bf16:
self.transformer.bfloat16()
self.lm_head.bfloat16()
if config.fp16:
self.transformer.half()
self.lm_head.half()
self.post_init()
# TextMonkey: An OCR-Free Large Multimodal Model for Understanding Document
<br>
<p align="center">
<img src="https://v1.ax1x.com/2024/04/13/7ySD7w.png" width="300"/>
<p>
> [**TextMonkey: An OCR-Free Large Multimodal Model for Understanding Document**](https://arxiv.org/abs/2403.04473)<br>
> Yuliang Liu, Biao Yang, Qiang Liu, Zhang Li, Zhiyin Ma, Shuo Zhang, Xiang Bai <br>
[![arXiv](https://img.shields.io/badge/Arxiv-2403.04473-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2403.04473)
[![Source_code](https://img.shields.io/badge/Code-Available-white)](monkey_model/text_monkey/README.md)
[![Demo](https://img.shields.io/badge/Demo-blue)](http://vlrlab-monkey.xyz:7684/)
[![Data](https://img.shields.io/badge/Data-yellow)](https://www.modelscope.cn/datasets/lvskiller/TextMonkey_data)
[![Model Weight](https://img.shields.io/badge/Model_Weight-gray)](https://www.modelscope.cn/models/lvskiller/TextMonkey)
-----
**TextMonkey** is a multi-modal large model (LMM) focused on text-related tasks, including document question answering and scene text question answering. Compared with Monkey, TextMonkey has been improved in many aspects: by using zero-initialized Shifted Window Attention, TextMonkey realizes information interaction between windows at a higher input resolution; by calculating similarity to filter out important image features, not only can it simplify the input, but it can also improve the performance of the model. Furthermore, TextMonkey enhances interpretability and reduces hallucinations by extending multiple text-related tasks and incorporating location information into responses. At the same time, after fine-tuning, TextMonkey can also have the ability to understand user instructions and click on the corresponding location in the APP Agent, demonstrating its huge potential for downstream applications.
# TODO
- [x] Open source code, weight, and data
- [ ] Improve Chinese language proficiency
- [ ] TextMonkey with different LLMs
# Model Zoo
| Method | LLM | STVQA | TextVQA | OCRVQA | DocVQA | InfoVQA | ChartQA | FUNSD | SROIE | POIE | OCRBench |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| BLIP2-OPT-6.7B | OPT-6.7B | 20.9 | 23.5 | 9.7 | 3.2 | 11.3 | 3.4 | 0.2 | 0.1 | 0.3 | 235 |
| mPLUG-Owl | LLaMA-7B | 30.5 | 34.0 | 21.1 | 7.4 | 20.0 | 7.9 | 0.5 | 1.7 | 2.5 | 297 |
| InstructBLIP | Vircuna-7B | 27.4 | 29.1 | 41.3 | 4.5 | 16.4 | 5.3 | 0.2 | 0.6 | 1.0 | 276 |
| LLaVAR | Vircuna-7B | 39.2 | 41.8 | 24.0 | 12.3 | 16.5 | 12.2 | 0.5 | 5.2 | 5.9 | 346 |
| BLIVA | Vircuna-7B | 32.1 | 33.3 | 50.7 | 5.8 | 23.6 | 8.7 | 0.2 | 0.7 | 2.1 | 291 |
| mPLUG-Owl2 | LLaMA-7B | 49.8 | 53.9 | 58.7 | 17.9 | 18.9 | 19.4 | 1.4 | 3.2 | 9.9 | 366 |
| LLaVA1.5-7B$ | Vircuna-7B | 38.1 | 38.7 | 58.1 | 8.5 | 14.7 | 9.3 | 0.2 | 1.7 | 2.5 | 297 |
| TGDoc$ | Vircuna-7B | 36.3 | 46.2 | 37.2 | 9.0 | 12.8 | 12.7 | 1.4 | 3.0 | 22.2 | - |
| UniDoc | Vircuna-7B | 35.2 | 46.2 | 36.8 | 7.7 | 14.7 | 10.9 | 1.0 | 2.9 | 5.1 | - |
| DocPedia | Vircuna-7B | 45.5 | 60.2 | 57.2 | 47.1 | 15.2 | 46.9 | 9.9 | 21.4 | 39.9 | - |
| Monkey | Qwen-7B | 54.7 | 64.3 | 64.4 | 50.1 | 25.8 | 54.0 | 24.1 | 41.9 | 19.9 | 514 |
| InternVL | - | 62.2 | 59.8 | 30.5 | 28.7 | 23.6 | 45.6 | 6.5 | 26.4 | 25.9 | 517 |
| InternLM-XComposer2 | InternLM-7B | 59.6 | 62.2 | 49.6 | 39.7 | 28.6 | 51.6 | 15.3 | 34.2 | 49.3 | 511 |
| TextMonkey (40k data)| Qwen-7B | 61.8 | 65.9 | 71.3 | 64.3 | 28.2 | 58.2 | 32.3 | 47.0 | 27.9 | 561 |
| TextMonkey (50k data) | Qwen-7B | 61.2 | 64.3 | 72.2 | 66.7 | 28.6 | 59.9 | 42.9 | 46.2 | 32.0 | 558 |
## Environment
```python
conda create -n textmonkey python=3.10
conda activate textmonkey
git clone https://github.com/MelosY/TextMonkey.git
cd ./TextMonkey
pip install -r requirements.txt
```
## Evaluate
We also offer TextMonkey's model testing code, which you can explore above. You can execute the training code through executing:
```python
bash eval/eval_doc.sh
```
## Train
Execute the training code:
```python
bash finetune/finetune_textmonkey.sh
```
## Cases
TextMonkey can accurately locate and recognize text in both scene images and document images. In addition, the natural image in (a), the document in (b), the diagram in (c), and the table in (d) all demonstrate TextMonkey’s ability to identify, understand, and locate text information in a variety of scenarios.
<br>
<p align="center">
<img src="https://v1.ax1x.com/2024/04/13/7ySSXO.png" width="700"/>
<p>
<br>
TextMonkey has shown strong feasibility as an agent for smartphone applications. After fine-tuning using 15k user click data from the Rico dataset, TextMonkey was able to understand user intent and click the corresponding icon.
<br>
<p align="center">
<img src="https://v1.ax1x.com/2024/04/13/7ySOV6.png" width="700"/>
<p>
<br>
## Citing TextMonkey
If you wish to refer to the baseline results published here, please use the following BibTeX entries:
```BibTeX
@article{liu2024textmonkey,
title={TextMonkey: An OCR-Free Large Multimodal Model for Understanding Document},
author={Liu, Yuliang and Yang, Biao and Liu, Qiang and Li, Zhang and Ma, Zhiyin and Zhang, Shuo and Bai, Xiang},
journal={arXiv preprint arXiv:2403.04473},
year={2024}
}
```
## Copyright
We welcome suggestions to help us improve the TextMonkey. For any query, please contact Dr. Yuliang Liu: ylliu@hust.edu.cn. If you find something interesting, please also feel free to share with us through email or open an issue.
import math
from typing import Callable, Tuple
import torch
def self_soft_matching(
metric: torch.Tensor,
r: int,):
t = metric.shape[1]
with torch.no_grad():
metric = metric / metric.norm(dim=-1, keepdim=True)
a, b = metric[..., :, :], metric[..., :, :]
scores = a @ b.transpose(-1, -2) # a_lxb_l
b,_,_ = scores.shape
scores_diag = torch.tril(torch.ones(t,t))*2
scores_diag = scores_diag.expand(b, -1, -1).to(metric.device)
scores = scores-scores_diag
node_max, node_idx = scores.max(dim=-1) # a中最相似的点
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] # a中相似度排序并得到idx,降序
unm_idx = edge_idx[..., t-r:, :] # Unmerged Tokens # 后面的就是不merge的
def merge(src: torch.Tensor) -> torch.Tensor:
n, t1, c = src.shape
unm = src.gather(dim=-2, index=unm_idx.expand(n, r, c))
unm_idx_new = unm_idx
all_idx = unm_idx_new
all_max,all_idx_idx = torch.sort(all_idx,dim=1)
return unm.gather(dim=-2, index=all_idx_idx.expand(n, r, c))
return merge
from einops import rearrange, repeat
from einops_exts import rearrange_many
from torch import einsum
import torch.nn as nn
import torch
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
#Resample model
from einops import rearrange, repeat
from einops_exts import rearrange_many
from torch import einsum
from monkey_model.text_monkey.merge import *
class FeedForward(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.norm = nn.LayerNorm(in_features)
self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
self.scale = nn.Parameter(torch.ones(1))
with torch.no_grad():
nn.init.kaiming_uniform_(self.fc1.weight, a=math.sqrt(5))
nn.init.zeros_(self.fc2.weight)
def forward(self, x):
x = self.norm(x)
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
x = self.scale*x
return x
class Block(nn.Module):
def __init__(self, input_size,output_size):
super().__init__()
self.fc_1 = nn.Linear(input_size, output_size)
self.norm = nn.LayerNorm(output_size)
def forward(self, x):
x = self.fc_1(x)
x = self.norm(x)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm_media = nn.LayerNorm(dim)
self.norm_latents = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
x = self.norm_media(x)
latents = self.norm_latents(latents)
h = self.heads
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
q = q * self.scale
# attention
sim = einsum("... i d, ... j d -> ... i j", q, k)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
out = einsum("... i j, ... j d -> ... i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)", h=h)
return self.to_out(out)
class PerceiverResampler(nn.Module):
def __init__(
self,
*,
in_dim=1024,
out_dim=4096,
depth=1,
dim_head=128,
heads=8,
visual_tokens_num=512,
ff_mult=4,
):
super().__init__()
self.downsample = nn.Linear(out_dim,in_dim,bias=False)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=in_dim, dim_head=dim_head, heads=heads),
FeedForward(in_features=in_dim, hidden_features=in_dim,out_features=out_dim),
]
)
)
def forward(self, x,r=0):
B,L,C = x.shape
merge = self_soft_matching(x, r) # Replace with your features and r value
latents = merge(x)
down_x = self.downsample(x)
down_latent = self.downsample(latents)
for attn, ff in self.layers:
down_latent = attn(down_x, down_latent)
latents = ff(down_latent) + latents
return latents
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import math
import requests
from io import BytesIO
from functools import partial
from PIL import Image
from typing import Callable, Optional, Sequence, Tuple, List
import numpy as np
import sys
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.init import trunc_normal_
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from monkey_model.text_monkey.window import CrossWindowAttention,PatchMerging
from monkey_model.text_monkey.resampler import *
import random
import ipdb
import numpy as np
import matplotlib.pyplot as plt
def reconstruct_matrix(windows):
temp =[]
for col in windows:
temp.append(torch.cat((col),dim=3))
all_img = torch.cat(temp,dim=2)
return all_img
def sliding_window(matrix, window_size, stride):
b,c,height, width = matrix.shape
window_rows = math.ceil((height - window_size[0]) / stride) + 1
window_cols = math.ceil((width - window_size[1]) / stride) + 1
#windows = np.zeros((window_rows, window_cols, window_size[0], window_size[1]))
windows = []
for i in range(window_rows):
windows_col = []
for j in range(window_cols):
window = matrix[:,:, i*stride:i*stride+window_size[0], j*stride:j*stride+window_size[1]]
windows_col.append(window)
windows.append(windows_col)
return windows
def get_abs_pos(abs_pos, tgt_size):
# abs_pos: L, C
# tgt_size: M
# return: M, C
src_size = int(math.sqrt(abs_pos.size(0)))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
return F.interpolate(
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size, tgt_size),
mode="bicubic",
align_corners=False,
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
else:
return abs_pos
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class Resampler(nn.Module):
"""
A 2D perceiver-resampler network with one cross attention layers by
(grid_size**2) learnable queries and 2d sincos pos_emb
Outputs:
A tensor with the shape of (grid_size**2, embed_dim)
"""
def __init__(
self,
grid_size,
embed_dim,
num_heads,
kv_dim=None,
norm_layer=nn.LayerNorm
):
super().__init__()
self.num_queries = grid_size ** 2
self.embed_dim = embed_dim
self.num_heads = num_heads
self.pos_embed = nn.Parameter(
torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float()
).requires_grad_(False)
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
trunc_normal_(self.query, std=.02)
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
else:
self.kv_proj = nn.Identity()
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x, attn_mask=None):
pos_embed = get_abs_pos(self.pos_embed, x.size(1))
x = self.kv_proj(x)
x = self.ln_kv(x).permute(1, 0, 2)
N = x.shape[1]
q = self.ln_q(self.query)
out = self.attn(
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
x + pos_embed.unsqueeze(1),
x,
attn_mask=attn_mask)[0]
return out.permute(1, 0, 2)
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)
class Lora_Adapter(nn.Module):
def __init__(self,
d_model=None,
out_feat=None,
r=16,
dropout=0.05):
super().__init__()
self.d_model = d_model
self.out_feat = out_feat
self.r = r
self.lora_scale = nn.Parameter(torch.ones(1))
self.lora_a = nn.Linear(self.d_model, self.r,bias=False)
self.lora_b = nn.Linear(self.r, self.out_feat,bias=False)
self.lora_dropout = nn.Dropout(p=dropout)
with torch.no_grad():
nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_b.weight)
def forward(self, x ):
#residual = x if residual is None else residual
x = self.lora_dropout(x)
down = self.lora_a(x)
up = self.lora_b(down)
up = up * self.lora_scale
output = up
return output
class VisualAttention(nn.Module):
"""self-attention layer class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def __init__(self, embed_dim, num_heads,
bias=True, kdim=None, vdim=None,lora_repeat_num=4):
super(VisualAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
# Per attention head and per partition values.
assert embed_dim % num_heads == 0
self.hidden_size_per_attention_head = embed_dim // num_heads
self.num_attention_heads_per_partition = num_heads
self.hidden_size_per_partition = embed_dim
# Strided linear layer.
assert self._qkv_same_embed_dim, 'Only Support SelfAttention Currently'
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim)
self.in_proj_lora = []
for _ in range(lora_repeat_num):
self.in_proj_lora.append(Lora_Adapter(d_model=embed_dim,out_feat=3 * embed_dim))
self.in_proj_lora = nn.ModuleList(self.in_proj_lora)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj_lora = []
for _ in range(lora_repeat_num):
self.out_proj_lora.append(Lora_Adapter(d_model=embed_dim,out_feat=embed_dim))
self.out_proj_lora = nn.ModuleList(self.out_proj_lora)
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
def forward(self, query, key, value, attn_mask = None,lora_idx = None):
# query/key/value: [sq, b, h]
sq, b, _ = query.size()
assert query is key, 'Only Support Self-Attention Currently'
sk = sq
mixed_x_layer = self.in_proj(query)
if lora_idx == None:
pass
else:
lora_res = self.in_proj_lora[lora_idx](query)
mixed_x_layer += lora_res
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
query_layer, key_layer, value_layer = mixed_x_layer.split(
self.hidden_size_per_attention_head, dim=-1)
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(sq,
b * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head).transpose(0, 1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(sk,
b * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head).transpose(0, 1)
q_scaled = query_layer / self.norm_factor
if attn_mask is not None:
attention_probs = torch.baddbmm(attn_mask, q_scaled, key_layer.transpose(-2, -1))
else:
attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1))
attention_probs = attention_probs.softmax(dim=-1)
value_layer = value_layer.view(sk,
b * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head).transpose(0, 1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer)
# change view [b, np, sq, hn]
context_layer = context_layer.view(b,
self.num_attention_heads_per_partition,
sq, self.hidden_size_per_attention_head)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.out_proj(context_layer)
if lora_idx == None:
pass
else:
lora_res = self.out_proj_lora[lora_idx](context_layer)
output += lora_res
return output
class VisualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
is_cross_attention: bool = False,
lora_repeat_num = 4,
):
super().__init__()
self.ln_1 = norm_layer(d_model)
if is_cross_attention:
self.ln_1_kv = norm_layer(d_model)
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.attn = VisualAttention(d_model, n_head,lora_repeat_num = lora_repeat_num)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, mlp_width)),
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model))
]))
self.mlp_lora = []
for _ in range(lora_repeat_num):
self.mlp_lora.append(Lora_Adapter(d_model=d_model,out_feat=d_model,r=32))
self.mlp_lora = nn.ModuleList(self.mlp_lora)
def attention(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
lora_idx = None
):
k_x = k_x if k_x is not None else q_x
v_x = v_x if v_x is not None else q_x
attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
return self.attn(q_x, k_x, v_x, attn_mask=attn_mask,lora_idx=lora_idx)
def forward(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
lora_idx = None
):
k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
x = q_x + self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask,lora_idx=lora_idx)
residual = x
x = x + self.mlp(self.ln_2(x))
if lora_idx == None:
pass
else:
x += self.mlp_lora[lora_idx](residual)
return x
class TransformerBlock(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
lora_repeat_num=4,
add_window=False,
window_all=False,
image_size=(896,896)
):
super().__init__()
self.width = width
self.layers = layers
self.add_window = add_window
self.window_all = window_all
self.window_pos = [2,6,24,46]
self.window_dim = [128,256,512,1024]
self.window_head = [4,8,16,32]
if isinstance(image_size, tuple) or isinstance(image_size, list):
image_size = tuple(size // 14 for size in image_size)
else:
image_size = image_size//14
if self.add_window:
self.window_attention = []
for idx in range(len(self.window_pos)):
self.window_attention.append(CrossWindowAttention(image_size=image_size,dim=1664,hidden_dim=self.window_dim[idx],head=self.window_head[idx]))
self.window_attention = nn.ModuleList(self.window_attention)
self.resblocks = nn.ModuleList([
VisualAttentionBlock(
width, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer,lora_repeat_num=lora_repeat_num)
for _ in range(layers)
])
def get_cast_dtype(self) -> torch.dtype:
return self.resblocks[0].mlp.c_fc.weight.dtype
def get_cast_device(self) -> torch.device:
return self.resblocks[0].mlp.c_fc.weight.device
def forward(self, x,attn_mask: Optional[torch.Tensor] = None,lora_idx=None,image_size=(64,64)):
if isinstance(x,List):
window_idx = 0
for r_idx,r in enumerate(self.resblocks):
if self.add_window:
if r_idx in self.window_pos:
for i in range(len(x)):
for j in range(len(x[i])):
x[i][j] = x[i][j].permute(1, 0, 2) # LND -> NLD
x[i][j] = x[i][j].permute(0, 2, 1) # shape = [*, width, grid ** 2,]
x[i][j] = x[i][j].reshape(x[i][j].shape[0], x[i][j].shape[1], 32,32)
whole_image = reconstruct_matrix(x) #shape = [*,width,grid,grid]
whole_image = self.window_attention[window_idx](whole_image,image_size)
x = sliding_window(whole_image,(32,32),32)
for i in range(len(x)):
for j in range(len(x[i])):
x[i][j] = x[i][j].reshape(x[i][j].shape[0], x[i][j].shape[1], -1)
x[i][j] = x[i][j].permute(0, 2, 1) # shape = [*, grid ** 2, width]
x[i][j] = x[i][j].permute(1, 0, 2) # NLD -> LND
window_idx += 1
if lora_idx is None or lora_idx == 0 :
for i in range(len(x)):
for j in range(len(x[i])):
x[i][j] = r(x[i][j],attn_mask=attn_mask,lora_idx=lora_idx)
else:
temp_lora_idx = 0
for i in range(len(x)):
for j in range(len(x[i])):
x[i][j] = r(x[i][j],attn_mask=attn_mask,lora_idx=temp_lora_idx)
temp_lora_idx += 1
return x
else:
for r in self.resblocks:
x = r(x, attn_mask=attn_mask)
return x
class VisionTransformer(nn.Module):
def __init__(
self,
image_size,
patch_size: int,
width: int,
layers: int,
heads: int,
mlp_ratio: float,
n_queries: int = 256,
output_dim: int = 512,
lora_repeat_num: int = 0,
add_window: bool = False,
use_global:bool =False,
resampler=False,
r=512,
**kwargs
):
super().__init__()
if isinstance(image_size, tuple) or isinstance(image_size, list):
image_height, image_width = self.image_size = image_size
else:
image_height, image_width = self.image_size = (image_size,image_size)
patch_height, patch_width = self.patch_size = (patch_size, patch_size)
self.grid_size = (image_height // patch_height, image_width // patch_width)
self.output_dim = output_dim
self.add_window = add_window
self.use_global = use_global
self.resampler = resampler
self.r = r
mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)
self.image_transform = transforms.Compose([
transforms.Resize(
self.image_size,
interpolation=InterpolationMode.BICUBIC
),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
self.lora_repeat_num = lora_repeat_num
# class embeddings and positional embeddings
scale = width ** -0.5
self.positional_embedding = nn.Parameter(scale * torch.randn(256, width))
norm_layer = partial(nn.LayerNorm, eps=1e-6)
act_layer = nn.GELU
self.ln_pre = norm_layer(width)
self.transformer = TransformerBlock(
width,
layers,
heads,
mlp_ratio,
act_layer=act_layer,
norm_layer=norm_layer,
lora_repeat_num=lora_repeat_num,
add_window=add_window,
image_size=image_size
)
self.attn_pool = Resampler(
grid_size=int(math.sqrt(256)),
embed_dim=output_dim,
num_heads=output_dim // 128,
kv_dim=width,
norm_layer=norm_layer,
)
self.ln_post = norm_layer(output_dim)
self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim))
if self.resampler:
self.downresampler = PerceiverResampler()
def forward(self, x: torch.Tensor,lora_idx=None,add_window=False):
x = x.to(
dtype=self.transformer.get_cast_dtype(),
device=self.transformer.get_cast_device(),
)
# to patches
x = self.conv1(x) # shape = [b, width, grid, grid]
b,c,h,w = x.shape
if add_window:
x = sliding_window(x,(32,32),32)
for i in range(len(x)):
for j in range(len(x[i])):
x[i][j] = x[i][j].reshape(x[i][j].shape[0], x[i][j].shape[1], -1)
x[i][j] = x[i][j].permute(0, 2, 1) # shape = [*, grid ** 2, width]
x[i][j] = x[i][j] + get_abs_pos(self.positional_embedding,x[i][j].size(1))
x[i][j] = self.ln_pre(x[i][j])
x[i][j] = x[i][j].permute(1, 0, 2) # NLD -> LND
else:
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = x + get_abs_pos(self.positional_embedding, x.size(1))
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x,lora_idx=lora_idx,image_size=(h,w))
if add_window:
for i in range(len(x)):
for j in range(len(x[i])):
x[i][j] = x[i][j].permute(1, 0, 2) # LND -> NLD
x[i][j] = self.attn_pool(x[i][j])
x[i][j] = self.ln_post(x[i][j])
x[i][j] = x[i][j] @ self.proj
temp =[]
for col in x:
temp.append(torch.cat((col),dim=1))
x = torch.cat(temp,dim=1)
else:
x = x.permute(1, 0, 2) # LND -> NLD
x = self.attn_pool(x)
x = self.ln_post(x)
x = x @ self.proj
return x
def encode(self, image_paths: List[str],lora_idx=None,input_image=None):
if input_image is None:
images = []
for image_path in image_paths:
if image_path.startswith("http://") or image_path.startswith("https://"):
image = Image.open(requests.get(image_path, stream=True).raw)
else:
image = Image.open(image_path)
image = image.convert("RGB")
## to imitate transmission loss in the real world.
if self.training:
output = BytesIO()
qual = random.randint(20, 100)
image.save(output, format='JPEG', quality=qual)
image_data = output.getvalue()
image =Image.open(BytesIO(image_data))
images.append(self.image_transform(image))
images = torch.stack(images, dim=0)
else:
images = input_image
images_448 = F.interpolate(images, size=(448,448), mode='bicubic')
if lora_idx == 1:
local_feat = self(images,0,add_window=True)
else:
local_feat = self(images,lora_idx,add_window=True)
if self.resampler:
local_feat = self.downresampler(local_feat,r = self.r)
if self.use_global:
global_feat = self(images_448,lora_idx=None,add_window=False)
return torch.cat([local_feat,global_feat],dim=1)
else:
return local_feat
import logging
import math
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import numpy as np
from itertools import repeat
import collections.abc
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.init import trunc_normal_
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from functools import partial
from itertools import repeat
import collections.abc
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
min_value = min_value or divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < round_limit * v:
new_v += divisor
return new_v
def extend_tuple(x, n):
# pads a tuple to specified n by padding with last value
if not isinstance(x, (tuple, list)):
x = (x,)
else:
x = tuple(x)
pad_n = n - len(x)
if pad_n <= 0:
return x[:n]
return x + (x[-1],) * pad_n
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f'drop_prob={round(self.drop_prob,3):0.3f}'
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=False)
self.act = act_layer()
self.drop1 = nn.Dropout(0.05)
self.fc2 = linear_layer(hidden_features, out_features, bias=False)
self.scale = nn.Parameter(torch.ones(1))
with torch.no_grad():
nn.init.kaiming_uniform_(self.fc1.weight, a=math.sqrt(5))
nn.init.zeros_(self.fc2.weight)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.scale*x
return x
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
_int_or_tuple_2_t = Union[int, Tuple[int, int]]
def window_partition(
x: torch.Tensor,
window_size: Tuple[int, int],
) -> torch.Tensor:
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
return windows
def window_reverse(windows, window_size: Tuple[int, int], H: int, W: int):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
C = windows.shape[-1]
x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
return x
def get_relative_position_index(win_h: int, win_w: int):
# get pair-wise relative position index for each token inside the window
coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += win_h - 1 # shift to start from 0
relative_coords[:, :, 1] += win_w - 1
relative_coords[:, :, 0] *= 2 * win_w - 1
return relative_coords.sum(-1) # Wh*Ww, Wh*Ww
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.reduction_2 = nn.Linear(2 * dim, dim, bias=False)
self.norm = norm_layer(4 * dim)
self.norm_2 = norm_layer(2 * dim)
def forward(self, x):
"""
X bxcxgxg
x: B, H*W, C
"""
size= self.input_resolution
B, C, G,_ = x.shape
assert G*G == size * size, "input feature has wrong size"
x = x.reshape(x.shape[0],x.shape[1],-1) #bxcxl
x = x.permute(0,2,1) #bxlxc
B, _, C = x.shape
x = x.view(B, size, size, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
x = self.norm_2(x)
x = self.reduction_2(x)
x = x.view(B,-1,C)
x = x.permute(0,2,1) #bxcxl
x = x.reshape(x.shape[0],x.shape[1],G//2,G//2) #bxcxl
return x
class WindowAttention(nn.Module):
""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports shifted and non-shifted windows.
"""
fused_attn: torch.jit.Final[bool]
def __init__(
self,
dim: int,
num_heads: int,
head_dim: Optional[int] = None,
window_size: _int_or_tuple_2_t = 7,
qkv_bias: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
):
"""
Args:
dim: Number of input channels.
num_heads: Number of attention heads.
head_dim: Number of channels per head (dim // num_heads if not set)
window_size: The height and width of the window.
qkv_bias: If True, add a learnable bias to query, key, value.
attn_drop: Dropout ratio of attention weight.
proj_drop: Dropout ratio of output.
"""
super().__init__()
self.window_size = to_2tuple(window_size) # Wh, Ww
win_h, win_w = self.window_size
self.window_area = win_h * win_w
self.num_heads = num_heads
head_dim = head_dim or dim // num_heads
attn_dim = head_dim * num_heads
self.scale = head_dim ** -0.5
# define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH
self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads))
# get pair-wise relative position index for each token inside the window
self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w), persistent=False)
self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(attn_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def _get_rel_pos_bias(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(self.window_area, self.window_area, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
return relative_position_bias.unsqueeze(0)
def forward(self, x, mask: Optional[torch.Tensor] = None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn + self._get_rel_pos_bias()
if mask is not None:
num_win = mask.shape[0]
attn = attn.view(-1, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B_, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock(nn.Module):
""" Swin Transformer Block.
"""
def __init__(
self,
dim: int,
hidden_dim:int,
input_resolution: _int_or_tuple_2_t,
num_heads: int = 4,
head_dim: Optional[int] = None,
window_size: _int_or_tuple_2_t = 7,
shift_size: int = 0,
mlp_ratio: float = 1.,
qkv_bias: bool = True,
proj_drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm
):
"""
Args:
dim: Number of input channels.
input_resolution: Input resolution.
window_size: Window size.
num_heads: Number of attention heads.
head_dim: Enforce the number of channels per head
shift_size: Shift size for SW-MSA.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: If True, add a learnable bias to query, key, value.
proj_drop: Dropout rate.
attn_drop: Attention dropout rate.
drop_path: Stochastic depth rate.
act_layer: Activation layer.
norm_layer: Normalization layer.
"""
super().__init__()
self.dim = dim
self.hidden_dim = hidden_dim
self.input_resolution = input_resolution
ws, ss = self._calc_window_shift(window_size, shift_size)
self.window_size: Tuple[int, int] = ws
self.shift_size: Tuple[int, int] = ss
self.window_area = self.window_size[0] * self.window_size[1]
self.mlp_ratio = mlp_ratio
self.downsample = nn.Linear(dim,hidden_dim,bias=False)
self.norm1 = norm_layer(hidden_dim)
self.attn = WindowAttention(
hidden_dim,
num_heads=num_heads,
head_dim=head_dim,
window_size=to_2tuple(self.window_size),
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(hidden_dim)
self.mlp = Mlp(
in_features=hidden_dim,
hidden_features=int(hidden_dim * 2),
out_features = 1664,
act_layer=act_layer,
drop=proj_drop,
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
attn_mask = self.calc_attn(self.input_resolution)
self.register_buffer("attn_mask", attn_mask, persistent=False)
def calc_attn(self,input_resolution):
H, W = input_resolution
H = math.ceil(H / self.window_size[0]) * self.window_size[0]
W = math.ceil(W / self.window_size[1]) * self.window_size[1]
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
cnt = 0
for h in (
slice(0, -self.window_size[0]),
slice(-self.window_size[0], -self.shift_size[0]),
slice(-self.shift_size[0], None)):
for w in (
slice(0, -self.window_size[1]),
slice(-self.window_size[1], -self.shift_size[1]),
slice(-self.shift_size[1], None)):
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_area)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]:
target_window_size = to_2tuple(target_window_size)
target_shift_size = to_2tuple(target_shift_size)
window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
return tuple(window_size), tuple(shift_size)
def _attn(self, x):
B, H, W, C = x.shape
# cyclic shift
has_shift = any(self.shift_size)
if has_shift:
shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
else:
shifted_x = x
# pad for resolution not divisible by window size
pad_h = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
pad_w = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
shifted_x = torch.nn.functional.pad(shifted_x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_area, C) # nW*B, window_size*window_size, C
attn_mask = self.attn_mask
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
shifted_x = shifted_x[:, :H, :W, :].contiguous()
# reverse cyclic shift
if has_shift:
x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2))
else:
x = shifted_x
return x
def forward(self, x):
x = self.downsample(x)
B, H, W, C = x.shape
# C = hidden_dim
x = x + self.drop_path1(self._attn(self.norm1(x)))
x = x.reshape(B, -1, C)
x = self.drop_path2(self.mlp(self.norm2(x)))
x = x.reshape(B, H, W, self.dim)
return x
def get_abs_pos(abs_pos, tgt_size):
# abs_pos: L, C
# tgt_size: M
# return: M, C
src_size = int(math.sqrt(abs_pos.size(1)))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
return F.interpolate(
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size, tgt_size),
mode="bicubic",
align_corners=False,
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
else:
return abs_pos
class CrossWindowAttention(nn.Module):
""" Patch Merging Layer.
"""
def __init__(
self,
image_size,
dim,
hidden_dim,
head,
window_size = 12
):
"""
Args:
dim: Number of input channels.
out_dim: Number of output channels (or 2 * dim if None)
norm_layer: Normalization layer.
"""
super().__init__()
if isinstance(image_size, tuple) or isinstance(image_size, list):
self.image_size = image_size
else:
self.image_size = (image_size,image_size)
self.dim = dim
self.window_size = window_size
self.shift_size = window_size // 2
self.position_embedding = nn.Parameter(torch.zeros(1, self.image_size[0]*self.image_size[1], 1664))
trunc_normal_(self.position_embedding, std=.02)
self.shift_attn = SwinTransformerBlock(dim=dim,hidden_dim=hidden_dim,input_resolution = self.image_size,num_heads=head,window_size =self.window_size,shift_size=self.shift_size)
def forward(self, x,image_size):
# X bxcxgxg
B,C,G,_=x.shape
x = x.reshape(x.shape[0],x.shape[1],-1) #bxcxl
x = x.permute(0,2,1) #bxlxc
B, L, C = x.shape
residual = x
H,W = image_size
pos_embed = get_abs_pos(self.position_embedding,x.size(1))
x = x + pos_embed
x = x.view(B,H,W,C)
x = self.shift_attn(x)
x = x.view(B,-1,C)
x = x + residual
x = x.permute(0,2,1) #bxcxl
x = x.reshape(x.shape[0],x.shape[1],G,G) #bxcxl
return x
...@@ -41,7 +41,7 @@ SPECIAL_TOKENS = ( ...@@ -41,7 +41,7 @@ SPECIAL_TOKENS = (
IMSTART, IMSTART,
IMEND, IMEND,
) + EXTRAS ) + EXTRAS
IMG_TOKEN_SPAN = 1280 # IMG_TOKEN_SPAN = 1024
def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
...@@ -128,6 +128,7 @@ class QWenTokenizer(PreTrainedTokenizer): ...@@ -128,6 +128,7 @@ class QWenTokenizer(PreTrainedTokenizer):
image_start_tag, image_end_tag, image_start_tag, image_end_tag,
image_pad_tag image_pad_tag
) )
self.IMG_TOKEN_SPAN = 1280
self.errors = errors # how to handle errors in decoding self.errors = errors # how to handle errors in decoding
...@@ -272,10 +273,10 @@ class QWenTokenizer(PreTrainedTokenizer): ...@@ -272,10 +273,10 @@ class QWenTokenizer(PreTrainedTokenizer):
img_tokens = img_tokens[1:-1] img_tokens = img_tokens[1:-1]
img_url = b''.join(img_tokens) img_url = b''.join(img_tokens)
out_img_tokens = list(map(self.decoder.get, img_url)) out_img_tokens = list(map(self.decoder.get, img_url))
if len(out_img_tokens) > IMG_TOKEN_SPAN: if len(out_img_tokens) > self.IMG_TOKEN_SPAN:
raise ValueError("The content in {}..{} is too long".format( raise ValueError("The content in {}..{} is too long".format(
self.image_start_tag, self.image_end_tag)) self.image_start_tag, self.image_end_tag))
out_img_tokens.extend([self.image_pad_tag] * (IMG_TOKEN_SPAN - len(out_img_tokens))) out_img_tokens.extend([self.image_pad_tag] * (self.IMG_TOKEN_SPAN - len(out_img_tokens)))
out_img_tokens = [self.image_start_tag] + out_img_tokens + [self.image_end_tag] out_img_tokens = [self.image_start_tag] + out_img_tokens + [self.image_end_tag]
return out_img_tokens return out_img_tokens
......
...@@ -436,7 +436,10 @@ class VisionTransformer(nn.Module): ...@@ -436,7 +436,10 @@ class VisionTransformer(nn.Module):
**kwargs **kwargs
): ):
super().__init__() super().__init__()
image_height, image_width = self.image_size = (image_size, image_size) if isinstance(image_size, tuple) or isinstance(image_size, list):
image_height, image_width = self.image_size = image_size
else:
image_height, image_width = self.image_size = (image_size,image_size)
patch_height, patch_width = self.patch_size = (patch_size, patch_size) patch_height, patch_width = self.patch_size = (patch_size, patch_size)
self.grid_size = (image_height // patch_height, image_width // patch_width) self.grid_size = (image_height // patch_height, image_width // patch_width)
self.output_dim = output_dim self.output_dim = output_dim
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment