Commit b81b2f59 authored by wanglch's avatar wanglch
Browse files

Initial commit

parent f7c86e68
This diff is collapsed.
import importlib
import math
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.cuda.amp import autocast
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
from transformers.generation.logits_process import LogitsProcessorList
if TYPE_CHECKING:
from transformers.generation.streamers import BaseStreamer
from transformers.generation.utils import GenerateOutput
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
try:
from einops import rearrange
except ImportError:
rearrange = None
from torch import nn
from monkey_model.modeling_qwen import QWenModel,QWenPreTrainedModel,QWenLMHeadModel
from monkey_model.text_monkey.visual_text import VisionTransformer
SUPPORT_CUDA = torch.cuda.is_available()
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
logger = logging.get_logger(__name__)
class TextMonkeyModel(QWenModel):
def __init__(self, config):
super().__init__(config)
self.visual = VisionTransformer(**config.visual)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
if past_key_values is None and torch.any(input_ids == self.config.visual['image_start_id']):
bos_pos = torch.where(input_ids == self.config.visual['image_start_id'])
eos_pos = torch.where(input_ids == self.config.visual['image_start_id'] + 1)
assert (bos_pos[0] == eos_pos[0]).all()
img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
images = []
for i, a, b in img_pos:
image = input_ids[i][a + 1 : b - 1].tolist()
image = image[ : image.index(self.config.visual['image_start_id'] + 2)]
images.append(bytes(image).decode('utf-8'))
if self.visual.lora_repeat_num>0:
images = self.visual.encode(images,lora_idx=self.visual.lora_repeat_num)
else:
images = self.visual.encode(images)
assert images.shape[0] == len(images)
else:
images = None
return super().forward(input_ids,
past_key_values,
attention_mask,
token_type_ids,
position_ids,
head_mask,inputs_embeds,
encoder_hidden_states,
encoder_attention_mask,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
images)
class TextMonkeyLMHeadModel(QWenLMHeadModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"]
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"]
def __init__(self, config):
super().__init__(config)
assert (
config.bf16 + config.fp16 + config.fp32 <= 1
), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true"
autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
if autoset_precision:
if SUPPORT_BF16:
logger.warn(
"The model is automatically converting to bf16 for faster inference. "
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
)
config.bf16 = True
elif SUPPORT_FP16:
logger.warn(
"The model is automatically converting to fp16 for faster inference. "
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
)
config.fp16 = True
else:
config.fp32 = True
if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".")
if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster")
if config.fp32:
if SUPPORT_BF16:
logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
elif SUPPORT_FP16:
logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
self.transformer = TextMonkeyModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.bf16:
self.transformer.bfloat16()
self.lm_head.bfloat16()
if config.fp16:
self.transformer.half()
self.lm_head.half()
self.post_init()
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Generation support."""
from typing import Tuple, List, Union, Iterable
import numpy as np
import torch
import torch.nn.functional as F
from transformers import PreTrainedTokenizer
from transformers import logging
from transformers.generation import LogitsProcessor
logger = logging.get_logger(__name__)
# Types.
HistoryType = List[Tuple[str, str]]
TokensType = List[int]
BatchTokensType = List[List[int]]
def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType:
for tokens in batch:
context_length = len(tokens)
if context_length < seq_length:
tokens.extend([pad_id] * (seq_length - context_length))
return batch
def get_ltor_masks_and_position_ids(
data,
eod_token,
reset_position_ids,
reset_attention_mask,
eod_mask_loss,
):
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
micro_batch_size, seq_length = data.size()
# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(
torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
).view(att_mask_batch, 1, seq_length, seq_length)
# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(micro_batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i + 1) :] -= i + 1 - prev_index
prev_index = i + 1
# Convert attention mask to binary:
attention_mask = attention_mask < 0.5
return attention_mask, loss_mask, position_ids
def get_batch(context_tokens: torch.LongTensor, eod_id: int):
"""Generate batch from context tokens."""
# Move to GPU.
tokens = context_tokens.contiguous().to(context_tokens.device)
# Get the attention mask and postition ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens,
eod_id,
reset_position_ids=False,
reset_attention_mask=False,
eod_mask_loss=False,
)
return tokens, attention_mask, position_ids
def get_stop_words_ids(chat_format, tokenizer):
if chat_format == "raw":
stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
elif chat_format == "chatml":
stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
else:
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
return stop_words_ids
def make_context(
tokenizer: PreTrainedTokenizer,
query: str,
history: List[Tuple[str, str]] = None,
system: str = "",
max_window_size: int = 6144,
chat_format: str = "chatml",
):
if history is None:
history = []
if chat_format == "chatml":
im_start, im_end = "<|im_start|>", "<|im_end|>"
im_start_tokens = [tokenizer.im_start_id]
im_end_tokens = [tokenizer.im_end_id]
nl_tokens = tokenizer.encode("\n")
def _tokenize_str(role, content):
return f"{role}\n{content}", tokenizer.encode(
role, allowed_special=set(tokenizer.IMAGE_ST)
) + nl_tokens + tokenizer.encode(content, allowed_special=set(tokenizer.IMAGE_ST))
system_text, system_tokens_part = _tokenize_str("system", system)
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
raw_text = ""
context_tokens = []
for turn_query, turn_response in reversed(history):
query_text, query_tokens_part = _tokenize_str("user", turn_query)
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
if turn_response is not None:
response_text, response_tokens_part = _tokenize_str(
"assistant", turn_response
)
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
prev_chat = (
f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
)
else:
next_context_tokens = nl_tokens + query_tokens + nl_tokens
prev_chat = f"\n{im_start}{query_text}{im_end}\n"
current_context_size = (
len(system_tokens) + len(next_context_tokens) + len(context_tokens)
)
if current_context_size < max_window_size:
context_tokens = next_context_tokens + context_tokens
raw_text = prev_chat + raw_text
else:
break
context_tokens = system_tokens + context_tokens
raw_text = f"{im_start}{system_text}{im_end}" + raw_text
context_tokens += (
nl_tokens
+ im_start_tokens
+ _tokenize_str("user", query)[1]
+ im_end_tokens
+ nl_tokens
+ im_start_tokens
+ tokenizer.encode("assistant")
+ nl_tokens
)
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
elif chat_format == "raw":
raw_text = query
context_tokens = tokenizer.encode(raw_text)
else:
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
return raw_text, context_tokens
def _decode_default(
tokens: List[int],
*,
stop_words: List[str],
eod_words: List[str],
tokenizer: PreTrainedTokenizer,
raw_text_len: int,
verbose: bool = False,
return_end_reason: bool = False,
errors: str='replace',
):
trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:]
if verbose:
print("\nRaw Generate: ", trim_decode_tokens)
end_reason = f"Gen length {len(tokens)}"
for stop_word in stop_words:
trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
for eod_word in eod_words:
if eod_word in trim_decode_tokens:
end_reason = f"Gen {eod_word!r}"
trim_decode_tokens = trim_decode_tokens.split(eod_word)[0]
trim_decode_tokens = trim_decode_tokens.strip()
if verbose:
print("\nEnd Reason:", end_reason)
print("\nGenerate: ", trim_decode_tokens)
if return_end_reason:
return trim_decode_tokens, end_reason
else:
return trim_decode_tokens
def _decode_chatml(
tokens: List[int],
*,
stop_words: List[str],
eod_token_ids: List[int],
tokenizer: PreTrainedTokenizer,
raw_text_len: int,
context_length: int,
verbose: bool = False,
return_end_reason: bool = False,
errors: str='replace'
):
end_reason = f"Gen length {len(tokens)}"
eod_token_idx = context_length
for eod_token_idx in range(context_length, len(tokens)):
if tokens[eod_token_idx] in eod_token_ids:
end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
break
trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:]
if verbose:
print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:])
print("\nRaw Generate:", trim_decode_tokens)
print("\nEnd Reason:", end_reason)
for stop_word in stop_words:
trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
trim_decode_tokens = trim_decode_tokens.strip()
if verbose:
print("\nGenerate:", trim_decode_tokens)
if return_end_reason:
return trim_decode_tokens, end_reason
else:
return trim_decode_tokens
def decode_tokens(
tokens: Union[torch.LongTensor, TokensType],
tokenizer: PreTrainedTokenizer,
raw_text_len: int,
context_length: int,
chat_format: str,
verbose: bool = False,
return_end_reason: bool = False,
errors: str="replace",
) -> str:
if torch.is_tensor(tokens):
tokens = tokens.cpu().numpy().tolist()
if chat_format == "chatml":
return _decode_chatml(
tokens,
stop_words=[],
eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id],
tokenizer=tokenizer,
raw_text_len=raw_text_len,
context_length=context_length,
verbose=verbose,
return_end_reason=return_end_reason,
errors=errors,
)
elif chat_format == "raw":
return _decode_default(
tokens,
stop_words=["<|endoftext|>"],
eod_words=["<|endoftext|>"],
tokenizer=tokenizer,
raw_text_len=raw_text_len,
verbose=verbose,
return_end_reason=return_end_reason,
errors=errors,
)
else:
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
class StopWordsLogitsProcessor(LogitsProcessor):
"""
:class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration.
Args:
stop_words_ids (:obj:`List[List[int]]`):
List of list of token ids of stop ids. In order to get the tokens of the words
that should not appear in the generated text, use :obj:`tokenizer(bad_word,
add_prefix_space=True).input_ids`.
eos_token_id (:obj:`int`):
The id of the `end-of-sequence` token.
"""
def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
raise ValueError(
f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
)
if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
raise ValueError(
f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
)
if any(
any(
(not isinstance(token_id, (int, np.integer)) or token_id < 0)
for token_id in stop_word_ids
)
for stop_word_ids in stop_words_ids
):
raise ValueError(
f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
)
self.stop_words_ids = list(
filter(
lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
)
)
self.eos_token_id = eos_token_id
for stop_token_seq in self.stop_words_ids:
assert (
len(stop_token_seq) > 0
), "Stop words token sequences {} cannot have an empty list".format(
stop_words_ids
)
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
stopped_samples = self._calc_stopped_samples(input_ids)
for i, should_stop in enumerate(stopped_samples):
if should_stop:
scores[i, self.eos_token_id] = float(2**15)
return scores
def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
if len(tokens) == 0:
# if bad word tokens is just one token always ban it
return True
elif len(tokens) > len(prev_tokens):
# if bad word tokens are longer then prev input_ids they can't be equal
return False
elif prev_tokens[-len(tokens) :].tolist() == tokens:
# if tokens match
return True
else:
return False
def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
stopped_samples = []
for prev_input_ids_slice in prev_input_ids:
match = False
for stop_token_seq in self.stop_words_ids:
if self._tokens_match(prev_input_ids_slice, stop_token_seq):
# if tokens do not match continue
match = True
break
stopped_samples.append(match)
return stopped_samples
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
"""This function has been mostly taken from huggingface conversational
ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
conversational-ai-with-transfer-learning-2d818ac26313"""
if top_k > 0:
# Remove all tokens with a probability less than the
# last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
# Cconvert to 1D
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token
# above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
for i in range(sorted_indices.size(0)):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = filter_value
return logits
def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
return (1 - boolean) * val1 + boolean * val2
{
"pad_token": "<|endoftext|>"
}
# TextMonkey: An OCR-Free Large Multimodal Model for Understanding Document
<br>
<p align="center">
<img src="https://v1.ax1x.com/2024/04/13/7ySD7w.png" width="300"/>
<p>
> [**TextMonkey: An OCR-Free Large Multimodal Model for Understanding Document**](https://arxiv.org/abs/2403.04473)<br>
> Yuliang Liu, Biao Yang, Qiang Liu, Zhang Li, Zhiyin Ma, Shuo Zhang, Xiang Bai <br>
[![arXiv](https://img.shields.io/badge/Arxiv-2403.04473-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2403.04473)
[![Source_code](https://img.shields.io/badge/Code-Available-white)](monkey_model/text_monkey/README.md)
[![Demo](https://img.shields.io/badge/Demo-blue)](http://vlrlab-monkey.xyz:7684/)
[![Data](https://img.shields.io/badge/Data-yellow)](https://huggingface.co/datasets/MelosY/TextMonkey_Data)
[![Model Weight](https://img.shields.io/badge/Model_Weight-gray)](https://www.modelscope.cn/models/lvskiller/TextMonkey)
-----
**TextMonkey** is a multi-modal large model (LMM) focused on text-related tasks, including document question answering and scene text question answering. Compared with Monkey, TextMonkey has been improved in many aspects: by using zero-initialized Shifted Window Attention, TextMonkey realizes information interaction between windows at a higher input resolution; by calculating similarity to filter out important image features, not only can it simplify the input, but it can also improve the performance of the model. Furthermore, TextMonkey enhances interpretability and reduces hallucinations by extending multiple text-related tasks and incorporating location information into responses. At the same time, after fine-tuning, TextMonkey can also have the ability to understand user instructions and click on the corresponding location in the APP Agent, demonstrating its huge potential for downstream applications.
# TODO
- [x] Open source code, weight, and data
- [ ] Support training using 3090 GPUs (24Gb video memory)
- [ ] Improve Chinese language proficiency
- [ ] TextMonkey with different LLMs
# Model Zoo
TextMonkey was trained using 8 A800 GPUs on a dataset of 400k data, requiring approximately 1 day and 6 hours of training time. It is capable of running inference on a 3090 GPU.
| Method | LLM | STVQA | TextVQA | OCRVQA | DocVQA | InfoVQA | ChartQA | FUNSD | SROIE | POIE | OCRBench |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| BLIP2-OPT-6.7B | OPT-6.7B | 20.9 | 23.5 | 9.7 | 3.2 | 11.3 | 3.4 | 0.2 | 0.1 | 0.3 | 235 |
| mPLUG-Owl | LLaMA-7B | 30.5 | 34.0 | 21.1 | 7.4 | 20.0 | 7.9 | 0.5 | 1.7 | 2.5 | 297 |
| InstructBLIP | Vircuna-7B | 27.4 | 29.1 | 41.3 | 4.5 | 16.4 | 5.3 | 0.2 | 0.6 | 1.0 | 276 |
| LLaVAR | Vircuna-7B | 39.2 | 41.8 | 24.0 | 12.3 | 16.5 | 12.2 | 0.5 | 5.2 | 5.9 | 346 |
| BLIVA | Vircuna-7B | 32.1 | 33.3 | 50.7 | 5.8 | 23.6 | 8.7 | 0.2 | 0.7 | 2.1 | 291 |
| mPLUG-Owl2 | LLaMA-7B | 49.8 | 53.9 | 58.7 | 17.9 | 18.9 | 19.4 | 1.4 | 3.2 | 9.9 | 366 |
| LLaVA1.5-7B$ | Vircuna-7B | 38.1 | 38.7 | 58.1 | 8.5 | 14.7 | 9.3 | 0.2 | 1.7 | 2.5 | 297 |
| TGDoc$ | Vircuna-7B | 36.3 | 46.2 | 37.2 | 9.0 | 12.8 | 12.7 | 1.4 | 3.0 | 22.2 | - |
| UniDoc | Vircuna-7B | 35.2 | 46.2 | 36.8 | 7.7 | 14.7 | 10.9 | 1.0 | 2.9 | 5.1 | - |
| DocPedia | Vircuna-7B | 45.5 | 60.2 | 57.2 | 47.1 | 15.2 | 46.9 | 9.9 | 21.4 | 39.9 | - |
| Monkey | Qwen-7B | 54.7 | 64.3 | 64.4 | 50.1 | 25.8 | 54.0 | 24.1 | 41.9 | 19.9 | 514 |
| InternVL | - | 62.2 | 59.8 | 30.5 | 28.7 | 23.6 | 45.6 | 6.5 | 26.4 | 25.9 | 517 |
| InternLM-XComposer2 | InternLM-7B | 59.6 | 62.2 | 49.6 | 39.7 | 28.6 | 51.6 | 15.3 | 34.2 | 49.3 | 511 |
| TextMonkey (~400k data)| Qwen-7B | 61.8 | 65.9 | 71.3 | 64.3 | 28.2 | 58.2 | 32.3 | 47.0 | 27.9 | 561 |
| TextMonkey (~500k data) | Qwen-7B | 61.2 | 64.3 | 72.2 | 66.7 | 28.6 | 59.9 | 42.9 | 46.2 | 32.0 | 558 |
## Environment
```python
conda create -n textmonkey python=3.10
conda activate textmonkey
git clone https://github.com/Yuliang-Liu/Monkey.git
cd ./Monkey
pip install -r requirements.txt
```
## Evaluate
We also offer TextMonkey's model testing code, which you can explore above. You can execute the training code through executing:
```python
bash eval/eval_doc.sh
```
## Train
Execute the training code:
```python
bash finetune/finetune_textmonkey.sh
```
## Cases
TextMonkey can accurately locate and recognize text in both scene images and document images. In addition, the natural image in (a), the document in (b), the diagram in (c), and the table in (d) all demonstrate TextMonkey’s ability to identify, understand, and locate text information in a variety of scenarios.
<br>
<p align="center">
<img src="https://v1.ax1x.com/2024/04/13/7ySSXO.png" width="700"/>
<p>
<br>
TextMonkey has shown strong feasibility as an agent for smartphone applications. After fine-tuning using 15k user click data from the Rico dataset, TextMonkey was able to understand user intent and click the corresponding icon.
<br>
<p align="center">
<img src="https://v1.ax1x.com/2024/04/13/7ySOV6.png" width="700"/>
<p>
<br>
## Citing TextMonkey
If you wish to refer to the baseline results published here, please use the following BibTeX entries:
```BibTeX
@article{liu2024textmonkey,
title={TextMonkey: An OCR-Free Large Multimodal Model for Understanding Document},
author={Liu, Yuliang and Yang, Biao and Liu, Qiang and Li, Zhang and Ma, Zhiyin and Zhang, Shuo and Bai, Xiang},
journal={arXiv preprint arXiv:2403.04473},
year={2024}
}
```
## Copyright
We welcome suggestions to help us improve the TextMonkey. For any query, please contact Dr. Yuliang Liu: ylliu@hust.edu.cn. If you find something interesting, please also feel free to share with us through email or open an issue.
import math
from typing import Callable, Tuple
import torch
def self_soft_matching(
metric: torch.Tensor,
r: int,):
t = metric.shape[1]
with torch.no_grad():
metric = metric / metric.norm(dim=-1, keepdim=True)
a, b = metric[..., :, :], metric[..., :, :]
scores = a @ b.transpose(-1, -2) # a_lxb_l
b,_,_ = scores.shape
scores_diag = torch.tril(torch.ones(t,t))*2
scores_diag = scores_diag.expand(b, -1, -1).to(metric.device)
scores = scores-scores_diag
node_max, node_idx = scores.max(dim=-1) # a中最相似的点
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] # a中相似度排序并得到idx,降序
unm_idx = edge_idx[..., t-r:, :] # Unmerged Tokens # 后面的就是不merge的
def merge(src: torch.Tensor) -> torch.Tensor:
n, t1, c = src.shape
unm = src.gather(dim=-2, index=unm_idx.expand(n, r, c))
unm_idx_new = unm_idx
all_idx = unm_idx_new
all_max,all_idx_idx = torch.sort(all_idx,dim=1)
return unm.gather(dim=-2, index=all_idx_idx.expand(n, r, c))
return merge
from einops import rearrange, repeat
from einops_exts import rearrange_many
from torch import einsum
import torch.nn as nn
import torch
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
#Resample model
from einops import rearrange, repeat
from einops_exts import rearrange_many
from torch import einsum
from monkey_model.text_monkey.merge import *
class FeedForward(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.norm = nn.LayerNorm(in_features)
self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
self.scale = nn.Parameter(torch.ones(1))
with torch.no_grad():
nn.init.kaiming_uniform_(self.fc1.weight, a=math.sqrt(5))
nn.init.zeros_(self.fc2.weight)
def forward(self, x):
x = self.norm(x)
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
x = self.scale*x
return x
class Block(nn.Module):
def __init__(self, input_size,output_size):
super().__init__()
self.fc_1 = nn.Linear(input_size, output_size)
self.norm = nn.LayerNorm(output_size)
def forward(self, x):
x = self.fc_1(x)
x = self.norm(x)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm_media = nn.LayerNorm(dim)
self.norm_latents = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
x = self.norm_media(x)
latents = self.norm_latents(latents)
h = self.heads
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
q = q * self.scale
# attention
sim = einsum("... i d, ... j d -> ... i j", q, k)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
out = einsum("... i j, ... j d -> ... i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)", h=h)
return self.to_out(out)
class PerceiverResampler(nn.Module):
def __init__(
self,
*,
in_dim=1024,
out_dim=4096,
depth=1,
dim_head=128,
heads=8,
visual_tokens_num=512,
ff_mult=4,
):
super().__init__()
self.downsample = nn.Linear(out_dim,in_dim,bias=False)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=in_dim, dim_head=dim_head, heads=heads),
FeedForward(in_features=in_dim, hidden_features=in_dim,out_features=out_dim),
]
)
)
def forward(self, x,r=0):
B,L,C = x.shape
merge = self_soft_matching(x, r) # Replace with your features and r value
latents = merge(x)
down_x = self.downsample(x)
down_latent = self.downsample(latents)
for attn, ff in self.layers:
down_latent = attn(down_x, down_latent)
latents = ff(down_latent) + latents
return latents
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
{
"auto_map": {
"AutoTokenizer": [
"tokenization_qwen.QWenTokenizer",
null
]
},
"clean_up_tokenization_spaces": true,
"model_max_length": 2048,
"padding_side": "right",
"tokenizer_class": "QWenTokenizer"
}
This diff is collapsed.
CUDA_VISIBLE_DEVICES=3,4 python demo_textmonkey.py -c /home/wanglch/projects/TextMonkey/TextMonkey_base
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment