Unverified Commit b63ace06 authored by Melos's avatar Melos Committed by GitHub
Browse files

TextMonkey (#75)



* textmonkey

* textmonkey code

* Delete README_cn.md

---------
Co-authored-by: default avatarYuliang Liu <34134635+Yuliang-Liu@users.noreply.github.com>
parent ab58e6f0
<p align="left">
English</a>&nbsp | &nbsp<a href="README_cn.md">中文</a>&nbsp
</p>
<br><br>
# Monkey: Image Resolution and Text Label Are Important Things for Large Multi-modal Models
<br>
<p align="center">
<img src="images/Logo-Monkey2.gif" width="300"/>
<img src="https://v1.ax1x.com/2024/04/13/7ySieU.png" width="500" style="margin-bottom: 0.2;"/>
<p>
<br>
<div align="center">
Zhang Li*, Biao Yang*, Qiang Liu, Zhiyin Ma, Shuo Zhang, Jingxu Yang, Yabo Sun, Yuliang Liu†, Xiang Bai†
</div>
<div align="center">
<strong>Huazhong University of Science and Technology, Kingsoft</strong>
</div>
<p align="center">
<a href="https://arxiv.org/abs/2311.06607">Paper</a>&nbsp&nbsp | &nbsp&nbsp<a href="http://vlrlab-monkey.xyz:7681/">Demo_chat</a>&nbsp&nbsp | &nbsp&nbsp<a href="http://huggingface.co/datasets/echo840/Detailed_Caption">Detailed Caption</a>&nbsp&nbsp | &nbsp&nbsp<a href="http://huggingface.co/echo840/Monkey">Model Weight</a>&nbsp&nbsp | <a href="https://www.wisemodel.cn/models/HUST-VLRLab/Monkey/">Model Weight in wisemodel</a>&nbsp&nbsp| <a href="https://wisemodel.cn/space/gradio/huakeMonkey">Demo in wisemodel</a>&nbsp&nbsp
<!-- | &nbsp&nbsp<a href="Monkey Model">Monkey Models</a>&nbsp | &nbsp <a href="http://huggingface.co/echo840/Monkey">Tutorial</a> -->
</p>
<h3 align="center"> <a href="https://arxiv.org/abs/2311.06607">Monkey: Image Resolution and Text Label Are Important Things for Large Multi-modal Models</a></h3>
<h2></h2>
<h5 align="center"> Please give us a star ⭐ for the latest update. </h5>
<h5 align="center">
-----
[![arXiv](https://img.shields.io/badge/Arxiv-2311.06607-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2311.06607)
[![License](https://img.shields.io/badge/License-Apache%202.0-yellow)](https://github.com/Yuliang-Liu/Monkey/blob/main/LICENSE)
[![GitHub issues](https://img.shields.io/github/issues/Yuliang-Liu/Monkey?color=critical&label=Issues)](https://github.com/Yuliang-Liu/Monkey/issues?q=is%3Aopen+is%3Aissue)
[![GitHub closed issues](https://img.shields.io/github/issues-closed/Yuliang-Liu/Monkey?color=success&label=Issues)](https://github.com/Yuliang-Liu/Monkey/issues?q=is%3Aissue+is%3Aclosed) <br>
</h5>
<details open><summary>💡 Monkey series projects:✨. </summary><p>
<!-- may -->
>[CVPR'24] [**Monkey: Image Resolution and Text Label Are Important Things for Large Multi-modal Models**](https://arxiv.org/abs/2311.06607)<br>
> Zhang Li, Biao Yang, Qiang Liu, Zhiyin Ma, Shuo Zhang, Jingxu Yang, Yabo Sun, Yuliang Liu, Xiang Bai <br>
[![Paper](https://img.shields.io/badge/Paper-CVPR'24_Highlight-red)](README.md)
[![Source_code](https://img.shields.io/badge/Code-Available-white)](README.md)
[![Demo](https://img.shields.io/badge/Demo-blue)](http://vlrlab-monkey.xyz:7681/)
[![Detailed Caption](https://img.shields.io/badge/Detailed_Caption-yellow)](http://huggingface.co/datasets/echo840/Detailed_Caption)
[![Model Weight](https://img.shields.io/badge/Model_Weight-gray)](http://huggingface.co/echo840/Monkey)
[![Model Weight in Wisemodel](https://img.shields.io/badge/Model_Weight_in_Wisemodel-gray)](https://www.wisemodel.cn/models/HUST-VLRLab/Monkey/)
[![Demo in Wisemodel](https://img.shields.io/badge/Demo_in_Wisemodel-blue)](https://wisemodel.cn/space/gradio/huakeMonkey)
> [**TextMonkey: An OCR-Free Large Multimodal Model for Understanding Document**](https://arxiv.org/abs/2403.04473)<br>
> Yuliang Liu, Biao Yang, Qiang Liu, Zhang Li, Zhiyin Ma, Shuo Zhang, Xiang Bai <br>
[![arXiv](https://img.shields.io/badge/Arxiv-2403.04473-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2403.04473)
[![Source_code](https://img.shields.io/badge/Code-Available-white)](monkey_model/text_monkey/README.md)
[![Demo](https://img.shields.io/badge/Demo-blue)](http://vlrlab-monkey.xyz:7684/)
[![Data](https://img.shields.io/badge/Data-yellow)](https://www.modelscope.cn/datasets/lvskiller/TextMonkey_data)
[![Model Weight](https://img.shields.io/badge/Model_Weight-gray)](https://www.modelscope.cn/models/lvskiller/TextMonkey)
**Monkey** brings a training-efficient approach to effectively improve the input resolution capacity up to 896 x 1344 pixels without pretraining from the start. To bridge the gap between simple text labels and high input resolution, we propose a multi-level description generation method, which automatically provides rich information that can guide the model to learn the contextual association between scenes and objects. With the synergy of these two designs, our model achieved excellent results on multiple benchmarks. By comparing our model with various LMMs, including GPT4V, our model demonstrates promising performance in image captioning by paying attention to textual information and capturing fine details within the images; its improved input resolution also enables remarkable performance in document images with dense text.
## News
* ```2024.4.13 ``` 🚀 Sourced code for [TextMonkey](monkey_model/text_monkey/README.md) is released.
* ```2024.4.5 ``` 🚀 Monkey is nominated as CVPR 2024 Highlight paper.
* ```2024.3.8 ``` 🚀 We introduce [TextMonkey](https://arxiv.org/abs/2403.04473), using only public document data, with a ([Demo](http://vlrlab-monkey.xyz:7684/)) available and code forthcoming.
* ```2024.2.27 ``` 🚀 Monkey is accepted by CVPR 2024. The [paper](https://arxiv.org/abs/2311.06607) has been carefully updated according to the valuable comments.
* ```2024.3.8 ``` 🚀 We release the paper [TextMonkey](https://arxiv.org/abs/2403.04473).
* ```2024.2.27 ``` 🚀 Monkey is accepted by CVPR 2024.
* ```2024.1.3 ``` 🚀 Release the basic data generation pipeline. [Data Generation](./data_generation)
* ```2023.12.21``` 🚀 The JSON file used for Monkey training is provided.
* ```2023.12.16``` 🚀 Monkey can be trained using 8 NVIDIA 3090 GPUs. See subsection [train](#Train) for details.
* ```2023.11.25``` 🚀 Monkey-chat demo is released.
* ```2023.11.06``` 🚀 Monkey [paper](https://arxiv.org/abs/2311.06607) is released.
* ```2023.11.06``` 🚀 We release the paper [Monkey](https://arxiv.org/abs/2311.06607).
## 🐳 Model Zoo
## Spotlights
- **Contextual associations.** We introduce a multilevel description generation method that improves the model’s ability to grasp the relationships among multiple targets and more effectively utilize common knowledge in generating text descriptions.
- **Support resolution up to 1344 x 896.** Surpassing the standard 448 x 448 resolution typically employed for LMMs, this significant increase in resolution augments the ability to discern and understand unnoticeable or tightly clustered objects and dense text.
- **Enhanced general performance.** We carried out testing across 18 diverse datasets, leading to a very competitive performance by our Monkey model in tasks such as Image Captioning, General Visual Question Answering, Scene Text-centric Visual Question Answering, and Document-oriented Visual Question Answering. In particular, during qualitative evaluations centered on dense text question answering, Monkey has shown promising results, comparing with GPT4V
Monkey-Chat
| Model|Language Model|Transformers(HF) |MMBench-Test|CCBench|MME|SeedBench_IMG|MathVista-MiniTest|HallusionBench-Avg|AI2D Test|OCRBench|
|---------------|---------|-----------------------------------------|---|---|---|---|---|---|---|---|
|Monkey-Chat|Qwev-7B|[🤗echo840/Monkey-Chat](https://huggingface.co/echo840/Monkey-Chat)|72.4|48|1887.4|68.9|34.8|39.3|68.5|534|
## Environment
......@@ -54,25 +69,31 @@ pip install -r requirements.txt
```
## Train
We also offer Monkey's model definition and training code, which you can explore above. You can execute the training code through executing `finetune_ds_debug.sh`.
The json file used for Monkey training can be downloaded at [Link](https://drive.google.com/file/d/18z_uQTe8Jq61V5rgHtxOt85uKBodbvw1/view?usp=sharing).
**ATTENTION:** Specify the path to your training data, which should be a json file consisting of a list of conversations.
Inspired by Qwen-VL, we freeze the Large Language Model (LLM) and introduce LoRA into four linear layers ```"c_attn", "attn.c_proj", "w1", "w2"``` for training. This step makes it possible to train Monkey using 8 NVIDIA 3090 GPUs. The specific implementation code is in ```modeling_qwen_nvdia3090.py```.
- Add LoRA: You need to replace the contents of ```modeling_qwen.py``` with the contents of ```modeling_qwen_nvdia3090.py```.
- Freeze LLM: You need to freeze other modules except LoRA and Resampler modules in ```finetune_multitask.py```.
## Inference
Run the inference code:
```
python ./inference.py --model_path MODEL_PATH --image_path IMAGE_PATH --question YOUR_QUESTION
```
## Demo
Demo is fast and easy to use. Simply uploading an image from your desktop or phone, or capture one directly.
[Demo_chat](http://vlrlab-monkey.xyz:7681) is also launched as an upgraded version of the original demo to deliver an enhanced interactive experience.
Before 14/11/2023, we have observed that for some random pictures Monkey can achieve more accurate results than GPT4V.
<br>
<p align="center">
<img src="images/demo_gpt4v_compare4.png" width="900"/>
<p>
<br>
Before 31/1/2024, Monkey-chat achieved the fifth rank in the Multimodal Model category on [OpenCompass](https://opencompass.org.cn/home).
<br>
<p align="center">
<img src="images/Monkey-rank.png" width="900"/>
<p>
<br>
We also provide the source code and the model weight for the original demo, allowing you to customize certain parameters for a more unique experience. The specific operations are as follows:
1. Make sure you have configured the [environment](#environment).
2. You can choose to use the demo offline or online:
......@@ -89,6 +110,21 @@ We also provide the source code and the model weight for the original demo, allo
python demo.py -c echo840/Monkey
```
Before 14/11/2023, we have observed that for some random pictures Monkey can achieve more accurate results than GPT4V.
<br>
<p align="center">
<img src="https://v1.ax1x.com/2024/04/13/7yS2yq.jpg" width="666"/>
<p>
<br>
Before 31/1/2024, Monkey-chat achieved the fifth rank in the Multimodal Model category on [OpenCompass](https://opencompass.org.cn/home).
<br>
<p align="center">
<img src="https://v1.ax1x.com/2024/04/13/7yShXL.jpg" width="666"/>
<p>
<br>
## Dataset
The json file used for Monkey training can be downloaded at [Link](https://drive.google.com/file/d/18z_uQTe8Jq61V5rgHtxOt85uKBodbvw1/view?usp=sharing).
......@@ -97,7 +133,7 @@ The data from our multi-level description generation method is now open-sourced
<br>
<p align="center">
<img src="images/detailed_caption.png" width="1000"/>
<img src="https://v1.ax1x.com/2024/04/13/7yS6Ss.jpg" width="666"/>
<p>
<br>
......@@ -148,122 +184,28 @@ ds_collections = {
bash eval/eval.sh 'EVAL_PTH' 'SAVE_NAME'
```
## Train
We also offer Monkey's model definition and training code, which you can explore above. You can execute the training code through executing `finetune_ds_debug.sh`.
The json file used for Monkey training can be downloaded at [Link](https://drive.google.com/file/d/18z_uQTe8Jq61V5rgHtxOt85uKBodbvw1/view?usp=sharing).
**ATTENTION:** Specify the path to your training data, which should be a json file consisting of a list of conversations.
Inspired by Qwen-VL, we freeze the Large Language Model (LLM) and introduce LoRA into four linear layers ```"c_attn", "attn.c_proj", "w1", "w2"``` for training. This step makes it possible to train Monkey using 8 NVIDIA 3090 GPUs. The specific implementation code is in ```modeling_qwen_nvdia3090.py```.
- Add LoRA: You need to replace the contents of ```modeling_qwen.py``` with the contents of ```modeling_qwen_nvdia3090.py```.
- Freeze LLM: You need to freeze other modules except LoRA and Resampler modules in ```finetune_multitask.py```.
## Inference
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
checkpoint = "echo840/Monkey"
model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map='cuda', trust_remote_code=True).eval()
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
tokenizer.padding_side = 'left'
tokenizer.pad_token_id = tokenizer.eod_id
img_path = ""
question = ""
query = f'<img>{img_path}</img> {question} Answer: ' #VQA
# query = f'<img>{img_path}</img> Generate the detailed caption in English: ' #detailed caption
input_ids = tokenizer(query, return_tensors='pt', padding='longest')
attention_mask = input_ids.attention_mask
input_ids = input_ids.input_ids
pred = model.generate(
input_ids=input_ids.cuda(),
attention_mask=attention_mask.cuda(),
do_sample=False,
num_beams=1,
max_new_tokens=512,
min_new_tokens=1,
length_penalty=1,
num_return_sequences=1,
output_hidden_states=True,
use_cache=True,
pad_token_id=tokenizer.eod_id,
eos_token_id=tokenizer.eod_id,
)
response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
print(response)
```
## Performance
<br>
<p align="center">
<img src="images/radar_1.png" width="800"/>
<p>
<br>
## Cases
Our model can accurately describe the details in the image.
<br>
<p align="center">
<img src="images/caption_1.png" width="700"/>
<p>
<br>
Our model performs particularly well in dense text question answering tasks. For example, in the dense text of item labels, Monkey can accurately answer various information about the item, and its performance is very impressive compared to other LMMs including GPT4V.
<br>
<p align="center">
<img src="images/dense_text_1.png" width="700"/>
<p>
<br>
<br>
<p align="center">
<img src="images/dense_text_2.png" width="700"/>
<p>
<br>
Monkey also performs equally well in daily life scenes. It can complete various Q&A and caption tasks and describe various details in the image in detail, even the inconspicuous watermark.
<br>
<p align="center">
<img src="images/qa_caption.png" width="700"/>
<p>
<br>
We qualitatively compare with existing LMMs including GPT4V, Qwen-vl, etc, which shows inspiring results. One can have a try using the provided demo.
<br>
<p align="center">
<img src="images/compare.png" width="800"/>
<p>
<br>
## Citing Monkey
If you wish to refer to the baseline results published here, please use the following BibTeX entries:
```BibTeX
@article{li2023monkey,
@inproceedings{li2023monkey,
title={Monkey: Image Resolution and Text Label Are Important Things for Large Multi-modal Models},
author={Li, Zhang and Yang, Biao and Liu, Qiang and Ma, Zhiyin and Zhang, Shuo and Yang, Jingxu and Sun, Yabo and Liu, Yuliang and Bai, Xiang},
journal={arXiv preprint arXiv:2311.06607},
year={2023}
booktitle={proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
year={2024}
}
@article{liu2024textmonkey,
title={TextMonkey: An OCR-Free Large Multimodal Model for Understanding Document},
author={Liu, Yuliang and Yang, Biao and Liu, Qiang and Li, Zhang and Ma, Zhiyin and Zhang, Shuo and Bai, Xiang},
journal={arXiv preprint arXiv:2403.04473},
year={2024}
}
```
## Acknowledgement
[Qwen-VL](https://github.com/QwenLM/Qwen-VL.git): the codebase we built upon. Thanks for the authors of Qwen for providing the framework.
[Qwen-VL](https://github.com/QwenLM/Qwen-VL.git), [LLAMA](https://github.com/meta-llama/llama), [LLaVA](https://github.com/haotian-liu/LLaVA), [OpenCompass](https://github.com/open-compass/opencompass), [InternLM](https://github.com/InternLM/InternLM).
## Copyright
......
<p align="left">
中文</a>&nbsp | &nbsp<a href="README.md">English</a>&nbsp
</p>
<br><br>
# Monkey: 图像分辨率和高质量文本描述对于大型多模态模型很重要
<br>
<p align="center">
<img src="images/logo_monkey.png" width="300"/>
<p>
<br>
<div align="center">
Zhang Li*, Biao Yang*, Qiang Liu, Zhiyin Ma, Shuo Zhang, Jingxu Yang, Yabo Sun, Yuliang Liu†, Xiang Bai†
</div>
<div align="center">
<strong>华中科技大学,金山</strong>
</div>
<p align="center">
<a href="https://arxiv.org/abs/2311.06607">论文</a>&nbsp&nbsp | &nbsp&nbsp<a href="http://vlrlab-monkey.xyz:7681/">对话演示</a>&nbsp&nbsp | &nbsp&nbsp<a href="http://huggingface.co/datasets/echo840/Detailed_Caption">详细描述</a>&nbsp&nbsp | &nbsp&nbsp<a href="http://huggingface.co/echo840/Monkey">模型权重</a>&nbsp&nbsp | <a href="https://www.wisemodel.cn/models/HUST-VLRLab/Monkey/">始智AI</a>&nbsp&nbsp
<!-- | &nbsp&nbsp<a href="Monkey Model">Monkey Models</a>&nbsp | &nbsp <a href="http://huggingface.co/echo840/Monkey">Tutorial</a> -->
</p>
-----
**Monkey** 引入了一种高效的训练方法,可以有效地将输入分辨率提高到 896 x 1344 ,同时不需要从开始进行预训练。为了弥合简单的文本描述和高输入分辨率之间的差距,Monkey 还提出了一种多级描述生成方法,该方法自动提供丰富的信息,可以指导模型学习场景和对象之间的关联。通过这两种设计的协同作用,Monkey 在多个基准测试中取得了优异的结果。与各种多模态大模型(包括 GPT4V)相比,Monkey 通过关注文本信息并捕获图像中的精细细节,在图像字幕方面表现出了良好的性能;高输入分辨率还可以使模型在具有密集文本的文档图像中展现出出色的性能。
## 新闻
* ```2023.12.21``` 🚀🚀🚀 Monkey 训练使用的 JSON 文件发布。
* ```2023.12.16``` 🚀🚀🚀 Monkey 可以使用 8 NVIDIA 3090 GPUs 进行训练。详见[训练](#训练)
* ```2023.11.25``` 🚀🚀🚀 Monkey [对话演示](http://vlrlab-monkey.xyz:7681/)发布。
* ```2023.11.06``` 🚀🚀🚀 Monkey [论文](https://arxiv.org/abs/2311.06607)发布。
## 贡献
- **上下文关联。** Monkey在回答问题时展现了更有效地推断目标之间关系的卓越能力,从而能够提供更全面和更有洞察力的结果。
- **支持高达 1344 x 896 的分辨率。** Monkey支持的分辨率的显着超越了 LMM 通常采用的标准 448 x 448 分辨率,增强了辨别和理解不明显或紧密聚集的对象和密集文本的能力。
- **性能提高** 在 16 个不同的数据集上进行了测试,结果表明 Monkey 在图像字幕、一般视觉问答、以文本为中心的视觉问答和面向文档的视觉问答等任务中表现出色。
## 环境
```python
conda create -n monkey python=3.9
conda activate monkey
git clone https://github.com/Yuliang-Liu/Monkey.git
cd ./Monkey
pip install -r requirements.txt
```
## 演示
演示快速且易于使用。只需从桌面或手机上传图像,或直接拍照即可。
为了提供更好的交互体验,我们还推出了原始演示的升级版本[对话演示](http://27.18.93.119:7681/)
我们观察到对于一些随机图片Monkey可以取得比GPT4V更准确的结果。
<br>
<p align="center">
<img src="images/demo_gpt4v_compare4.png" width="900"/>
<p>
<br>
我们还提供原始演示的源代码和模型权重,允许您自定义某些参数以获得更独特的体验。具体操作如下:
1. 确保您配置好了[环境](#环境).
2. 您可以选择在线或离线方法运行demo.py:
- **离线:**
- 下载[模型权重](http://huggingface.co/echo840/Monkey).
- 修改`demo.py`文件里的`DEFAULT_CKPT_PATH="pathto/Monkey"`为您下载的模型权重的路径。
- 用下面的命令运行演示:
```
python demo.py
```
- **在线:**
- 使用下面的命令加载模型并运行演示:
```
python demo.py -c echo840/Monkey
```
## 数据集
Monkey训练使用的json文件可以在[链接](https://drive.google.com/file/d/18z_uQTe8Jq61V5rgHtxOt85uKBodbvw1/view?usp=sharing)获取。
我们开源了多级描述生成方法生成的数据。您可以在这里下载:[详细描述数据](https://huggingface.co/datasets/echo840/Detailed_Caption).
## 评估
我们在`evaluate_vqa.py`文件中提供了 14 个视觉问答(VQA)数据集的评估代码,以便于快速验证结果。具体操作如下:
1. 确保您配置好了[环境](#环境).
2. 修改`sys.path.append("pathto/Monkey")`为该项目的地址。
3. 准备需要评估的数据集。
4. 运行评估代码。
以ESTVQA数据集的评测为例:
- 按照下面的格式准备数据集:
```
├── data
| ├── estvqa
| ├── test_image
| ├── {image_path0}
| ├── {image_path1}
| ·
| ·
| ├── estvqa.jsonl
```
- 注释文件`.jsonl`每行的格式示例:
```
{"image": "data/estvqa/test_image/011364.jpg", "question": "What is this store?", "answer": "pizzeria", "question_id": 0}
```
- 修改这个字典`ds_collections`
```
ds_collections = {
'estvqa_test': {
'test': 'data/estvqa/estvqa.jsonl',
'metric': 'anls',
'max_new_tokens': 100,
},
...
}
```
- 运行下面的命令:
```
bash eval/eval.sh 'EVAL_PTH' 'SAVE_NAME'
```
## 训练
我们还提供 Monkey 的模型定义和训练代码,您可以在上面进行探索。 通过执行`finetune_ds_debug.sh`来进行训练。
Monkey训练使用的json文件可以在[链接](https://drive.google.com/file/d/18z_uQTe8Jq61V5rgHtxOt85uKBodbvw1/view?usp=sharing)获取。
**注意:** 需要指定训练数据的路径,该路径应该是包含对话列表的 json 文件。
受 Qwen-VL 的启发,我们冻结了大型语言模型(LLM),并将 LoRA 引入四个线性层```"c_attn"、"attn.c_proj"、"w1"、"w2"```进行训练。 这使得使用 8 个 NVIDIA 3090 GPU 训练 Monkey 成为可能。
- 添加LoRA:需要将```modeling_qwen.py```的内容替换为```modeling_qwen_nvdia3090.py```的内容
- 冻结LLM:需要在```finetune_multitask.py```中冻结除LoRA和Resampler模块的其他模块
## 推理
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
checkpoint = "echo840/Monkey"
model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map='cuda', trust_remote_code=True).eval()
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
tokenizer.padding_side = 'left'
tokenizer.pad_token_id = tokenizer.eod_id
img_path = ""
question = ""
query = f'<img>{img_path}</img> {question} Answer: ' #VQA
# query = f'<img>{img_path}</img> Generate the detailed caption in English: ' #detailed caption
input_ids = tokenizer(query, return_tensors='pt', padding='longest')
attention_mask = input_ids.attention_mask
input_ids = input_ids.input_ids
pred = model.generate(
input_ids=input_ids.cuda(),
attention_mask=attention_mask.cuda(),
do_sample=False,
num_beams=1,
max_new_tokens=512,
min_new_tokens=1,
length_penalty=1,
num_return_sequences=1,
output_hidden_states=True,
use_cache=True,
pad_token_id=tokenizer.eod_id,
eos_token_id=tokenizer.eod_id,
)
response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
print(response)
```
## 性能
<br>
<p align="center">
<img src="images/radar_1.png" width="800"/>
<p>
<br>
## 展示
Monkey 可以准确地描述图像中的细节。
<br>
<p align="center">
<img src="images/caption_1.png" width="700"/>
<p>
<br>
Monkey 在密集文本问答任务中表现特别好。 例如,在商品标签的密集文本中,Monkey 可以准确回答有关该商品的各种信息,与包括 GPT4V 在内的其他 LMMs 相比,Monkey的性能非常突出。
<br>
<p align="center">
<img src="images/dense_text_1.png" width="700"/>
<p>
<br>
<br>
<p align="center">
<img src="images/dense_text_2.png" width="700"/>
<p>
<br>
Monkey 在日常生活场景中也表现同样出色。 它可以完成各种问答和字幕任务,详细描述图像中的各种细节,甚至是不显眼的水印。
<br>
<p align="center">
<img src="images/qa_caption.png" width="700"/>
<p>
<br>
与现有的 LMMs(包括 GPT4V、Qwen-vl 等)进行定性比较,Moneky 显示出令人鼓舞的结果。 您可以尝试使用我们提供的演示。
<br>
<p align="center">
<img src="images/compare.png" width="800"/>
<p>
<br>
## Citing Monkey
如果您觉得我们的论文和代码对研究有帮助,请考虑star和引用:
```BibTeX
@article{li2023monkey,
title={Monkey: Image Resolution and Text Label Are Important Things for Large Multi-modal Models},
author={Li, Zhang and Yang, Biao and Liu, Qiang and Ma, Zhiyin and Zhang, Shuo and Yang, Jingxu and Sun, Yabo and Liu, Yuliang and Bai, Xiang},
journal={arXiv preprint arXiv:2311.06607},
year={2023}
}
```
## Acknowledgement
我们在 [Qwen-VL](https://github.com/QwenLM/Qwen-VL.git) 的基础上构建代码。感谢 Qwen 的作者提供的框架。
## Copyright
我们欢迎提出建议来帮助我们改进 Monkey。如有任何疑问,请联系刘禹良博士:ylliu@hust.edu.cn。如果您发现有趣的事,也请随时通过电子邮件与我们分享或提出问题。谢谢!
EVAL_PTH=$1
SAVE_NAME=$2
python -m torch.distributed.launch --use-env --nproc_per_node ${NPROC_PER_NODE:-8} --nnodes ${WORLD_SIZE:-1} --node_rank ${RANK:-0} --master_addr ${MASTER_ADDR:-127.0.0.1} --master_port ${MASTER_PORT:-12345} eval/evaluate_vqa_doc.py --checkpoint $EVAL_PTH --batch-size 8 --num-workers 4 --save_name $SAVE_NAME
import argparse
import itertools
import json
import os
import random
import time
from functools import partial
from typing import Optional
import sys
import torch
from tqdm import tqdm
from vqa import VQA
from vqa_eval import VQAEval
sys.path.append("pathto/Monkey/")
from monkey_model.modeling_textmonkey import TextMonkeyLMHeadModel
from monkey_model.tokenization_qwen import QWenTokenizer
import numpy as np
from pathlib import Path
import re
time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
from monkey_model.configuration_qwen import QWenConfig
from monkey_model.configuration_monkey import MonkeyConfig
ds_collections = {
'docvqa_test': {
'train': 'data/docvqa/train.jsonl',
'test': 'data/docvqa/test_ans.jsonl',
'metric': 'accuracy',
'max_new_tokens': 100,
},
'ocrvqa_test': {
'train': 'data/ocrvqa/ocrvqa_test.jsonl',
'test': 'data/ocrvqa/ocrvqa_test.jsonl',
'metric': 'accuracy',
'max_new_tokens': 100,
},
'textvqa_val': {
'train': 'data/textvqa/textvqa_train.jsonl',
'test': 'data/textvqa/textvqa_val.jsonl',
'question': 'data/textvqa/textvqa_val_questions.json',
'annotation': 'data/textvqa/textvqa_val_annotations.json',
'metric': 'accuracy',
'max_new_tokens': 10,
},
'chartqa_ureader': {
'train': 'data/chartqa/train_augmented.jsonl',
'test': 'data/chartqa/chartqa_ureader.jsonl',
'metric': 'accuracy',
'max_new_tokens': 100,
},
'FUNSD': {
'train': 'data/chartqa/train_augmented.jsonl',
'test': 'data/FUNSD/FUNSD_test.jsonl',
'metric': 'accuracy',
'max_new_tokens': 100,
},
'SROIE_test': {
'train': 'data/chartqa/train_augmented.jsonl',
'test': 'data/SROIE/SROIE_test.jsonl',
'metric': 'accuracy',
'max_new_tokens': 100,
},
'POIE': {
'train': 'data/chartqa/train_augmented.jsonl',
'test': 'data/POIE/POIE_test.jsonl',
'metric': 'accuracy',
'max_new_tokens': 100,
},
'textvqa_val': {
'train': 'data/textvqa/textvqa_train.jsonl',
'test': 'data/textvqa/textvqa_val.jsonl',
'question': 'data/textvqa/textvqa_val_questions.json',
'annotation': 'data/textvqa/textvqa_val_annotations.json',
'metric': 'accuracy',
'max_new_tokens': 100,
},
'infovqa_test': {
'train': 'data/infographicVQA/infovqa.jsonl',
'test': 'data/infographicVQA/infovqa_test.jsonl',
'metric': 'accuracy',
'max_new_tokens': 100,
},
'stvqa_test': {
'train': 'data/STVQA/stvqa.jsonl',
'test': 'data/STVQA/stvqa.jsonl',
'metric': 'accuracy',
'max_new_tokens': 100,
},
}
def levenshtein_distance(s1, s2):
if len(s1) > len(s2):
s1, s2 = s2, s1
distances = range(len(s1) + 1)
for i2, c2 in enumerate(s2):
distances_ = [i2+1]
for i1, c1 in enumerate(s1):
if c1 == c2:
distances_.append(distances[i1])
else:
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
distances = distances_
return distances[-1]
def normANLS(s1,s2):
dist = levenshtein_distance(s1.lower().strip(),s2.lower().strip())
length = max(len(s1),len(s2))
value = 0.0 if length == 0 else float(dist) / float(length)
return value
def evaluateANLS(ans_list):
anls_threshold = 0.5
anls_list = []
for predict_pair in ans_list:
answer = predict_pair["answer"].strip()
gt_list = predict_pair["annotation"]
value_list = []
for gt_single in gt_list:
if gt_single.strip().lower() in answer.strip().lower():
value_list.append(0)
value_list.append(normANLS(gt_single,answer))
question_result = 1 - min(value_list)
if (question_result < anls_threshold) :
question_result = 0
anls_list.append(question_result)
return np.mean(anls_list)
# https://github.com/google-research/pix2struct/blob/main/pix2struct/metrics.py#L81
def relaxed_correctness(target: str,
prediction: str,
max_relative_change: float = 0.05) -> bool:
"""Calculates relaxed correctness.
The correctness tolerates certain error ratio defined by max_relative_change.
See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1:
“Following Methani et al. (2020), we use a relaxed accuracy measure for the
numeric answers to allow a minor inaccuracy that may result from the automatic
data extraction process. We consider an answer to be correct if it is within
5% of the gold answer. For non-numeric answers, we still need an exact match
to consider an answer to be correct.”
Args:
target: Target string.
prediction: Predicted string.
max_relative_change: Maximum relative change.
Returns:
Whether the prediction was correct given the specified tolerance.
"""
def _to_float(text: str) -> Optional[float]:
try:
if text.endswith('%'):
# Convert percentages to floats.
return float(text.rstrip('%')) / 100.0
else:
return float(text)
except ValueError:
return None
prediction_float = _to_float(prediction)
target_float = _to_float(target)
if prediction_float is not None and target_float:
relative_change = abs(prediction_float -
target_float) / abs(target_float)
return relative_change <= max_relative_change
else:
return prediction.lower() == target.lower()
def evaluate_relaxed_accuracy(entries):
scores = []
for elem in entries:
if isinstance(elem['annotation'], str):
elem['annotation'] = [elem['annotation']]
score = max([
relaxed_correctness(elem['answer'].strip(), ann)
for ann in elem['annotation']
])
scores.append(score)
return sum(scores) / len(scores)
def evaluate_exact_match_accuracy(entries):
scores = []
for elem in entries:
if isinstance(elem['annotation'], str):
elem['annotation'] = [elem['annotation']]
quad_blocks = re.findall(r'<point>(.*?)</point>', elem['answer'])
for quad_block in quad_blocks:
elem['answer'] = elem['answer'].replace('<point>' + quad_block + '</point>', '')
quad_blocks = re.findall(r'<box>(.*?)</box>', elem['answer'])
for quad_block in quad_blocks:
elem['answer'] = elem['answer'].replace('<box>' + quad_block + '</box>', '')
score = max([
(1.0 if
(ann.strip().lower() in elem['answer'].strip().lower() ) else 0.0)
for ann in elem['annotation']
])
scores.append(score)
return sum(scores) / len(scores)
def collate_fn(batches, tokenizer):
image_paths = [_['image_path'] for _ in batches]
questions = [_['question'] for _ in batches]
question_ids = [_['question_id'] for _ in batches]
annotations = [_['annotation'] for _ in batches]
input_ids = tokenizer(questions, return_tensors='pt', padding='longest')
return image_paths,question_ids, input_ids.input_ids, input_ids.attention_mask, annotations
class VQADataset(torch.utils.data.Dataset):
def __init__(self, train, test, prompt, few_shot):
self.test = open(test).readlines()
self.prompt = prompt
self.few_shot = few_shot
if few_shot > 0:
self.train = open(train).readlines()
def __len__(self):
return len(self.test)
def __getitem__(self, idx):
data = json.loads(self.test[idx].strip())
image, question, question_id, annotation = data['image'], data[
'question'], data['question_id'], data.get('answer', None)
few_shot_prompt = ''
if self.few_shot > 0:
few_shot_samples = random.sample(self.train, self.few_shot)
for sample in few_shot_samples:
sample = json.loads(sample.strip())
few_shot_prompt += self.prompt.format(
sample['image'],
sample['question']) + f" {sample['answer']}"
return {
'image_path':image,
'question': few_shot_prompt + self.prompt.format(image, question),
'question_id': question_id,
'annotation': annotation
}
class InferenceSampler(torch.utils.data.sampler.Sampler):
def __init__(self, size):
self._size = int(size)
assert size > 0
self._rank = torch.distributed.get_rank()
self._world_size = torch.distributed.get_world_size()
self._local_indices = self._get_local_indices(size, self._world_size,
self._rank)
@staticmethod
def _get_local_indices(total_size, world_size, rank):
shard_size = total_size // world_size
left = total_size % world_size
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
begin = sum(shard_sizes[:rank])
end = min(sum(shard_sizes[:rank + 1]), total_size)
return range(begin, end)
def __iter__(self):
yield from self._local_indices
def __len__(self):
return len(self._local_indices)
def evaluate(model,tokenizer,prompt,args,dataset_name):
dataset_info = ds_collections[dataset_name]
dataset = VQADataset(
train=dataset_info['train'],
test=dataset_info['test'],
prompt=prompt,
few_shot=args.few_shot,
)
len_dataset = len(dataset)
if torch.distributed.get_rank() == 0:
print(f"there have {len(dataset)} in {dataset_name}")
dataloader = torch.utils.data.DataLoader(
dataset=dataset,
sampler=InferenceSampler(len_dataset),
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False,
collate_fn=partial(collate_fn, tokenizer=tokenizer),
)
outputs = []
for image_paths,question_ids, input_ids, attention_mask,annotations in tqdm(dataloader):
pred = model.generate(
input_ids=input_ids.cuda(),
attention_mask=attention_mask.cuda(),
do_sample=False,
num_beams=1,
max_new_tokens=dataset_info['max_new_tokens'],
min_new_tokens=1,
length_penalty=1,
num_return_sequences=1,
output_hidden_states=True,
use_cache=True,
pad_token_id=tokenizer.eod_id,
eos_token_id=tokenizer.eod_id,
)
answers = [
tokenizer.decode(_[input_ids.size(1):].cpu(),
skip_special_tokens=True).strip() for _ in pred
]
answers = [answer.replace("<|endoftext|>","") for answer in answers]
questions = [
tokenizer.decode(_[:input_ids.size(1)].cpu(),
skip_special_tokens=False).strip() for _ in pred
]
questions = [question.replace("<|endoftext|>","") for question in questions]
print(questions[0],answers[0])
for image_path,question,question_id, answer, annotation in zip(image_paths,questions,question_ids, answers,
annotations):
if dataset_info['metric'] == 'vqa_score':
outputs.append({
'image_path':image_path,
'question_id': question_id,
'answer': answer,
'question':question
})
elif dataset_info['metric'] == 'anls':
if isinstance(annotation,list):
outputs.append({
'image_path':image_path,
'questionId': question_id,
'answer': answer,
'annotation': annotation,
'question':question
})
else:
outputs.append({
'image_path':image_path,
'questionId': question_id,
'answer': answer,
'annotation': [annotation],
'question':question
})
elif dataset_info['metric'] == 'accuracy':
outputs.append({
'image_path':image_path,
'questionId': question_id,
'answer': answer,
'annotation': annotation,
'question':question
})
elif dataset_info['metric'] == 'accuracy_recog':
outputs.append({
'image_path':image_path,
'questionId': question_id,
'answer': answer,
'annotation': annotation,
'question':question
})
elif dataset_name in ["chartqa_ureader"]:
outputs.append({
'image_path':image_path,
'answer': answer,
'annotation': annotation,
'question':question,
'question':question
})
else:
raise NotImplementedError
torch.distributed.barrier()
world_size = torch.distributed.get_world_size()
merged_outputs = [None for _ in range(world_size)]
torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
merged_outputs = [json.loads(_) for _ in merged_outputs]
merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
if torch.distributed.get_rank() == 0:
print(f"Evaluating {dataset_name} ...")
results_file = f'{dataset_name}.json'
root_path = os.path.join("result_doc",args.save_name,time_prefix)
Path(root_path).mkdir(exist_ok=True,parents=True)
results_file = os.path.join(root_path,results_file)
json.dump(merged_outputs, open(results_file, 'w',encoding="utf-8"), ensure_ascii=False,indent=2)
if dataset_info['metric'] == 'vqa_score':
vqa = VQA(dataset_info['annotation'],dataset_info['question'])
results = vqa.loadRes(
resFile=results_file,
quesFile=dataset_info['question'])
vqa_scorer = VQAEval(vqa, results, n=2)
question_id_list = [item["question_id"]for item in merged_outputs]
vqa_scorer.evaluate(question_id_list)
print(vqa_scorer.accuracy)
results_file = results_file.replace("json","txt")
with open(results_file,"w") as fp:
fp.write(dataset_name+"\n")
fp.writelines(str(vqa_scorer.accuracy["overall"])+'\n')
elif dataset_info['metric'] == 'anls':
anls_res = evaluateANLS(merged_outputs)
print(anls_res)
results_file = results_file.replace("json","txt")
with open(results_file,"w") as fp:
fp.write(dataset_name+"\n")
fp.writelines(str(anls_res)+'\n')
elif dataset_info['metric'] == 'relaxed_accuracy':
print({
'relaxed_accuracy': evaluate_relaxed_accuracy(merged_outputs)
})
results_file = results_file.replace("json","txt")
with open(results_file,"w") as fp:
fp.write(dataset_name+"\n")
fp.writelines(str(evaluate_relaxed_accuracy(merged_outputs))+'\n')
elif dataset_info['metric'] == 'accuracy':
if 'gqa' in dataset_name:
for entry in merged_outputs:
response = entry['answer']
response = response.strip().split('.')[0].split(
',')[0].split('!')[0].lower()
if 'is ' in response:
response = response.split('is ')[1]
if 'are ' in response:
response = response.split('are ')[1]
if 'a ' in response:
response = response.split('a ')[1]
if 'an ' in response:
response = response.split('an ')[1]
if 'the ' in response:
response = response.split('the ')[1]
if ' of' in response:
response = response.split(' of')[0]
response = response.strip()
entry['answer'] = response
acc = evaluate_exact_match_accuracy(merged_outputs)
print({'accuracy': acc})
results_file = results_file.replace("json","txt")
with open(results_file,"w") as fp:
fp.write(dataset_name+"\n")
fp.writelines(str(acc)+'\n')
torch.distributed.barrier()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', type=str, default='')
parser.add_argument('--dataset', type=str, default='')
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--num-workers', type=int, default=1)
parser.add_argument('--few-shot', type=int, default=0)
parser.add_argument('--seed', type=int, default=3407)
parser.add_argument("--save_name",type=str,default="test")
args = parser.parse_args()
torch.distributed.init_process_group(
backend='nccl',
world_size=int(os.getenv('WORLD_SIZE', '1')),
rank=int(os.getenv('RANK', '0')),
)
torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
config = MonkeyConfig.from_pretrained(
args.checkpoint,
trust_remote_code=True,
)
print(config)
model = TextMonkeyLMHeadModel.from_pretrained(args.checkpoint,
config=config,
device_map='cuda', trust_remote_code=True).eval()
tokenizer = QWenTokenizer.from_pretrained(args.checkpoint,
trust_remote_code=True)
tokenizer.padding_side = 'left'
tokenizer.pad_token_id = tokenizer.eod_id
tokenizer.IMG_TOKEN_SPAN = config.visual["n_queries"]
random.seed(args.seed)
for k,_ in ds_collections.items():
# prompt = '<img>{}</img> {} Provide the location coordinates of the answer when answering the question. Answer:'
# prompt = '<img>{}</img> Convert the document in this image to json format. Answer: '
prompt = '<img>{}</img> {} Answer:'
evaluate(model,tokenizer,prompt,args,k)
......@@ -10,42 +10,17 @@
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 100,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
"wall_clock_breakdown": false,
"zero_optimization": {
"stage": 2,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto"
}
}
\ No newline at end of file
#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1
DIR=`pwd`
GPUS_PER_NODE=8
NNODES=1
NODE_RANK=0
MASTER_ADDR=localhost
MASTER_PORT=6001
MODEL="Qwen/Qwen-VL" # We use the first version of Qwen-VL
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
# See the section for finetuning in README for more information.
DATA="pathto/data"
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT"
torchrun $DISTRIBUTED_ARGS finetune_multitask_dialouge_doc.py\
--model_name_or_path $MODEL \
--data_path $DATA \
--bf16 True \
--fix_vit True \
--output_dir output_model \
--num_train_epochs 1 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 10 \
--learning_rate 1e-5 \
--weight_decay 0.1 \
--adam_beta2 0.95 \
--warmup_ratio 0.02 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--report_to "none" \
--model_max_length 2048 \
--gradient_checkpointing \
--lazy_preprocess True \
--deepspeed finetune/ds_config_zero2.json \
--image_size 896 \
--image_width 896 \
--image_height 896 \
--add_window true \
--use_global true \
--resampler true \
--remain 512
# This code is based on the revised code from fastchat based on tatsu-lab/stanford_alpaca.
from dataclasses import dataclass, field
import json
import math
import logging
import os
from typing import Dict, Optional, List
import torch
from torch.utils.data import Dataset
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
import transformers
from transformers import Trainer, GPTQConfig, deepspeed
from transformers.trainer_pt_utils import LabelSmoother
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from accelerate.utils import DistributedType
from monkey_model.modeling_textmonkey import TextMonkeyLMHeadModel
from monkey_model.tokenization_qwen import QWenTokenizer
from monkey_model.configuration_monkey import MonkeyConfig
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="Qwen/Qwen-7B")
@dataclass
class DataArguments:
data_path: str = field(
default=None, metadata={"help": "Path to the training data."}
)
eval_data_path: str = field(
default=None, metadata={"help": "Path to the evaluation data."}
)
lazy_preprocess: bool = False
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=8192,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
use_lora: bool = False
fix_vit: bool = True
fix_llm: bool = False
fix_resampler: bool = False
image_size: int = 448
image_width: int = 896
image_height: int = 896
n_queries: int = 256
lora_repeat_num : int = 0
add_window: bool = False
use_global: bool = True
resampler: bool = False
remain:int = 512
@dataclass
class LoraArguments:
lora_r: int = 16
lora_alpha: int = 32
lora_dropout: float = 0.05
lora_target_modules: List[str] = field(
default_factory=lambda: ["in_proj","out_proj","c_fc"] ##["in_proj","out_proj","c_fc"]
)
lora_weight_path: str = ""
lora_bias: str = "none"
q_lora: bool = False
def maybe_zero_3(param):
if hasattr(param, "ds_id"):
assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
return to_return
local_rank = None
def rank0_print(*args):
if local_rank == 0:
print(*args)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str, bias="none"):
"""Collects the state dict and dump to disk."""
# check if zero3 mode enabled
if deepspeed.is_deepspeed_zero3_enabled():
state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
else:
state_dict = trainer.model.state_dict()
if trainer.args.should_save and trainer.args.local_rank == 0:
trainer._save(output_dir, state_dict=state_dict)
def format_tokenizer(tokenizer, message, return_target=False, label=False):
_input_ids = tokenizer(message).input_ids
input_ids = _input_ids
if return_target:
if label:
target = input_ids
else:
target = [IGNORE_TOKEN_ID] * (len(_input_ids))
return input_ids, target
else:
return input_ids
def preprocess(
source,
tokenizer,
max_len,
system_message: str = "You are a helpful assistant.",
padding=True
):
'''
[{"from": "user", "value": f"<img>{file_abs_path}</img>" + prefix,}, {"from": "assistant", "value": label}]
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
<img>image_path<imgpad><imgpad><imgpad><imgpad> ... </img>Describe the image concisely.<|im_end|>
<|im_start|>assistant
A man on a surfboard on a wave in the ocean.<|im_end|>
'''
# Apply prompt templates
input_ids, targets = [], []
message_l = []
for conv in source:
message_l.append(conv["value"])
for i, message in enumerate(message_l):
try:
_input_ids, _target = format_tokenizer(tokenizer, message, return_target=True, label=True if i %2==1 else False) # <img> 有些text会有img标签,所以使用<img>作为特殊id有问题,标签数量不对等会报错
except Exception as e:
print(e)
continue
input_ids += _input_ids
targets += _target
if i%2==1:
input_ids += [-1]
targets += [tokenizer.pad_token_id]
assert len(_input_ids) == len(_input_ids)
if padding:
input_ids += [tokenizer.pad_token_id] * (max_len - len(input_ids))
targets += [IGNORE_TOKEN_ID] * (max_len - len(targets))
targets = targets[:max_len]
input_ids = input_ids[:max_len]
input_ids = torch.tensor(input_ids, dtype=torch.int)
targets = torch.tensor(targets, dtype=torch.int)
attention_mask=input_ids.ne(tokenizer.pad_token_id)
input_ids[input_ids == -1 ] = tokenizer.pad_token_id
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=attention_mask,
)
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, max_len: int):
super(SupervisedDataset, self).__init__()
rank0_print("Formatting inputs...")
sources = [example["conversations"] for example in raw_data]
data_dict = preprocess(sources, tokenizer, max_len)
self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]
self.attention_mask = data_dict["attention_mask"]
def __len__(self):
return len(self.input_ids)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(
input_ids=self.input_ids[i],
labels=self.labels[i],
attention_mask=self.attention_mask[i],
)
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, max_len: int):
super(LazySupervisedDataset, self).__init__()
self.tokenizer = tokenizer
self.max_len = max_len
rank0_print("Formatting inputs...Skip in lazy mode")
self.tokenizer = tokenizer
self.raw_data = raw_data
self.cached_data_dict = {}
def __len__(self):
return len(self.raw_data)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
if i in self.cached_data_dict:
return self.cached_data_dict[i]
ret = preprocess(self.raw_data[i]["conversations"], self.tokenizer, self.max_len)
ret = dict(
input_ids=ret["input_ids"],
labels=ret["labels"],
attention_mask=ret["attention_mask"],
)
self.cached_data_dict[i] = ret
return ret
def make_supervised_data_module(
tokenizer: transformers.PreTrainedTokenizer, data_args, max_len,
) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
dataset_cls = (
LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset
)
rank0_print("Loading data...")
train_json = json.load(open(data_args.data_path, "r"))
train_dataset = dataset_cls(train_json, tokenizer=tokenizer, max_len=max_len)
if data_args.eval_data_path:
eval_json = json.load(open(data_args.eval_data_path, "r"))
eval_dataset = dataset_cls(eval_json, tokenizer=tokenizer, max_len=max_len)
else:
eval_dataset = None
return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)
def print_trainable_params(model: torch.nn.Module):
trainable_params, all_param = 0, 0
for param in model.parameters():
num_params = param.numel()
all_param += num_params
if param.requires_grad:
trainable_params += num_params
rank0_print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param))
for name,p in model.named_parameters():
if p.requires_grad and "transformer.h" not in name and "lora" not in name:
if "lora" in name:
if "39" not in name:
continue
rank0_print(name)
# for name,p in model.named_parameters():
# rank0_print(name,p.device)
def train():
global local_rank
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments, LoraArguments)
)
(
model_args,
data_args,
training_args,
lora_args,
) = parser.parse_args_into_dataclasses()
if getattr(training_args, 'deepspeed', None) and getattr(lora_args, 'q_lora', False):
training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
compute_dtype = (
torch.float16
if training_args.fp16
else (torch.bfloat16 if training_args.bf16 else torch.float32)
)
local_rank = training_args.local_rank
device_map = None
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if lora_args.q_lora:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
logging.warning(
"FSDP or ZeRO3 are not incompatible with QLoRA."
)
# Set RoPE scaling factor
config = MonkeyConfig.from_pretrained(
"monkey_model",
cache_dir=training_args.cache_dir,
trust_remote_code=True,
)
config.visual["image_size"]= (training_args.image_height,training_args.image_width)
config.visual["n_queries"]= training_args.n_queries
config.visual["lora_repeat_num"]= training_args.lora_repeat_num
config.visual["add_window"]= training_args.add_window
config.visual["use_global"]= training_args.use_global
config.visual["resampler"]= training_args.resampler
config.visual["r"]= training_args.remain
rank0_print(config)
config.use_cache = False
# Load model and tokenizer
rank0_print("loading base model")
model = TextMonkeyLMHeadModel.from_pretrained(
model_args.model_name_or_path,
config=config,
cache_dir=training_args.cache_dir,
device_map=device_map,
trust_remote_code=True,
quantization_config=GPTQConfig(
bits=4, disable_exllama=True
)
if training_args.use_lora and lora_args.q_lora
else None,
ignore_mismatched_sizes=True
)
tokenizer = QWenTokenizer.from_pretrained(
"monkey_model",
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
trust_remote_code=True,
)
tokenizer.IMG_TOKEN_SPAN = training_args.n_queries
if training_args.resampler:
tokenizer.IMG_TOKEN_SPAN =training_args.remain
if training_args.use_global:
tokenizer.IMG_TOKEN_SPAN += training_args.n_queries
tokenizer.pad_token_id = tokenizer.eod_id
rank0_print(tokenizer.IMG_TOKEN_SPAN)
config.visual["n_queries"]= tokenizer.IMG_TOKEN_SPAN
if not training_args.use_lora:
if training_args.fix_vit and hasattr(model,'transformer') and hasattr(model.transformer,'visual'):
model.transformer.visual.requires_grad_(False)
if not training_args.fix_resampler and hasattr(model.transformer.visual,'attn_pool'):
model.transformer.visual.attn_pool.requires_grad_(True)
model.transformer.visual.ln_post.requires_grad_(True)
model.transformer.visual.proj.requires_grad_(True)
if hasattr(model.transformer.visual,'downresampler'):
model.transformer.visual.downresampler.requires_grad_(True)
for k,v in model.named_parameters():
if "lora" in k :
v.requires_grad_(True)
for k,v in model.named_parameters():
if "window_attention" in k :
v.requires_grad_(True)
if training_args.fix_llm and hasattr(model,'transformer') and hasattr(model.transformer,'h'):
model.transformer.h.requires_grad_(False)
model.transformer.wte.requires_grad_(False)
model.transformer.ln_f.requires_grad_(False)
model.lm_head.requires_grad_(False)
if training_args.use_lora:
if lora_args.q_lora or "chat" in model_args.model_name_or_path.lower():
modules_to_save = None
else:
modules_to_save = []
lora_config = LoraConfig(
r=lora_args.lora_r,
lora_alpha=lora_args.lora_alpha,
target_modules=lora_args.lora_target_modules,
lora_dropout=lora_args.lora_dropout,
bias=lora_args.lora_bias,
task_type="CAUSAL_LM",
modules_to_save=modules_to_save # This argument serves for adding new tokens.
)
model = get_peft_model(model, lora_config)
if training_args.gradient_checkpointing:
model.enable_input_require_grads()
# Load data
data_module = make_supervised_data_module(
tokenizer=tokenizer, data_args=data_args, max_len=training_args.model_max_length
)
print_trainable_params(model)
# Start trainner
trainer = Trainer(
model=model, tokenizer=tokenizer, args=training_args, **data_module
)
trainer.train()
trainer.save_state()
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir, bias=lora_args.lora_bias)
import numpy as np
import random
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
os.environ["PYTHONHASHSEED"] = str(seed)
if __name__ == "__main__":
setup_seed(46)
train()
This image diff could not be displayed because it is too large. You can view the blob instead.
This image diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment