"examples/pytorch/graphsage/dist/train_dist_unsupervised.py" did not exist on "a4c931a9bf182b78bdffee7b71142122eceafe75"
Commit 0911606e authored by wanglch's avatar wanglch
Browse files

Initial commit

parent bde8d813
# TextMonkey: An OCR-Free Large Multimodal Model for Understanding Document
<br>
<p align="center">
<img src="https://v1.ax1x.com/2024/04/13/7ySD7w.png" width="300"/>
<p>
> [**TextMonkey: An OCR-Free Large Multimodal Model for Understanding Document**](https://arxiv.org/abs/2403.04473)<br>
> Yuliang Liu, Biao Yang, Qiang Liu, Zhang Li, Zhiyin Ma, Shuo Zhang, Xiang Bai <br>
[![arXiv](https://img.shields.io/badge/Arxiv-2403.04473-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2403.04473)
[![Source_code](https://img.shields.io/badge/Code-Available-white)](monkey_model/text_monkey/README.md)
[![Demo](https://img.shields.io/badge/Demo-blue)](http://vlrlab-monkey.xyz:7684/)
[![Data](https://img.shields.io/badge/Data-yellow)](https://huggingface.co/datasets/MelosY/TextMonkey_Data)
[![Model Weight](https://img.shields.io/badge/Model_Weight-gray)](https://www.modelscope.cn/models/lvskiller/TextMonkey)
-----
**TextMonkey** is a multi-modal large model (LMM) focused on text-related tasks, including document question answering and scene text question answering. Compared with Monkey, TextMonkey has been improved in many aspects: by using zero-initialized Shifted Window Attention, TextMonkey realizes information interaction between windows at a higher input resolution; by calculating similarity to filter out important image features, not only can it simplify the input, but it can also improve the performance of the model. Furthermore, TextMonkey enhances interpretability and reduces hallucinations by extending multiple text-related tasks and incorporating location information into responses. At the same time, after fine-tuning, TextMonkey can also have the ability to understand user instructions and click on the corresponding location in the APP Agent, demonstrating its huge potential for downstream applications.
# TODO
- [x] Open source code, weight, and data
- [ ] Support training using 3090 GPUs (24Gb video memory)
- [ ] Improve Chinese language proficiency
- [ ] TextMonkey with different LLMs
# Model Zoo
TextMonkey was trained using 8 A800 GPUs on a dataset of 400k data, requiring approximately 1 day and 6 hours of training time. It is capable of running inference on a 3090 GPU.
| Method | LLM | STVQA | TextVQA | OCRVQA | DocVQA | InfoVQA | ChartQA | FUNSD | SROIE | POIE | OCRBench |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| BLIP2-OPT-6.7B | OPT-6.7B | 20.9 | 23.5 | 9.7 | 3.2 | 11.3 | 3.4 | 0.2 | 0.1 | 0.3 | 235 |
| mPLUG-Owl | LLaMA-7B | 30.5 | 34.0 | 21.1 | 7.4 | 20.0 | 7.9 | 0.5 | 1.7 | 2.5 | 297 |
| InstructBLIP | Vircuna-7B | 27.4 | 29.1 | 41.3 | 4.5 | 16.4 | 5.3 | 0.2 | 0.6 | 1.0 | 276 |
| LLaVAR | Vircuna-7B | 39.2 | 41.8 | 24.0 | 12.3 | 16.5 | 12.2 | 0.5 | 5.2 | 5.9 | 346 |
| BLIVA | Vircuna-7B | 32.1 | 33.3 | 50.7 | 5.8 | 23.6 | 8.7 | 0.2 | 0.7 | 2.1 | 291 |
| mPLUG-Owl2 | LLaMA-7B | 49.8 | 53.9 | 58.7 | 17.9 | 18.9 | 19.4 | 1.4 | 3.2 | 9.9 | 366 |
| LLaVA1.5-7B$ | Vircuna-7B | 38.1 | 38.7 | 58.1 | 8.5 | 14.7 | 9.3 | 0.2 | 1.7 | 2.5 | 297 |
| TGDoc$ | Vircuna-7B | 36.3 | 46.2 | 37.2 | 9.0 | 12.8 | 12.7 | 1.4 | 3.0 | 22.2 | - |
| UniDoc | Vircuna-7B | 35.2 | 46.2 | 36.8 | 7.7 | 14.7 | 10.9 | 1.0 | 2.9 | 5.1 | - |
| DocPedia | Vircuna-7B | 45.5 | 60.2 | 57.2 | 47.1 | 15.2 | 46.9 | 9.9 | 21.4 | 39.9 | - |
| Monkey | Qwen-7B | 54.7 | 64.3 | 64.4 | 50.1 | 25.8 | 54.0 | 24.1 | 41.9 | 19.9 | 514 |
| InternVL | - | 62.2 | 59.8 | 30.5 | 28.7 | 23.6 | 45.6 | 6.5 | 26.4 | 25.9 | 517 |
| InternLM-XComposer2 | InternLM-7B | 59.6 | 62.2 | 49.6 | 39.7 | 28.6 | 51.6 | 15.3 | 34.2 | 49.3 | 511 |
| TextMonkey (~400k data)| Qwen-7B | 61.8 | 65.9 | 71.3 | 64.3 | 28.2 | 58.2 | 32.3 | 47.0 | 27.9 | 561 |
| TextMonkey (~500k data) | Qwen-7B | 61.2 | 64.3 | 72.2 | 66.7 | 28.6 | 59.9 | 42.9 | 46.2 | 32.0 | 558 |
## Environment
```python
conda create -n textmonkey python=3.10
conda activate textmonkey
git clone https://github.com/Yuliang-Liu/Monkey.git
cd ./Monkey
pip install -r requirements.txt
```
## Evaluate
We also offer TextMonkey's model testing code, which you can explore above. You can execute the training code through executing:
```python
bash eval/eval_doc.sh
```
## Train
Execute the training code:
```python
bash finetune/finetune_textmonkey.sh
```
## Cases
TextMonkey can accurately locate and recognize text in both scene images and document images. In addition, the natural image in (a), the document in (b), the diagram in (c), and the table in (d) all demonstrate TextMonkey’s ability to identify, understand, and locate text information in a variety of scenarios.
<br>
<p align="center">
<img src="https://v1.ax1x.com/2024/04/13/7ySSXO.png" width="700"/>
<p>
<br>
TextMonkey has shown strong feasibility as an agent for smartphone applications. After fine-tuning using 15k user click data from the Rico dataset, TextMonkey was able to understand user intent and click the corresponding icon.
<br>
<p align="center">
<img src="https://v1.ax1x.com/2024/04/13/7ySOV6.png" width="700"/>
<p>
<br>
## Citing TextMonkey
If you wish to refer to the baseline results published here, please use the following BibTeX entries:
```BibTeX
@article{liu2024textmonkey,
title={TextMonkey: An OCR-Free Large Multimodal Model for Understanding Document},
author={Liu, Yuliang and Yang, Biao and Liu, Qiang and Li, Zhang and Ma, Zhiyin and Zhang, Shuo and Bai, Xiang},
journal={arXiv preprint arXiv:2403.04473},
year={2024}
}
```
## Copyright
We welcome suggestions to help us improve the TextMonkey. For any query, please contact Dr. Yuliang Liu: ylliu@hust.edu.cn. If you find something interesting, please also feel free to share with us through email or open an issue.
import math
from typing import Callable, Tuple
import torch
def self_soft_matching(
metric: torch.Tensor,
r: int,):
t = metric.shape[1]
with torch.no_grad():
metric = metric / metric.norm(dim=-1, keepdim=True)
a, b = metric[..., :, :], metric[..., :, :]
scores = a @ b.transpose(-1, -2) # a_lxb_l
b,_,_ = scores.shape
scores_diag = torch.tril(torch.ones(t,t))*2
scores_diag = scores_diag.expand(b, -1, -1).to(metric.device)
scores = scores-scores_diag
node_max, node_idx = scores.max(dim=-1) # a中最相似的点
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] # a中相似度排序并得到idx,降序
unm_idx = edge_idx[..., t-r:, :] # Unmerged Tokens # 后面的就是不merge的
def merge(src: torch.Tensor) -> torch.Tensor:
n, t1, c = src.shape
unm = src.gather(dim=-2, index=unm_idx.expand(n, r, c))
unm_idx_new = unm_idx
all_idx = unm_idx_new
all_max,all_idx_idx = torch.sort(all_idx,dim=1)
return unm.gather(dim=-2, index=all_idx_idx.expand(n, r, c))
return merge
from einops import rearrange, repeat
from einops_exts import rearrange_many
from torch import einsum
import torch.nn as nn
import torch
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
#Resample model
from einops import rearrange, repeat
from einops_exts import rearrange_many
from torch import einsum
from monkey_model.text_monkey.merge import *
class FeedForward(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.norm = nn.LayerNorm(in_features)
self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
self.scale = nn.Parameter(torch.ones(1))
with torch.no_grad():
nn.init.kaiming_uniform_(self.fc1.weight, a=math.sqrt(5))
nn.init.zeros_(self.fc2.weight)
def forward(self, x):
x = self.norm(x)
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
x = self.scale*x
return x
class Block(nn.Module):
def __init__(self, input_size,output_size):
super().__init__()
self.fc_1 = nn.Linear(input_size, output_size)
self.norm = nn.LayerNorm(output_size)
def forward(self, x):
x = self.fc_1(x)
x = self.norm(x)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm_media = nn.LayerNorm(dim)
self.norm_latents = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
x = self.norm_media(x)
latents = self.norm_latents(latents)
h = self.heads
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
q = q * self.scale
# attention
sim = einsum("... i d, ... j d -> ... i j", q, k)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
out = einsum("... i j, ... j d -> ... i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)", h=h)
return self.to_out(out)
class PerceiverResampler(nn.Module):
def __init__(
self,
*,
in_dim=1024,
out_dim=4096,
depth=1,
dim_head=128,
heads=8,
visual_tokens_num=512,
ff_mult=4,
):
super().__init__()
self.downsample = nn.Linear(out_dim,in_dim,bias=False)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=in_dim, dim_head=dim_head, heads=heads),
FeedForward(in_features=in_dim, hidden_features=in_dim,out_features=out_dim),
]
)
)
def forward(self, x,r=0):
B,L,C = x.shape
merge = self_soft_matching(x, r) # Replace with your features and r value
latents = merge(x)
down_x = self.downsample(x)
down_latent = self.downsample(latents)
for attn, ff in self.layers:
down_latent = attn(down_x, down_latent)
latents = ff(down_latent) + latents
return latents
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
{
"auto_map": {
"AutoTokenizer": [
"tokenization_qwen.QWenTokenizer",
null
]
},
"clean_up_tokenization_spaces": true,
"model_max_length": 2048,
"padding_side": "right",
"tokenizer_class": "QWenTokenizer"
}
This diff is collapsed.
CUDA_VISIBLE_DEVICES=3,4 python demo_textmonkey.py -c /home/wanglch/projects/TextMonkey/TextMonkey_base CUDA_VISIBLE_DEVICES=3,4 python demo_textmonkey.py -c ../TextMonkey/TextMonkey_base
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment