Unverified Commit b63ace06 authored by Melos's avatar Melos Committed by GitHub
Browse files

TextMonkey (#75)



* textmonkey

* textmonkey code

* Delete README_cn.md

---------
Co-authored-by: default avatarYuliang Liu <34134635+Yuliang-Liu@users.noreply.github.com>
parent ab58e6f0
from transformers import AutoModelForCausalLM, AutoTokenizer
import argparse
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="echo840/Monkey-Chat") #echo840/Monkey-Chat echo840/Monkey
parser.add_argument("--image_path", type=str, default=None)
parser.add_argument("--question", type=str, default=None)
args = parser.parse_args()
checkpoint = args.model_path
model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map='cuda', trust_remote_code=True).eval()
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
tokenizer.padding_side = 'left'
tokenizer.pad_token_id = tokenizer.eod_id
img_path = args.image_path
question = args.question
if question == "Generate the detailed caption in English:" and "Monkey-Chat" not in checkpoint:
query = f'<img>{img_path}</img> Generate the detailed caption in English: ' #detailed caption
else:
query = f'<img>{img_path}</img> {question} Answer: ' #VQA
input_ids = tokenizer(query, return_tensors='pt', padding='longest')
attention_mask = input_ids.attention_mask
input_ids = input_ids.input_ids
pred = model.generate(
input_ids=input_ids.cuda(),
attention_mask=attention_mask.cuda(),
do_sample=False,
num_beams=1,
max_new_tokens=512,
min_new_tokens=1,
length_penalty=1,
num_return_sequences=1,
output_hidden_states=True,
use_cache=True,
pad_token_id=tokenizer.eod_id,
eos_token_id=tokenizer.eod_id,
)
response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
print(f"Question: {question} Answer: {response}")
...@@ -45,7 +45,7 @@ from .qwen_generation_utils import ( ...@@ -45,7 +45,7 @@ from .qwen_generation_utils import (
StopWordsLogitsProcessor, StopWordsLogitsProcessor,
) )
from .visual import VisionTransformer from .visual import VisionTransformer
SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -71,6 +71,142 @@ apply_rotary_emb_func = None ...@@ -71,6 +71,142 @@ apply_rotary_emb_func = None
rms_norm = None rms_norm = None
# use flash attnetion, if your machine do not support it, you can close it
use_flash_attention = True
def _import_flash_attn():
global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func
try:
from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func
apply_rotary_emb_func = __apply_rotary_emb_func
except ImportError:
logger.warn(
"Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency "
"https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary"
)
# try:
# from flash_attn.ops.rms_norm import rms_norm as __rms_norm
# rms_norm = __rms_norm
# except ImportError:
# logger.warn(
# "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency "
# "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm"
# )
try:
import flash_attn
if not hasattr(flash_attn, '__version__'):
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
else:
if int(flash_attn.__version__.split(".")[0]) >= 2:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func
else:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
flash_attn_unpadded_func = __flash_attn_unpadded_func
except ImportError:
logger.warn(
"Warning: import flash_attn fail, please install FlashAttention to get higher efficiency "
"https://github.com/Dao-AILab/flash-attention"
)
class FlashSelfAttention(torch.nn.Module):
def __init__(
self,
causal=False,
softmax_scale=None,
attention_dropout=0.0,
):
super().__init__()
assert flash_attn_unpadded_func is not None, (
"Please install FlashAttention first, " "e.g., with pip install flash-attn"
)
assert (
rearrange is not None
), "Please install einops first, e.g., with pip install einops"
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def unpad_input(self, hidden_states, attention_mask):
valid_mask = attention_mask.squeeze(1).squeeze(1).eq(0)
seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
hidden_states = hidden_states[indices]
return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
def pad_input(self, hidden_states, indices, batch, seqlen):
output = torch.zeros(batch * seqlen, *hidden_states.shape[1:], device=hidden_states.device,
dtype=hidden_states.dtype)
output[indices] = hidden_states
return rearrange(output, '(b s) ... -> b s ...', b=batch)
def forward(self, q, k, v, attention_mask=None):
assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v)))
assert all((i.is_cuda for i in (q, k, v)))
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = k.shape[1]
seqlen_out = seqlen_q
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=q.device,
)
if batch_size > 1 and attention_mask is not None:
k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask)
if q.size(0) == v.size(0):
q = q[indices_k]
cu_seqlens_q = cu_seqlens_k
seqlen_q = seqlen_k
v = v[indices_k]
else:
cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * seqlen_k,
step=seqlen_k,
dtype=torch.int32,
device=q.device,
)
if self.training:
assert seqlen_k == seqlen_q
is_causal = self.causal
dropout_p = self.dropout_p
else:
is_causal = seqlen_q == seqlen_k
dropout_p = 0
output = flash_attn_unpadded_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
seqlen_q,
seqlen_k,
dropout_p,
softmax_scale=self.softmax_scale,
causal=is_causal,
)
if batch_size > 1 and attention_mask is not None and seqlen_q == seqlen_k:
output = self.pad_input(output, indices_k, batch_size, seqlen_out)
else:
new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:]
output = output.view(new_shape)
return output
# Copied from transformers.models.bart.modeling_bart._make_causal_mask # Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask( def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
...@@ -144,6 +280,9 @@ class QWenAttention(nn.Module): ...@@ -144,6 +280,9 @@ class QWenAttention(nn.Module):
self.logn_tensor = torch.tensor(logn_list)[None, :, None, None] self.logn_tensor = torch.tensor(logn_list)[None, :, None, None]
self.attn_dropout = nn.Dropout(config.attn_dropout_prob) self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
if use_flash_attention:
_import_flash_attn()
self.core_attention_flash = FlashSelfAttention(causal=True, attention_dropout=0)
def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None): def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) attn_weights = torch.matmul(query, key.transpose(-1, -2))
...@@ -297,12 +436,18 @@ class QWenAttention(nn.Module): ...@@ -297,12 +436,18 @@ class QWenAttention(nn.Module):
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :] logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
query = query * logn_tensor.expand_as(query) query = query * logn_tensor.expand_as(query)
query = query.permute(0, 2, 1, 3) if self.training and SUPPORT_TORCH2 and use_flash_attention:
key = key.permute(0, 2, 1, 3) attn_output = self.core_attention_flash(query,key,value)
value = value.permute(0, 2, 1, 3) attn_weight = None
attn_output, attn_weight = self._attn( else:
query, key, value, registered_causal_mask, attention_mask, head_mask
) query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
attn_output, attn_weight = self._attn(
query, key, value, registered_causal_mask, attention_mask, head_mask
)
context_layer = self._merge_heads( context_layer = self._merge_heads(
attn_output, self.num_heads, self.head_dim attn_output, self.num_heads, self.head_dim
) )
...@@ -410,6 +555,10 @@ class QWenPreTrainedModel(PreTrainedModel): ...@@ -410,6 +555,10 @@ class QWenPreTrainedModel(PreTrainedModel):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
def _init_weights(self, module): def _init_weights(self, module):
'''
There is no need to re_init
'''
return
"""Initialize the weights.""" """Initialize the weights."""
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
......
import importlib
import math
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.cuda.amp import autocast
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
from transformers.generation.logits_process import LogitsProcessorList
if TYPE_CHECKING:
from transformers.generation.streamers import BaseStreamer
from transformers.generation.utils import GenerateOutput
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
try:
from einops import rearrange
except ImportError:
rearrange = None
from torch import nn
from monkey_model.modeling_qwen import QWenModel,QWenPreTrainedModel,QWenLMHeadModel
from monkey_model.text_monkey.visual_text import VisionTransformer
SUPPORT_CUDA = torch.cuda.is_available()
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
logger = logging.get_logger(__name__)
class TextMonkeyModel(QWenModel):
def __init__(self, config):
super().__init__(config)
self.visual = VisionTransformer(**config.visual)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
if past_key_values is None and torch.any(input_ids == self.config.visual['image_start_id']):
bos_pos = torch.where(input_ids == self.config.visual['image_start_id'])
eos_pos = torch.where(input_ids == self.config.visual['image_start_id'] + 1)
assert (bos_pos[0] == eos_pos[0]).all()
img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
images = []
for i, a, b in img_pos:
image = input_ids[i][a + 1 : b - 1].tolist()
image = image[ : image.index(self.config.visual['image_start_id'] + 2)]
images.append(bytes(image).decode('utf-8'))
if self.visual.lora_repeat_num>0:
images = self.visual.encode(images,lora_idx=self.visual.lora_repeat_num)
else:
images = self.visual.encode(images)
assert images.shape[0] == len(images)
else:
images = None
return super().forward(input_ids,
past_key_values,
attention_mask,
token_type_ids,
position_ids,
head_mask,inputs_embeds,
encoder_hidden_states,
encoder_attention_mask,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
images)
class TextMonkeyLMHeadModel(QWenLMHeadModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"]
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"]
def __init__(self, config):
super().__init__(config)
assert (
config.bf16 + config.fp16 + config.fp32 <= 1
), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true"
autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
if autoset_precision:
if SUPPORT_BF16:
logger.warn(
"The model is automatically converting to bf16 for faster inference. "
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
)
config.bf16 = True
elif SUPPORT_FP16:
logger.warn(
"The model is automatically converting to fp16 for faster inference. "
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
)
config.fp16 = True
else:
config.fp32 = True
if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".")
if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster")
if config.fp32:
if SUPPORT_BF16:
logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
elif SUPPORT_FP16:
logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
self.transformer = TextMonkeyModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.bf16:
self.transformer.bfloat16()
self.lm_head.bfloat16()
if config.fp16:
self.transformer.half()
self.lm_head.half()
self.post_init()
# TextMonkey: An OCR-Free Large Multimodal Model for Understanding Document
<br>
<p align="center">
<img src="https://v1.ax1x.com/2024/04/13/7ySD7w.png" width="300"/>
<p>
> [**TextMonkey: An OCR-Free Large Multimodal Model for Understanding Document**](https://arxiv.org/abs/2403.04473)<br>
> Yuliang Liu, Biao Yang, Qiang Liu, Zhang Li, Zhiyin Ma, Shuo Zhang, Xiang Bai <br>
[![arXiv](https://img.shields.io/badge/Arxiv-2403.04473-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2403.04473)
[![Source_code](https://img.shields.io/badge/Code-Available-white)](monkey_model/text_monkey/README.md)
[![Demo](https://img.shields.io/badge/Demo-blue)](http://vlrlab-monkey.xyz:7684/)
[![Data](https://img.shields.io/badge/Data-yellow)](https://www.modelscope.cn/datasets/lvskiller/TextMonkey_data)
[![Model Weight](https://img.shields.io/badge/Model_Weight-gray)](https://www.modelscope.cn/models/lvskiller/TextMonkey)
-----
**TextMonkey** is a multi-modal large model (LMM) focused on text-related tasks, including document question answering and scene text question answering. Compared with Monkey, TextMonkey has been improved in many aspects: by using zero-initialized Shifted Window Attention, TextMonkey realizes information interaction between windows at a higher input resolution; by calculating similarity to filter out important image features, not only can it simplify the input, but it can also improve the performance of the model. Furthermore, TextMonkey enhances interpretability and reduces hallucinations by extending multiple text-related tasks and incorporating location information into responses. At the same time, after fine-tuning, TextMonkey can also have the ability to understand user instructions and click on the corresponding location in the APP Agent, demonstrating its huge potential for downstream applications.
# TODO
- [x] Open source code, weight, and data
- [ ] Improve Chinese language proficiency
- [ ] TextMonkey with different LLMs
# Model Zoo
| Method | LLM | STVQA | TextVQA | OCRVQA | DocVQA | InfoVQA | ChartQA | FUNSD | SROIE | POIE | OCRBench |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| BLIP2-OPT-6.7B | OPT-6.7B | 20.9 | 23.5 | 9.7 | 3.2 | 11.3 | 3.4 | 0.2 | 0.1 | 0.3 | 235 |
| mPLUG-Owl | LLaMA-7B | 30.5 | 34.0 | 21.1 | 7.4 | 20.0 | 7.9 | 0.5 | 1.7 | 2.5 | 297 |
| InstructBLIP | Vircuna-7B | 27.4 | 29.1 | 41.3 | 4.5 | 16.4 | 5.3 | 0.2 | 0.6 | 1.0 | 276 |
| LLaVAR | Vircuna-7B | 39.2 | 41.8 | 24.0 | 12.3 | 16.5 | 12.2 | 0.5 | 5.2 | 5.9 | 346 |
| BLIVA | Vircuna-7B | 32.1 | 33.3 | 50.7 | 5.8 | 23.6 | 8.7 | 0.2 | 0.7 | 2.1 | 291 |
| mPLUG-Owl2 | LLaMA-7B | 49.8 | 53.9 | 58.7 | 17.9 | 18.9 | 19.4 | 1.4 | 3.2 | 9.9 | 366 |
| LLaVA1.5-7B$ | Vircuna-7B | 38.1 | 38.7 | 58.1 | 8.5 | 14.7 | 9.3 | 0.2 | 1.7 | 2.5 | 297 |
| TGDoc$ | Vircuna-7B | 36.3 | 46.2 | 37.2 | 9.0 | 12.8 | 12.7 | 1.4 | 3.0 | 22.2 | - |
| UniDoc | Vircuna-7B | 35.2 | 46.2 | 36.8 | 7.7 | 14.7 | 10.9 | 1.0 | 2.9 | 5.1 | - |
| DocPedia | Vircuna-7B | 45.5 | 60.2 | 57.2 | 47.1 | 15.2 | 46.9 | 9.9 | 21.4 | 39.9 | - |
| Monkey | Qwen-7B | 54.7 | 64.3 | 64.4 | 50.1 | 25.8 | 54.0 | 24.1 | 41.9 | 19.9 | 514 |
| InternVL | - | 62.2 | 59.8 | 30.5 | 28.7 | 23.6 | 45.6 | 6.5 | 26.4 | 25.9 | 517 |
| InternLM-XComposer2 | InternLM-7B | 59.6 | 62.2 | 49.6 | 39.7 | 28.6 | 51.6 | 15.3 | 34.2 | 49.3 | 511 |
| TextMonkey (40k data)| Qwen-7B | 61.8 | 65.9 | 71.3 | 64.3 | 28.2 | 58.2 | 32.3 | 47.0 | 27.9 | 561 |
| TextMonkey (50k data) | Qwen-7B | 61.2 | 64.3 | 72.2 | 66.7 | 28.6 | 59.9 | 42.9 | 46.2 | 32.0 | 558 |
## Environment
```python
conda create -n textmonkey python=3.10
conda activate textmonkey
git clone https://github.com/MelosY/TextMonkey.git
cd ./TextMonkey
pip install -r requirements.txt
```
## Evaluate
We also offer TextMonkey's model testing code, which you can explore above. You can execute the training code through executing:
```python
bash eval/eval_doc.sh
```
## Train
Execute the training code:
```python
bash finetune/finetune_textmonkey.sh
```
## Cases
TextMonkey can accurately locate and recognize text in both scene images and document images. In addition, the natural image in (a), the document in (b), the diagram in (c), and the table in (d) all demonstrate TextMonkey’s ability to identify, understand, and locate text information in a variety of scenarios.
<br>
<p align="center">
<img src="https://v1.ax1x.com/2024/04/13/7ySSXO.png" width="700"/>
<p>
<br>
TextMonkey has shown strong feasibility as an agent for smartphone applications. After fine-tuning using 15k user click data from the Rico dataset, TextMonkey was able to understand user intent and click the corresponding icon.
<br>
<p align="center">
<img src="https://v1.ax1x.com/2024/04/13/7ySOV6.png" width="700"/>
<p>
<br>
## Citing TextMonkey
If you wish to refer to the baseline results published here, please use the following BibTeX entries:
```BibTeX
@article{liu2024textmonkey,
title={TextMonkey: An OCR-Free Large Multimodal Model for Understanding Document},
author={Liu, Yuliang and Yang, Biao and Liu, Qiang and Li, Zhang and Ma, Zhiyin and Zhang, Shuo and Bai, Xiang},
journal={arXiv preprint arXiv:2403.04473},
year={2024}
}
```
## Copyright
We welcome suggestions to help us improve the TextMonkey. For any query, please contact Dr. Yuliang Liu: ylliu@hust.edu.cn. If you find something interesting, please also feel free to share with us through email or open an issue.
import math
from typing import Callable, Tuple
import torch
def self_soft_matching(
metric: torch.Tensor,
r: int,):
t = metric.shape[1]
with torch.no_grad():
metric = metric / metric.norm(dim=-1, keepdim=True)
a, b = metric[..., :, :], metric[..., :, :]
scores = a @ b.transpose(-1, -2) # a_lxb_l
b,_,_ = scores.shape
scores_diag = torch.tril(torch.ones(t,t))*2
scores_diag = scores_diag.expand(b, -1, -1).to(metric.device)
scores = scores-scores_diag
node_max, node_idx = scores.max(dim=-1) # a中最相似的点
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] # a中相似度排序并得到idx,降序
unm_idx = edge_idx[..., t-r:, :] # Unmerged Tokens # 后面的就是不merge的
def merge(src: torch.Tensor) -> torch.Tensor:
n, t1, c = src.shape
unm = src.gather(dim=-2, index=unm_idx.expand(n, r, c))
unm_idx_new = unm_idx
all_idx = unm_idx_new
all_max,all_idx_idx = torch.sort(all_idx,dim=1)
return unm.gather(dim=-2, index=all_idx_idx.expand(n, r, c))
return merge
from einops import rearrange, repeat
from einops_exts import rearrange_many
from torch import einsum
import torch.nn as nn
import torch
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
#Resample model
from einops import rearrange, repeat
from einops_exts import rearrange_many
from torch import einsum
from monkey_model.text_monkey.merge import *
class FeedForward(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.norm = nn.LayerNorm(in_features)
self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
self.scale = nn.Parameter(torch.ones(1))
with torch.no_grad():
nn.init.kaiming_uniform_(self.fc1.weight, a=math.sqrt(5))
nn.init.zeros_(self.fc2.weight)
def forward(self, x):
x = self.norm(x)
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
x = self.scale*x
return x
class Block(nn.Module):
def __init__(self, input_size,output_size):
super().__init__()
self.fc_1 = nn.Linear(input_size, output_size)
self.norm = nn.LayerNorm(output_size)
def forward(self, x):
x = self.fc_1(x)
x = self.norm(x)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm_media = nn.LayerNorm(dim)
self.norm_latents = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
x = self.norm_media(x)
latents = self.norm_latents(latents)
h = self.heads
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
q = q * self.scale
# attention
sim = einsum("... i d, ... j d -> ... i j", q, k)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
out = einsum("... i j, ... j d -> ... i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)", h=h)
return self.to_out(out)
class PerceiverResampler(nn.Module):
def __init__(
self,
*,
in_dim=1024,
out_dim=4096,
depth=1,
dim_head=128,
heads=8,
visual_tokens_num=512,
ff_mult=4,
):
super().__init__()
self.downsample = nn.Linear(out_dim,in_dim,bias=False)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=in_dim, dim_head=dim_head, heads=heads),
FeedForward(in_features=in_dim, hidden_features=in_dim,out_features=out_dim),
]
)
)
def forward(self, x,r=0):
B,L,C = x.shape
merge = self_soft_matching(x, r) # Replace with your features and r value
latents = merge(x)
down_x = self.downsample(x)
down_latent = self.downsample(latents)
for attn, ff in self.layers:
down_latent = attn(down_x, down_latent)
latents = ff(down_latent) + latents
return latents
This diff is collapsed.
This diff is collapsed.
...@@ -41,7 +41,7 @@ SPECIAL_TOKENS = ( ...@@ -41,7 +41,7 @@ SPECIAL_TOKENS = (
IMSTART, IMSTART,
IMEND, IMEND,
) + EXTRAS ) + EXTRAS
IMG_TOKEN_SPAN = 1280 # IMG_TOKEN_SPAN = 1024
def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
...@@ -128,6 +128,7 @@ class QWenTokenizer(PreTrainedTokenizer): ...@@ -128,6 +128,7 @@ class QWenTokenizer(PreTrainedTokenizer):
image_start_tag, image_end_tag, image_start_tag, image_end_tag,
image_pad_tag image_pad_tag
) )
self.IMG_TOKEN_SPAN = 1280
self.errors = errors # how to handle errors in decoding self.errors = errors # how to handle errors in decoding
...@@ -272,10 +273,10 @@ class QWenTokenizer(PreTrainedTokenizer): ...@@ -272,10 +273,10 @@ class QWenTokenizer(PreTrainedTokenizer):
img_tokens = img_tokens[1:-1] img_tokens = img_tokens[1:-1]
img_url = b''.join(img_tokens) img_url = b''.join(img_tokens)
out_img_tokens = list(map(self.decoder.get, img_url)) out_img_tokens = list(map(self.decoder.get, img_url))
if len(out_img_tokens) > IMG_TOKEN_SPAN: if len(out_img_tokens) > self.IMG_TOKEN_SPAN:
raise ValueError("The content in {}..{} is too long".format( raise ValueError("The content in {}..{} is too long".format(
self.image_start_tag, self.image_end_tag)) self.image_start_tag, self.image_end_tag))
out_img_tokens.extend([self.image_pad_tag] * (IMG_TOKEN_SPAN - len(out_img_tokens))) out_img_tokens.extend([self.image_pad_tag] * (self.IMG_TOKEN_SPAN - len(out_img_tokens)))
out_img_tokens = [self.image_start_tag] + out_img_tokens + [self.image_end_tag] out_img_tokens = [self.image_start_tag] + out_img_tokens + [self.image_end_tag]
return out_img_tokens return out_img_tokens
......
...@@ -436,7 +436,10 @@ class VisionTransformer(nn.Module): ...@@ -436,7 +436,10 @@ class VisionTransformer(nn.Module):
**kwargs **kwargs
): ):
super().__init__() super().__init__()
image_height, image_width = self.image_size = (image_size, image_size) if isinstance(image_size, tuple) or isinstance(image_size, list):
image_height, image_width = self.image_size = image_size
else:
image_height, image_width = self.image_size = (image_size,image_size)
patch_height, patch_width = self.patch_size = (patch_size, patch_size) patch_height, patch_width = self.patch_size = (patch_size, patch_size)
self.grid_size = (image_height // patch_height, image_width // patch_width) self.grid_size = (image_height // patch_height, image_width // patch_width)
self.output_dim = output_dim self.output_dim = output_dim
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment