Commit 18d71758 authored by wangsen's avatar wangsen
Browse files

Update README.md, requirements, requirements.txt, requirements_wo_ds.txt,...

Update README.md, requirements, requirements.txt, requirements_wo_ds.txt, test.py, api.py, cli_demo_mp.py, LICENSE.txt, MODEL_LICENSE.txt, finetune_visualglm.py, cli_demo_hf.py, cli_demo.py, web_demo_hf.py, web_demo.py, README_en.md, api_hf.py, temp.json, fewshot-data.zip files
Deleted .gitignore, .gitlab-ci.yml, LICENSE, gatsby-browser.js, gatsby-node.js, gatsby-ssr.js, gatsby-config.js, package-lock.json, package.json, src/components/header.js, src/components/image.js, src/components/layout.css, src/components/layout.js, src/components/seo.js, src/images/gatsby-astronaut.png, src/images/gatsby-icon.png, src/pages/404.js, src/pages/index.js, src/pages/page-2.js files
parent 60c3df39
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
# Runtime data
pids
*.pid
*.seed
*.pid.lock
# Directory for instrumented libs generated by jscoverage/JSCover
lib-cov
# Coverage directory used by tools like istanbul
coverage
# nyc test coverage
.nyc_output
# Grunt intermediate storage (http://gruntjs.com/creating-plugins#storing-task-files)
.grunt
# Bower dependency directory (https://bower.io/)
bower_components
# node-waf configuration
.lock-wscript
# Compiled binary addons (http://nodejs.org/api/addons.html)
build/Release
# Dependency directories
node_modules/
jspm_packages/
# Typescript v1 declaration files
typings/
# Optional npm cache directory
.npm
# Optional eslint cache
.eslintcache
# Optional REPL history
.node_repl_history
# Output of 'npm pack'
*.tgz
# dotenv environment variable files
.env*
# gatsby files
.cache/
public
# Mac files
.DS_Store
# Yarn
yarn-error.log
.pnp/
.pnp.js
# Yarn Integrity file
.yarn-integrity
image: node:10
variables:
GIT_SUBMODULE_STRATEGY: recursive
build:
stage: build
script:
- CI=true
- npm install
- npm run build
artifacts:
paths:
- public
cache:
key: ${CI_COMMIT_REF_SLUG}
paths:
- node_modules/
test:
stage: build
script:
- npm install
- npm test
cache:
key: ${CI_COMMIT_REF_SLUG}
paths:
- node_modules/
pages:
image: alpine:latest
stage: deploy
script:
- echo "GitLab Pages magic happens here"
dependencies:
- build # Pulls in the artifacts from the build stage
artifacts:
paths:
- public # Required artifact of Gitlab pages
# only:
# - master
The MIT License (MIT)
Copyright (c) 2015 gatsbyjs
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright Zhengxiao Du
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
\ No newline at end of file
The VisualGLM-6B License
1. Definitions
“Licensor” means the VisualGLM-6B Model Team that distributes its Software.
“Software” means the VisualGLM-6B model parameters made available under this license.
2. License Grant
Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes.
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
3. Restriction
You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes.
You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings.
4. Disclaimer
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
5. Limitation of Liability
EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
6. Dispute Resolution
This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing.
Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us.
# Gatsby Template for GitLab Pages # VisualGLM-6B
## 模型结构
VisualGLM 模型架构是 ViT + QFormer + ChatGLM,在预训练阶段对 QFormer 和 ViT LoRA 进行训练,在微调阶段对 QFormer 和 ChatGLM LoRA 进行训练,训练目标是自回归损失(根据图像生成正确的文本)和对比损失(输入 ChatGLM 的视觉特征与对应文本的语义特征对齐)
<div align=center>
<img src="./doc/image.png"/>
</div>
## 算法原理
VisualGLM-6B 是一个开源的,支持图像、中文和英文的多模态对话语言模型,语言模型基于 ChatGLM-6B,具有 62 亿参数;图像部分通过训练 BLIP2-Qformer 构建起视觉模型与语言模型的桥梁,整体模型共78亿参数。
VisualGLM-6B 由 SwissArmyTransformer(简称sat) 库训练,这是一个支持Transformer灵活修改、训练的工具库,支持Lora、P-tuning等参数高效微调方法。本项目提供了符合用户习惯的huggingface接口,也提供了基于sat的接口。
结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4量化级别下最低只需6.3G显存)。
## 环境配置
### Docker(方法一)
在光源可拉取推理的docker镜像,拉取方式如下:
```
docker pull xxx
docker run -it -v /path/your_code_data/:/path/your_code_data/ --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash
cd /path/workspace/
```
### Dockerfile(方法二)
此处提供dockerfile的使用方法
```
docker build --no-cache -t xxx:latest .
docker run -it -v /path/your_code_data/:/path/your_code_data/ --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash
```
### Anaconda(方法三)
此处提供本地配置、编译的详细步骤,例如:
关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.hpccube.com/tool/)开发者社区下载安装。
```
DTK驱动:dtk23.10
python:python3.8
torch:1.13.1
torchvision:0.14.1
deepspeed:0.12.3
```
其它非深度学习库参照requirements.txt安装:
```
pip install -r requirements.txt
```
## 推理
使用pip安装依赖
```
pip install -i https://mirrors.aliyun.com/pypi/simple/ -r requirements.txt
```
此时默认会安装deepspeed库(支持sat库训练),此库对于模型推理并非必要,同时部分Windows环境安装此库时会遇到问题。 如果想绕过deepspeed安装,我们可以将命令改为
```
pip install -i https://mirrors.aliyun.com/pypi/simple/ -r requirements_wo_ds.txt
pip install -i https://mirrors.aliyun.com/pypi/simple/ --no-deps "SwissArmyTransformer>=0.4.4"
```
使用Huggingface transformers库调用模型,可以通过如下代码(其中图像路径为本地路径,模型路径为THUDM/visulglm-6b):
```
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).half().cuda()
image_path = "your image path"
response, history = model.chat(tokenizer, image_path, "描述这张图片。", history=[])
print(response)
response, history = model.chat(tokenizer, image_path, "这张图片可能是在什么场所拍摄的?", history=history)
print(response)
```
得到返回值
```
Specify both input_ids and inputs_embeds at the same time, will use inputs_embeds
这张照片中,一位女士坐在沙发上使用笔记本电脑和鼠标。她似乎正在浏览网页或工作。她的姿势表明她在放松、享受或专注于她的工作。背景中的瓶子可能暗示着饮料或其他日常用品的存在。椅子和沙发的布置也表明这是一个舒适的环境,适合休息或进行轻松的工作活动。
考虑到照片的背景和场景设置,可以推断出这个场景是一个舒适的环境中拍摄的照片,例如家庭住宅或休闲空间。这位女士坐在一张沙发上,周围有瓶子和其他物品,这表明这个地方可能有一些日常用品或装饰。这种布置可能会鼓励人们放松身心并享受他们的日常活动,比如观看电影、阅读书籍或者与亲朋好友聊天。
```
命令行 Demo
```
python cli_demo_hf.py
```
程序会自动下载sat模型,并在命令行中进行交互式的对话,输入指示并回车即可生成回复,输入 clear 可以清空对话历史,输入 stop 终止程序。
API部署
首先需要安装额外的依赖 pip install fastapi uvicorn,然后运行仓库中的 api.py:
```
python api.py
```
程序会自动下载 sat 模型,默认部署在本地的 8080 端口,通过 POST 方法进行调用。下面是用curl请求的例子,一般而言可以也可以使用代码方法进行POST。
```
echo "{\"image\":\"$(base64 path/to/example.jpg)\",\"text\":\"描述这张图片\",\"history\":[]}" > temp.json
curl -X POST -H "Content-Type: application/json" -d @temp.json http://127.0.0.1:8080
```
得到的返回值为
```
{"result":"这张照片中,一个年轻女子坐在沙发上,手里拿着笔记本电脑。她可能正在工作或学习,或者只是放松和享受时间。","history":[["描述这张图片","这张照片中,一个年轻女子坐在沙发上,手里拿着笔记本电脑。她可能正在工作或学习,或者只是放松和享受时间。"]],"status":200,"time":"2024-01-22 11:16:35"}
```
### 精度
## 应用场景
### 算法类别
多模态对话
### 热点应用行业
## 源码仓库及问题反馈
- 此处填本项目gitlab地址
## 参考资料
- https://github.com/THUDM/VisualGLM-6B?tab=readme-ov-file
Examples [Gatsby](https://www.gatsbyjs.org/) website using [GitLab pages](https://about.gitlab.com/product/pages/).
## 🚀 Quick start
1. **Get the code**
For, clone, or download this project. You can also start with a [new GitLab project from template](https://gitlab.com/projects/new#create-from-template) and choose [Gatsby](https://gitlab.com/gitlab-org/project-templates/gatsby).
2. **Start developing.**
Navigate into your new site’s directory and start it up.
```shell
cd my-project/
gatsby develop
```
3. **Open the source code and start editing!**
Your site is now running at `http://localhost:8000`!
_Note: You'll also see a second link: _`http://localhost:8000/___graphql`_. This is a tool you can use to experiment with querying your data. Learn more about using this tool in the [Gatsby tutorial](https://www.gatsbyjs.org/tutorial/part-five/#introducing-graphiql)._
Open your project directory directory in your code editor of choice and edit `src/pages/index.js`. Save your changes and the browser will update in real time!
4. **Deploy with GitLab CI and GitLab Pages**
This project's static Pages are built by GitLab CI, following the steps defined in [`.gitlab-ci.yml`](./gitlab-ci.yml) then deployed to [GitLab pages](https://docs.gitlab.com/ee/user/project/pages/).
## 🧐 What's inside?
A quick look at the top-level files and directories you'll see in a Gatsby project.
.
├── node_modules
├── src
├── .gitignore
├── .prettierrc
├── gatsby-browser.js
├── gatsby-config.js
├── gatsby-node.js
├── gatsby-ssr.js
├── LICENSE
├── package-lock.json
├── package.json
└── README.md
1. **`/node_modules`**: This directory contains all of the modules of code that your project depends on (npm packages) are automatically installed.
2. **`/src`**: This directory will contain all of the code related to what you will see on the front-end of your site (what you see in the browser) such as your site header or a page template. `src` is a convention for “source code”.
3. **`.gitignore`**: This file tells git which files it should not track / not maintain a version history for.
4. **`.prettierrc`**: This is a configuration file for [Prettier](https://prettier.io/). Prettier is a tool to help keep the formatting of your code consistent.
5. **`gatsby-browser.js`**: This file is where Gatsby expects to find any usage of the [Gatsby browser APIs](https://www.gatsbyjs.org/docs/browser-apis/) (if any). These allow customization/extension of default Gatsby settings affecting the browser.
6. **`gatsby-config.js`**: This is the main configuration file for a Gatsby site. This is where you can specify information about your site (metadata) like the site title and description, which Gatsby plugins you’d like to include, etc. (Check out the [config docs](https://www.gatsbyjs.org/docs/gatsby-config/) for more detail).
7. **`gatsby-node.js`**: This file is where Gatsby expects to find any usage of the [Gatsby Node APIs](https://www.gatsbyjs.org/docs/node-apis/) (if any). These allow customization/extension of default Gatsby settings affecting pieces of the site build process.
8. **`gatsby-ssr.js`**: This file is where Gatsby expects to find any usage of the [Gatsby server-side rendering APIs](https://www.gatsbyjs.org/docs/ssr-apis/) (if any). These allow customization of default Gatsby settings affecting server-side rendering.
9. **`LICENSE`**: Gatsby is licensed under the MIT license.
10. **`package-lock.json`** (See `package.json` below, first). This is an automatically generated file based on the exact versions of your npm dependencies that were installed for your project. **(You won’t change this file directly).**
11. **`package.json`**: A manifest file for Node.js projects, which includes things like metadata (the project’s name, author, etc). This manifest is how npm knows which packages to install for your project.
12. **`README.md`**: A text file containing useful reference information about your project.
## 🎓 Learning Gatsby
Looking for more guidance? Full documentation for Gatsby lives [on the website](https://www.gatsbyjs.org/). Here are some places to start:
- **For most developers, we recommend starting with our [in-depth tutorial for creating a site with Gatsby](https://www.gatsbyjs.org/tutorial/).** It starts with zero assumptions about your level of ability and walks through every step of the process.
- **To dive straight into code samples, head [to our documentation](https://www.gatsbyjs.org/docs/).** In particular, check out the _Guides_, _API Reference_, and _Advanced Tutorials_ sections in the sidebar.
# VisualGLM-6B
<p align="center">
🤗 <a href="https://huggingface.co/THUDM/visualglm-6b" target="_blank">HF Repo</a> • ⚒️ <a href="https://github.com/THUDM/SwissArmyTransformer" target="_blank">SwissArmyTransformer (sat)</a> • 🐦 <a href="https://twitter.com/thukeg" target="_blank">Twitter</a>
</p>
<p align="center">
• 📃 <a href="https://arxiv.org/abs/2105.13290" target="_blank">[CogView@NeurIPS 21]</a> <a href="https://github.com/THUDM/CogView" target="_blank">[GitHub]</a> • 📃 <a href="https://arxiv.org/abs/2103.10360" target="_blank">[GLM@ACL 22]</a> <a href="https://github.com/THUDM/GLM" target="_blank">[GitHub]</a> <br>
</p>
<p align="center">
👋 Join us on <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1th2q5u69-7tURzFuOPanmuHy9hsZnKA" target="_blank">Slack</a> and <a href="resources/WECHAT.md" target="_blank">WeChat</a>
</p>
<!-- <p align="center">
🤖<a href="https://huggingface.co/spaces/THUDM/visualglm-6b" target="_blank">VisualGLM-6B Online Demo Website</a>
</p> -->
## Introduction
VisualGLM-6B is an open-source, multi-modal dialog language model that supports **images, Chinese, and English**. The language model is based on [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) with 6.2 billion parameters; the image part builds a bridge between the visual model and the language model through the training of [BLIP2-Qformer](https://arxiv.org/abs/2301.12597), with the total model comprising 7.8 billion parameters.
VisualGLM-6B relies on 30M high-quality Chinese image-text pairs from the [CogView](https://arxiv.org/abs/2105.13290) dataset and 300M filtered English image-text pairs for pre-training, with equal weight for Chinese and English. This training method aligns visual information well to the semantic space of ChatGLM. In the subsequent fine-tuning phase, the model is trained on long visual question answering data to generate answers that align with human preferences.
VisualGLM-6B is trained using the [SwissArmyTransformer](https://github.com/THUDM/SwissArmyTransformer) (abbreviated as sat) library, a utility library for flexible modification and training of Transformer, supporting efficient fine-tuning methods like Lora and P-tuning. This project provides a user-friendly huggingface interface, as well as an interface based on sat.
However, as VisualGLM-6B is still at the v1 stage, it is known to have quite a few [**limitations**](#Limitations), such as factual inaccuracy/model hallucination in image description, lack of capturing image detail information, and some limitations from the language model. Please be aware of these issues and evaluate the potential risks before using. In future versions of VisualGLM, we will strive to optimize these issues.
With model quantization technology, users can deploy locally on consumer-grade graphics cards (requiring as little as 6.3G memory under INT4 quantization level).
## Examples
VisualGLM-6B can answer questions related to image description.
![Titanic example](examples/chat_example1.png)
<details>
<summary>It can also combine common sense or propose interesting views. Click to expand/collapse more examples</summary>
![Ironing shirt taxi example](examples/chat_example2.png)
![Mona Lisa dog example](examples/chat_example3.png)
</details>
## Usage
### Model Inference
Install dependencies with pip
```
pip install -i https://pypi.org/simple -r requirements.txt
pip install -i https://mirrors.aliyun.com/pypi/simple/ -r requirements.txt
```
This will default to installing the deepspeed library (which supports the sat library training). This library is not necessary for model inference and can cause problems when installed in some Windows environments.
If you want to bypass deepspeed installation, you can change the command to:
```
pip install -i https://mirrors.aliyun.com/pypi/simple/ -r requirements_wo_ds.txt
pip install -i https://mirrors.aliyun.com/pypi/simple/ --no-deps "SwissArmyTransformer>=0.3.6"
```
If you are calling the model using the Huggingface transformers library (you also need to install the above dependency packages!), you can use the following code (where the image path is the local path):
```python
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).half().cuda()
image_path = "your image path"
response, history = model.chat(tokenizer, image_path, "描述这张图片。", history=[])
print(response)
response, history = model.chat(tokenizer, image_path, "这张图片可能是在什么场所拍摄的?", history=history)
print(response)
```
If you use the SwissArmyTransformer library to call the model, the method is similar, and you can use the environment variable SAT_HOME to determine the model download location. In the directory of this repository:
```python
import argparse
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
from model import chat, VisualGLMModel
model, model_args = VisualGLMModel.from_pretrained('visualglm-6b', args=argparse.Namespace(fp16=True, skip_init=True))
from sat.model.mixins import CachedAutoregressiveMixin
model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
image_path = "your image path or URL"
response, history, cache_image = chat(image_path, model, tokenizer, "Describe this picture.", history=[])
print(response)
response, history, cache_image = chat(None, model, tokenizer, "Where could this picture possibly have been taken?", history=history, image=cache_image)
print(response)
```
Using the `sat` library can also easily carry out efficient parameter fine-tuning. <!-- TODO specific code -->
Please note that the Huggingface model implementation is located in the [Huggingface repository](https://huggingface.co/THUDM/visualglm-6b), and the `sat` model implementation is included in this repository.
## Model Fine-tuning
Multimodal tasks are wide-ranging and diverse, and pre-training often cannot cover all bases.
Here we provide an example of small sample fine-tuning, using 20 labeled images to enhance the model's ability to answer "background" questions.
After unzipping fewshot-data.zip, run the following command:
```
bash finetune/finetune_visualglm.sh
```
Currently we support three types of (parameter-efficient) fine-tuning:
* LoRA: In the given example, we add rank=10 LoRA for layer 0 and layer 14 in ChatGLM. You can adjust `--layer_range` and `--lora_rank` to fit your application and data amount.
* QLoRA: If your resource is limited, consider using `bash finetune/finetune_visualglm_qlora.sh`, which do 4-bit quantization for ChatGLM Linear layers, reducing the required GPU memory to 9.8 GB.
* P-tuning: You can replace `--use_lora` to `--use_ptuning`, but not recommended, unless your application has a relatively fixed input and output template.
After training, you can use the following command for inference:
```
python cli_demo.py --from_pretrained your_checkpoint_path --prompt_zh 这张图片的背景里有什么内容?
```
Fine-tuning requires the installation of the deepspeed library, and currently this process only supports the Linux system. More examples and instructions for the Windows system will be completed in the near future.
If you want to merge LoRA weights into original weights, just call `merge_lora()`:
```python
from finetune_visualglm import FineTuneVisualGLMModel
import argparse
model, args = FineTuneVisualGLMModel.from_pretrained('checkpoints/finetune-visualglm-6b-05-19-07-36',
args=argparse.Namespace(
fp16=True,
skip_init=True,
use_gpu_initialization=True,
device='cuda',
))
model.get_mixin('lora').merge_lora()
args.layer_range = []
args.save = 'merge_lora'
args.mode = 'inference'
from sat.training.model_io import save_checkpoint
save_checkpoint(1, model, None, None, args)
```
## Deployment Tools
### Command Line Demo
```shell
python cli_demo.py
```
The program will automatically download the sat model and interact in the command line. You can generate replies by entering instructions and pressing enter. Enter 'clear' to clear the conversation history and 'stop' to stop the program.
![cli_demo](examples/thu.png)
The program provides the following hyperparameters to control the generation process and quantization accuracy:
```
usage: cli_demo.py [-h] [--max_length MAX_LENGTH] [--top_p TOP_P] [--top_k TOP_K] [--temperature TEMPERATURE] [--english] [--quant {8,4}]
optional arguments:
-h, --help show this help message and exit
--max_length MAX_LENGTH
max length of the total sequence
--top_p TOP_P top p for nucleus sampling
--top_k TOP_K top k for top k sampling
--temperature TEMPERATURE
temperature for sampling
--english only output English
--quant {8,4} quantization bits
```
Note that during training, the prompt words for English Q&A pairs are 'Q: A:', while in Chinese they are '问:答:'. The web demo uses Chinese prompts, so the English replies will be worse and interspersed with Chinese; if you need English replies, please use the --english option in cli_demo.py.
We also provide a typewriter effect command line tool inherited from ChatGLM-6B, which uses the Huggingface model:
```shell
python cli_demo_hf.py
```
### Web Demo
![web_demo](examples/web_demo.png)
We provide a web demo based on [Gradio](https://gradio.app). First, install Gradio: `pip install gradio`.
Then download and enter this repository and run `web_demo.py`:
```
git clone https://github.com/THUDM/VisualGLM-6B
cd VisualGLM-6B
python web_demo.py
```
The program will automatically download the sat model and run a Web Server, outputting the address. Open the output address in your browser to use it.
We also provide a web tool with a typewriter effect inherited from ChatGLM-6B, which uses the Huggingface model and will run on port :8080 after starting:
```shell
python web_demo_hf.py
```
Both web demos accept the command line parameter --share to generate a public link for gradio, and accept --quant 4 and --quant 8 to use 4-bit quantization/8-bit quantization to reduce GPU memory usage.
### API Deployment
First, you need to install additional dependencies pip install fastapi uvicorn, then run the api.py in the repository:
```shell
python api.py
```
The program will automatically download the sat model, and by default it will be deployed on local port 8080 and called through the POST method. Below is an example of a request with curl, but in general you can also use a code method to POST.
```shell
echo "{\"image\":\"$(base64 path/to/example.jpg)\",\"text\":\"Describe this picture\",\"history\":[]}" > temp.json
curl -X POST -H "Content-Type: application/json" -d @temp.json http://127.0.0.1:8080
```
We also provide an api_hf.py that uses the Huggingface model, which works the same way as the sat model's api:
```shell
python api_hf.py
```
## Model Quantization
In the Huggingface implementation, the model is loaded with FP16 precision by default, and running the above code requires about 15GB of GPU memory. If your GPU memory is limited, you can try loading the model in a quantized manner.
Here's how:
```python
# Modify as needed, currently only 4/8 bit quantization is supported. The following will only quantize ChatGLM, as the error is larger when quantizing ViT
model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).quantize(8).half().cuda()
```
In the sat implementation, you need to change the loading location to 'cpu' first, and then perform quantization. Here's how, see cli_demo.py for details:
```python
from sat.quantization.kernels import quantize
model = quantize(model, args.quant).cuda()
# only need 7GB GPU memory to inference
```
## Limitations
This project is currently at V1 version of the visual and language model parameters, the amount of calculation is small, we have summarized the following main improvements:
- Image description factuality/model hallucination problem. When generating long descriptions of images, as the distance from the image increases, the language model will dominate, and there is a certain possibility of generating content that does not exist in the image based on the context.
- Attribute mismatch problem. In scenes with multiple objects, some attributes of some objects are often incorrectly inserted onto other objects.
- Resolution issue. This project uses a resolution of 224*224, which is the most commonly used size in visual models; however, for more fine-grained understanding, larger resolution and computation are necessary.
- Due to data and other reasons, the model currently does not have the ability to perform Chinese OCR (some ability for English OCR), we will add this ability in future versions.
## License
The code in this repository is open source under the Apache-2.0 license, while the use of the VisualGLM-6B model weights must comply with the Model License.
## Citation & Acknowledgements
If you find our work helpful, please consider citing the following papers
```
@inproceedings{du2022glm,
title={GLM: General Language Model Pretraining with Autoregressive Blank Infilling},
author={Du, Zhengxiao and Qian, Yujie and Liu, Xiao and Ding, Ming and Qiu, Jiezhong and Yang, Zhilin and Tang, Jie},
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
pages={320--335},
year={2022}
}
@article{ding2021cogview,
title={Cogview: Mastering text-to-image generation via transformers},
author={Ding, Ming and Yang, Zhuoyi and Hong, Wenyi and Zheng, Wendi and Zhou, Chang and Yin, Da and Lin, Junyang and Zou, Xu and Shao, Zhou and Yang, Hongxia and others},
journal={Advances in Neural Information Processing Systems},
volume={34},
pages={19822--19835},
year={2021}
}
```
In the instruction fine-tuning phase of the VisualGLM-6B dataset, there are some English image-text data from the [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4) and [LLAVA](https://github.com/haotian-liu/LLaVA) projects, as well as many classic cross-modal work datasets. We sincerely thank them for their contributions.
import os
import json
import uvicorn
from fastapi import FastAPI, Request
from model import is_chinese, get_infer_setting, generate_input, chat
import datetime
import torch
gpu_number = 0
model, tokenizer = get_infer_setting(gpu_device=gpu_number)
app = FastAPI()
@app.post('/')
async def visual_glm(request: Request):
json_post_raw = await request.json()
print("Start to process request")
json_post = json.dumps(json_post_raw)
request_data = json.loads(json_post)
input_text, input_image_encoded, history = request_data['text'], request_data['image'], request_data['history']
input_para = {
"max_length": 2048,
"min_length": 50,
"temperature": 0.8,
"top_p": 0.4,
"top_k": 100,
"repetition_penalty": 1.2
}
input_para.update(request_data)
is_zh = is_chinese(input_text)
input_data = generate_input(input_text, input_image_encoded, history, input_para)
input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs']
with torch.no_grad():
answer, history, _ = chat(None, model, tokenizer, input_text, history=history, image=input_image, \
max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \
top_k = gen_kwargs['top_k'], temperature=gen_kwargs['temperature'], english=not is_zh)
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
response = {
"result": answer,
"history": history,
"status": 200,
"time": time
}
return response
if __name__ == '__main__':
uvicorn.run(app, host='0.0.0.0', port=8080, workers=1)
import os
import json
from transformers import AutoTokenizer, AutoModel
import uvicorn
from fastapi import FastAPI, Request
import datetime
from model import process_image
import torch
tokenizer = AutoTokenizer.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).half().cuda()
app = FastAPI()
@app.post('/')
async def visual_glm(request: Request):
json_post_raw = await request.json()
print("Start to process request")
json_post = json.dumps(json_post_raw)
request_data = json.loads(json_post)
history = request_data.get("history")
image_encoded = request_data.get("image")
query = request_data.get("text")
image_path = process_image(image_encoded)
with torch.no_grad():
result = model.stream_chat(tokenizer, image_path, query, history=history)
last_result = None
for value in result:
last_result = value
answer = last_result[0]
if os.path.isfile(image_path):
os.remove(image_path)
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
response = {
"result": answer,
"history": history,
"status": 200,
"time": time
}
return response
if __name__ == "__main__":
uvicorn.run(app, host='0.0.0.0', port=8080, workers=1)
\ No newline at end of file
# -*- encoding: utf-8 -*-
import os
import sys
import torch
import argparse
from transformers import AutoTokenizer
from sat.model.mixins import CachedAutoregressiveMixin
from sat.quantization.kernels import quantize
from model import VisualGLMModel, chat
from finetune_visualglm import FineTuneVisualGLMModel
from sat.model import AutoModel
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--max_length", type=int, default=2048, help='max length of the total sequence')
parser.add_argument("--top_p", type=float, default=0.4, help='top p for nucleus sampling')
parser.add_argument("--top_k", type=int, default=100, help='top k for top k sampling')
parser.add_argument("--temperature", type=float, default=.8, help='temperature for sampling')
parser.add_argument("--english", action='store_true', help='only output English')
parser.add_argument("--quant", choices=[8, 4], type=int, default=None, help='quantization bits')
parser.add_argument("--from_pretrained", type=str, default="/data", help='pretrained ckpt')
parser.add_argument("--prompt_zh", type=str, default="描述这张图片。", help='Chinese prompt for the first round')
parser.add_argument("--prompt_en", type=str, default="Describe the image.", help='English prompt for the first round')
args = parser.parse_args()
# load model
model, model_args = AutoModel.from_pretrained(
args.from_pretrained,
args=argparse.Namespace(
fp16=True,
skip_init=True,
use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False,
device='cuda' if (torch.cuda.is_available() and args.quant is None) else 'cpu',
))
model = model.eval()
if args.quant:
quantize(model, args.quant)
if torch.cuda.is_available():
model = model.cuda()
model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
tokenizer = AutoTokenizer.from_pretrained("/data", trust_remote_code=True)
if not args.english:
print('欢迎使用 VisualGLM-6B 模型,输入图像URL或本地路径读图,继续输入内容对话,clear 重新开始,stop 终止程序')
else:
print('Welcome to VisualGLM-6B model. Enter an image URL or local file path to load an image. Continue inputting text to engage in a conversation. Type "clear" to start over, or "stop" to end the program.')
with torch.no_grad():
while True:
history = None
cache_image = None
if not args.english:
image_path = input("请输入图像路径或URL(回车进入纯文本对话): ")
else:
image_path = input("Please enter the image path or URL (press Enter for plain text conversation): ")
if image_path == 'stop':
break
if len(image_path) > 0:
query = args.prompt_en if args.english else args.prompt_zh
else:
if not args.english:
query = input("用户:")
else:
query = input("User: ")
while True:
if query == "clear":
break
if query == "stop":
sys.exit(0)
try:
response, history, cache_image = chat(
image_path,
model,
tokenizer,
query,
history=history,
image=cache_image,
max_length=args.max_length,
top_p=args.top_p,
temperature=args.temperature,
top_k=args.top_k,
english=args.english,
invalid_slices=[slice(63823, 130000)] if args.english else []
)
except Exception as e:
print(e)
break
sep = 'A:' if args.english else '答:'
print("VisualGLM-6B:"+response.split(sep)[-1].strip())
image_path = None
if not args.english:
query = input("用户:")
else:
query = input("User: ")
if __name__ == "__main__":
main()
import os
import platform
import signal
from transformers import AutoTokenizer, AutoModel
import torch
tokenizer = AutoTokenizer.from_pretrained("/data", trust_remote_code=True)
model = AutoModel.from_pretrained("/data", trust_remote_code=True).half().cuda()
model = model.eval()
os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
stop_stream = False
def build_prompt(history, prefix):
prompt = prefix
for query, response in history:
prompt += f"\n\n用户:{query}"
prompt += f"\n\nVisualGLM-6B:{response}"
return prompt
def signal_handler(signal, frame):
global stop_stream
stop_stream = True
def main():
global stop_stream
while True:
history = []
prefix = "欢迎使用 VisualGLM-6B 模型,输入图片路径和内容即可进行对话,clear 清空对话历史,stop 终止程序"
print(prefix)
image_path = input("\n请输入图片路径:")
if image_path == "stop":
break
prefix = prefix + "\n" + image_path
query = "描述这张图片。"
while True:
count = 0
with torch.no_grad():
for response, history in model.stream_chat(tokenizer, image_path, query, history=history):
if stop_stream:
stop_stream = False
break
else:
count += 1
if count % 8 == 0:
os.system(clear_command)
print(build_prompt(history, prefix), flush=True)
signal.signal(signal.SIGINT, signal_handler)
os.system(clear_command)
print(build_prompt(history, prefix), flush=True)
query = input("\n用户:")
if query.strip() == "clear":
break
if query.strip() == "stop":
stop_stream = True
exit(0)
# if query.strip() == "clear":
# history = []
# os.system(clear_command)
# print(prefix)
# continue
if __name__ == "__main__":
main()
# -*- encoding: utf-8 -*-
import os
import sys
import torch
import argparse
from transformers import AutoTokenizer
from sat.model.mixins import CachedAutoregressiveMixin
from sat.quantization.kernels import quantize
from model import VisualGLMModel, chat
from finetune_visualglm import FineTuneVisualGLMModel
from sat.model import AutoModel
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--max_length", type=int, default=2048, help='max length of the total sequence')
parser.add_argument("--top_p", type=float, default=0.4, help='top p for nucleus sampling')
parser.add_argument("--top_k", type=int, default=100, help='top k for top k sampling')
parser.add_argument("--temperature", type=float, default=.8, help='temperature for sampling')
parser.add_argument("--english", action='store_true', help='only output English')
parser.add_argument("--quant", choices=[8, 4], type=int, default=None, help='quantization bits')
parser.add_argument("--from_pretrained", type=str, default="/data", help='pretrained ckpt')
parser.add_argument("--prompt_zh", type=str, default="描述这张图片。", help='Chinese prompt for the first round')
parser.add_argument("--prompt_en", type=str, default="Describe the image.", help='English prompt for the first round')
args = parser.parse_args()
# load model
model, model_args = AutoModel.from_pretrained(
args.from_pretrained,
args=argparse.Namespace(
fp16=True,
skip_init=True,
use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False,
device='cuda' if (torch.cuda.is_available() and args.quant is None) else 'cpu',
), overwrite_args={'model_parallel_size': 2})
model = model.eval()
if args.quant:
quantize(model.transformer, args.quant)
if torch.cuda.is_available():
model = model.cuda()
model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
tokenizer = AutoTokenizer.from_pretrained("/data", trust_remote_code=True)
image_path = 'fewshot-data/meme.png'
query = args.prompt_en if args.english else args.prompt_zh
history = None
cache_image = None
response, history, cache_image = chat(
image_path,
model,
tokenizer,
query,
history=history,
image=cache_image,
max_length=args.max_length,
top_p=args.top_p,
temperature=args.temperature,
top_k=args.top_k,
english=args.english,
invalid_slices=[slice(63823, 130000)] if args.english else []
)
sep = 'A:' if args.english else '答:'
print(response.split(sep)[-1].strip())
if __name__ == "__main__":
main()
import os
import torch
import argparse
from sat import mpu, get_args, get_tokenizer
from sat.training.deepspeed_training import training_main
from model import VisualGLMModel
from sat.model.finetune import PTuningV2Mixin
from sat.model.finetune.lora2 import LoraMixin
class FineTuneVisualGLMModel(VisualGLMModel):
def __init__(self, args, transformer=None, parallel_output=True, **kw_args):
super().__init__(args, transformer=transformer, parallel_output=parallel_output, **kw_args)
if args.use_ptuning:
self.add_mixin("ptuning", PTuningV2Mixin(args.num_layers, args.hidden_size // args.num_attention_heads, args.num_attention_heads, args.pre_seq_len))
if args.use_lora:
self.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, layer_range=args.layer_range), reinit=True)
# self.get_mixin("eva").model.glm_proj = replace_linear_with_lora(self.get_mixin("eva").model.glm_proj, LoraLinear, args.lora_rank)
elif args.use_qlora:
self.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, layer_range=args.layer_range, qlora=True), reinit=True)
self.args = args
@classmethod
def add_model_specific_args(cls, parser):
group = parser.add_argument_group('VisualGLM-finetune', 'VisualGLM finetune Configurations')
group.add_argument('--pre_seq_len', type=int, default=8)
group.add_argument('--lora_rank', type=int, default=10)
group.add_argument('--use_ptuning', action="store_true")
group.add_argument('--use_lora', action="store_true")
group.add_argument('--use_qlora', action="store_true")
group.add_argument('--layer_range', nargs='+', type=int, default=None)
return super().add_model_specific_args(parser)
def disable_untrainable_params(self):
enable = []
if self.args.use_ptuning:
enable.extend(['ptuning'])
if self.args.use_lora or self.args.use_qlora:
enable.extend(['matrix_A', 'matrix_B'])
for n, p in self.named_parameters():
flag = False
for e in enable:
if e.lower() in n.lower():
flag = True
break
if not flag:
p.requires_grad_(False)
else:
print(n)
def get_batch(data_iterator, args, timers):
# Items and their type.
keys = ['input_ids', 'labels']
datatype = torch.int64
# Broadcast data.
timers('data loader').start()
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
timers('data loader').stop()
data_b = mpu.broadcast_data(keys, data, datatype)
data_i = mpu.broadcast_data(['image'], data, torch.float32)
# Unpack.
tokens = data_b['input_ids'].long()
labels = data_b['labels'].long()
img = data_i['image']
if args.fp16:
img = img.half()
return tokens, labels, img, data['pre_image']
from torch.nn import CrossEntropyLoss
def forward_step(data_iterator, model, args, timers):
"""Forward step."""
# Get the batch.
timers('batch generator').start()
tokens, labels, image, pre_image = get_batch(
data_iterator, args, timers)
timers('batch generator').stop()
logits = model(input_ids=tokens, image=image, pre_image=pre_image)[0]
dtype = logits.dtype
lm_logits = logits.to(torch.float32)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
lm_logits = lm_logits.to(dtype)
loss = loss.to(dtype)
return loss, {'loss': loss}
from model.blip2 import BlipImageEvalProcessor
from torch.utils.data import Dataset
import json
from PIL import Image
class FewShotDataset(Dataset):
def __init__(self, path, processor, tokenizer, args):
max_seq_length = args.max_source_length + args.max_target_length
with open(path, 'r', encoding='utf-8') as f:
data = json.load(f)
self.images = []
self.input_ids = []
self.labels = []
for item in data:
image = processor(Image.open(item['img']).convert('RGB'))
input0 = tokenizer.encode("<img>", add_special_tokens=False)
input1 = [tokenizer.pad_token_id] * args.image_length
input2 = tokenizer.encode("</img>问:"+item['prompt']+"\n答:", add_special_tokens=False)
a_ids = sum([input0, input1, input2], [])
b_ids = tokenizer.encode(text=item['label'], add_special_tokens=False)
if len(a_ids) > args.max_source_length - 1:
a_ids = a_ids[: args.max_source_length - 1]
if len(b_ids) > args.max_target_length - 2:
b_ids = b_ids[: args.max_target_length - 2]
pre_image = len(input0)
input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids)
context_length = input_ids.index(tokenizer.bos_token_id)
mask_position = context_length - 1
labels = [-100] * context_length + input_ids[mask_position+1:]
pad_len = max_seq_length - len(input_ids)
input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
labels = labels + [tokenizer.pad_token_id] * pad_len
if args.ignore_pad_token_for_loss:
labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]
self.images.append(image)
self.input_ids.append(input_ids)
self.labels.append(labels)
self.pre_image = pre_image
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
return {
"image": self.images[idx],
"input_ids": self.input_ids[idx],
"labels": self.labels[idx],
"pre_image": self.pre_image
}
def create_dataset_function(path, args):
tokenizer = get_tokenizer(args)
image_processor = BlipImageEvalProcessor(224)
dataset = FewShotDataset(path, image_processor, tokenizer, args)
return dataset
if __name__ == '__main__':
py_parser = argparse.ArgumentParser(add_help=False)
py_parser.add_argument('--max_source_length', type=int)
py_parser.add_argument('--max_target_length', type=int)
py_parser.add_argument('--ignore_pad_token_for_loss', type=bool, default=True)
# py_parser.add_argument('--old_checkpoint', action="store_true")
py_parser.add_argument('--source_prefix', type=str, default="")
py_parser = FineTuneVisualGLMModel.add_model_specific_args(py_parser)
known, args_list = py_parser.parse_known_args()
args = get_args(args_list)
args = argparse.Namespace(**vars(args), **vars(known))
args.device = 'cpu'
model_type = 'visualglm-6b'
model, args = FineTuneVisualGLMModel.from_pretrained(model_type, args)
if torch.cuda.is_available():
model = model.to('cuda')
tokenizer = get_tokenizer(args)
label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id
def data_collator(examples):
for example in examples:
example['input_ids'] = torch.tensor(example['input_ids'], dtype=torch.long)
example['labels'] = torch.tensor(example['labels'], dtype=torch.long)
ret = {
'input_ids': torch.stack([example['input_ids'] for example in examples]),
'labels': torch.stack([example['labels'] for example in examples]),
'image': torch.stack([example['image'] for example in examples]),
'pre_image': example['pre_image']
}
return ret
training_main(args, model_cls=model, forward_step_function=forward_step, create_dataset_function=create_dataset_function, collate_fn=data_collator)
\ No newline at end of file
/**
* Implement Gatsby's Browser APIs in this file.
*
* See: https://www.gatsbyjs.org/docs/browser-apis/
*/
// You can delete this file if you're not using it
module.exports = {
pathPrefix: `/gatsby/`,
siteMetadata: {
title: `Gatsby Default Starter`,
description: `Kick off your next, great Gatsby project with this default starter. This barebones starter ships with the main Gatsby configuration files you might need.`,
author: `@gatsbyjs`,
},
plugins: [
`gatsby-plugin-react-helmet`,
{
resolve: `gatsby-source-filesystem`,
options: {
name: `images`,
path: `${__dirname}/src/images`,
},
},
`gatsby-transformer-sharp`,
`gatsby-plugin-sharp`,
{
resolve: `gatsby-plugin-manifest`,
options: {
name: `gatsby-starter-default`,
short_name: `starter`,
start_url: `/`,
background_color: `#663399`,
theme_color: `#663399`,
display: `minimal-ui`,
icon: `src/images/gatsby-icon.png`, // This path is relative to the root of the site.
},
},
// this (optional) plugin enables Progressive Web App + Offline functionality
// To learn more, visit: https://gatsby.dev/offline
// `gatsby-plugin-offline`,
],
}
/**
* Implement Gatsby's Node APIs in this file.
*
* See: https://www.gatsbyjs.org/docs/node-apis/
*/
// You can delete this file if you're not using it
/**
* Implement Gatsby's SSR (Server Side Rendering) APIs in this file.
*
* See: https://www.gatsbyjs.org/docs/ssr-apis/
*/
// You can delete this file if you're not using it
This diff is collapsed.
{
"name": "gatsby-starter-default",
"private": true,
"description": "A simple starter to get up and developing quickly with Gatsby",
"version": "0.1.0",
"author": "Kyle Mathews <mathews.kyle@gmail.com>",
"dependencies": {
"gatsby": "^2.18.12",
"gatsby-image": "^2.2.34",
"gatsby-plugin-manifest": "^2.2.31",
"gatsby-plugin-offline": "^3.0.27",
"gatsby-plugin-react-helmet": "^3.1.16",
"gatsby-plugin-sharp": "^2.3.5",
"gatsby-source-filesystem": "^2.1.40",
"gatsby-transformer-sharp": "^2.3.7",
"prop-types": "^15.7.2",
"react": "^16.12.0",
"react-dom": "^16.12.0",
"react-helmet": "^5.2.1"
},
"devDependencies": {
"prettier": "^1.19.1"
},
"keywords": [
"gatsby"
],
"license": "MIT",
"scripts": {
"build": "gatsby build --prefix-paths",
"develop": "gatsby develop",
"format": "prettier --write \"**/*.{js,jsx,json,md}\"",
"start": "npm run develop",
"serve": "gatsby serve",
"clean": "gatsby clean",
"test": "echo \"Write tests! -> https://gatsby.dev/unit-testing\""
},
"repository": {
"type": "git",
"url": "https://github.com/gatsbyjs/gatsby-starter-default"
},
"bugs": {
"url": "https://github.com/gatsbyjs/gatsby/issues"
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment