"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "74837171ab0a4d23ec26a4f0d54df8f82c700247"
Commit c04f261a authored by dongchy920's avatar dongchy920
Browse files

InstruceBLIP

parents
Pipeline #1594 canceled with stages
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# project-specific
output/
debug*/
*.bak
*.dir
*.dat
*.tsv
*.gz
cache/
# datasets
lavis/coco
lavis/vg
lavis/scienceqa
\ No newline at end of file
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0
hooks:
- id: trailing-whitespace
- id: check-ast
- id: no-commit-to-branch
args: ['--branch=main']
- id: check-added-large-files
args: ['--maxkb=5000']
- id: end-of-file-fixer
- repo: https://github.com/psf/black
rev: stable
hooks:
- id: black
language_version: python3.8
- repo: https://github.com/PyCQA/flake8
rev: 3.9.2
hooks:
- id: flake8
args: [
# only error for syntax errors and undefined names
"--select=E9,F63,F7,F82",
]
# Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing.
#ECCN:Open Source
\ No newline at end of file
# Salesforce Open Source Community Code of Conduct
## About the Code of Conduct
Equality is a core value at Salesforce. We believe a diverse and inclusive
community fosters innovation and creativity, and are committed to building a
culture where everyone feels included.
Salesforce open-source projects are committed to providing a friendly, safe, and
welcoming environment for all, regardless of gender identity and expression,
sexual orientation, disability, physical appearance, body size, ethnicity, nationality,
race, age, religion, level of experience, education, socioeconomic status, or
other similar personal characteristics.
The goal of this code of conduct is to specify a baseline standard of behavior so
that people with different social values and communication styles can work
together effectively, productively, and respectfully in our open source community.
It also establishes a mechanism for reporting issues and resolving conflicts.
All questions and reports of abusive, harassing, or otherwise unacceptable behavior
in a Salesforce open-source project may be reported by contacting the Salesforce
Open Source Conduct Committee at ossconduct@salesforce.com.
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to making participation in our project and
our community a harassment-free experience for everyone, regardless of gender
identity and expression, sexual orientation, disability, physical appearance,
body size, ethnicity, nationality, race, age, religion, level of experience, education,
socioeconomic status, or other similar personal characteristics.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy toward other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Personal attacks, insulting/derogatory comments, or trolling
* Public or private harassment
* Publishing, or threatening to publish, others' private information—such as
a physical or electronic address—without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
* Advocating for or encouraging any of the above behaviors
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned with this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies both within project spaces and in public spaces
when an individual is representing the project or its community. Examples of
representing a project or community include using an official project email
address, posting via an official social media account, or acting as an appointed
representative at an online or offline event. Representation of a project may be
further defined and clarified by project maintainers.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the Salesforce Open Source Conduct Committee
at ossconduct@salesforce.com. All complaints will be reviewed and investigated
and will result in a response that is deemed necessary and appropriate to the
circumstances. The committee is obligated to maintain confidentiality with
regard to the reporter of an incident. Further details of specific enforcement
policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership and the Salesforce Open Source Conduct
Committee.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home],
version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html.
It includes adaptions and additions from [Go Community Code of Conduct][golang-coc],
[CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc].
This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us].
[contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/)
[golang-coc]: https://golang.org/conduct
[cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md
[microsoft-coc]: https://opensource.microsoft.com/codeofconduct/
[cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/
\ No newline at end of file
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [2023] Lightning AI
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
BSD 3-Clause License
Copyright (c) 2022 Salesforce, Inc.
All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
recursive-include lavis/configs *.yaml *.json
recursive-include lavis/projects *.yaml *.json
recursive-exclude lavis/datasets/download_scripts *
recursive-exclude lavis/output *
include requirements.txt
include lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz
# InstructBLIP
## 论文
- https://arxiv.org/pdf/2305.06500
## 模型结构
InstructBLIP是BLIP系列的最新模型,将LLM领域的指令微调技术应用到视觉语言多模态领域中,在多项视觉和语言任务取得SOTA。BLIP系列模型包含原始的BLIP模型、BLIP2以及本项目的InstrutBLIP。
<div align=center>
<img src="./assets/InstructBLIP.png"/>
</div>
## 算法原理
BLIP(Bootstrapping Language-Image Pretraining)是salesforce在2022年提出的多模态框架,是理解和生成的统一,引入了跨模态的编码器和解码器,实现了跨模态信息流动。相比于CLIP使用image encoder和txt encoder两个独立的模块分别对图像和文本进行编码将输入投影到潜在向量空间中,BLIP的创新点在于引入了编码器-解码器的多模态混合结构MED( Multimodal mixture of Encoder-Decoder),能够有效地进行多任务预学习和迁移学习。MED包括两个单模态编码器(lmage Encoder,Text Encoder),一个以图像为基础的文本编码器(image-grounded text encoder)和一个以图像为基础的文本解码器(image-grounded text decoder),并设计了三类损失联合优化:
<div align=center>
<img src="./assets/BLIP.png"/>
</div>
与ALBEF、BLIP类似,BLIP-2的目标是训练一个图文多模态预训练模型。不同点是,BLIP-2是在冻结图像编码器和文本编码器的情况下,训练出多模态预训练模型。这给BLIP-2带来了巨大优势,也是其最大的创新点。为了让图像和文本产生交互,BLIP2设计了Q-Former模块来对图文进行对齐:
<div align=center>
<img src="./assets/BLIP2.png"/>
</div>
InstructBLIP 的架构和 BLIP-2 相似,从预训练好的 BLIP-2 模型初始化,由图像编码器、LLM 和 Q-Former 组成。在指令微调期间只训练 Q-Former,冻结图像编码器和 LLM 的参数。
## 环境配置
### Docker(方法一)
[光源](https://www.sourcefind.cn/#/service-list)中拉取docker镜像:
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.8
```
创建容器并挂载目录进行开发:
```
docker run -it --name {name} --shm-size=1024G --device=/dev/kfd --device=/dev/dri/ --privileged --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ulimit memlock=-1:-1 --ipc=host --network host --group-add video -v /opt/hyhal:/opt/hyhal:ro -v {}:{} {docker_image} /bin/bash
# 修改1 {name} 需要改为自定义名称,建议命名{框架_dtk版本_使用者姓名},如果有特殊用途可在命名框架前添加命名
# 修改2 {docker_image} 需要需要创建容器的对应镜像名称,如: image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.8
# 修改3 -v 挂载路径到容器指定路径
pip install -r requirements.txt
```
### Dockerfile(方法二)
```
cd docker
docker build --no-cache -t InstrutBLIP_pytorch:1.0 .
docker run -it --name {name} --shm-size=1024G --device=/dev/kfd --device=/dev/dri/ --privileged --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ulimit memlock=-1:-1 --ipc=host --network host --group-add video -v /opt/hyhal:/opt/hyhal:ro -v {}:{} {docker_image} /bin/bash
pip install -r requirements.txt
```
### Anaconda(方法三)
线上节点推荐使用conda进行环境配置。
创建python=3.8的conda环境并激活
```
conda create -n InstrutBLIP python=3.8
conda activate InstrutBLIP
```
关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.hpccube.com/tool/)开发者社区下载安装。
```
DTK驱动:dtk24.04.1
python:python3.8
pytorch:2.1.0
torchvision:0.16.0
```
安装其他依赖包
```
pip install -r requirements.txt
# 项目需要从huggingface中自动下载部分模型,需要使用huggingface镜像
pip install -U huggingface_hub
export HF_ENDPOINT=https://hf-mirror.com
```
## 数据集
下载[ScienceQA](https://huggingface.co/datasets/derek-thomas/ScienceQA/tree/main)数据集
对ScienceQA进行预处理:
```
# 修改预处理代码中scienceqa数据集解压位置
python scienceqa_data_preprocess.py
```
该命令将scienceQA转为指令微调数据集,指令格式为:
```
<Image> Context: { {hint} {lecture} } Question: { {question} } Options: { {choices} } Answer: (a) { {answer} }
```
## 训练
使用预训练权重[vicuna-7b-v1.1](https://huggingface.co/lmsys/vicuna-7b-v1.1)或者[vicuna-13b-v1.1](https://huggingface.co/lmsys/vicuna-13b-v1.1)进行LoRA微调
修改lavis/config/datasets/scienceqa/default.yaml的数据集路径
开始训练:
```
# scienceqa为指定数据集,60为实验序号,对应于config文件:lavis/projects/instructblip/train/scienceqa/finetune_instructblip_scienceqa_60.yaml
bash run_scripts/instructblip/train/run_finetune_instructblip_experiments.sh scienceqa 60
```
## 推理
- 下载LLM模型vicuna-7b或者vicuna-13b
[vicuna-7b-v1.1](https://huggingface.co/lmsys/vicuna-7b-v1.1)
[vicuna-13b-v1.1](https://huggingface.co/lmsys/vicuna-13b-v1.1)
- 下载instructBLIP模型instruct_blip_vicuna7b或者instruct_blip_vicuna13b
[instruct_blip_vicuna7b_trimmed](https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_vicuna7b_trimmed.pth)
[instruct_blip_vicuna13b_trimmed](https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_vicuna13b_trimmed.pth)
修改lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml和blip2_instruct_vicuna13b.yaml中预训练模型pretrained和llm_model的本地位置:
```
pretrained: "/mnt/LAVIS-main/instruct_blip_vicuna7b_trimmed.pth"
llm_model: "/mnt/LAVIS-main/vicuna-7b-v1.1"
```
运行推理代码,自动下载一些预训练模型后会在命令行输出一个本地URL链接,点击链接就可以进入webUI界面:
```
python projects/instructblip/run_demo.py
```
## result
使用gradio搭建可视化界面,首先上传图像,然后输入Prompt提示词后点击Submit,在右侧output会输出文本对应的描述:
<div align=center>
<img src="./assets/demo.png"/>
</div>
## 精度
## 应用场景
### 算法类别
多模态
### 热点应用行业
AIGC,设计,教育
## 源码仓库及问题反馈
[https://developer.hpccube.com/codes/modelzoo/InstructBLIP_pytorch](https://developer.hpccube.com/codes/modelzoo/InstructBLIP_pytorch)
## 参考资料
[https://github.com/salesforce/LAVIS/tree/main](https://github.com/salesforce/LAVIS/tree/main)
[https://github.com/AttentionX/InstructBLIP_PEFT](https://github.com/AttentionX/InstructBLIP_PEFT)
## Security
Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com)
as soon as it is discovered. This library limits its runtime dependencies in
order to reduce the total cost of ownership as much as can be, but all consumers
should remain vigilant and have their security stakeholders review all third-party
products (3PP) like this one and their dependencies.
\ No newline at end of file
"""
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from PIL import Image
import requests
import streamlit as st
import torch
@st.cache()
def load_demo_image():
img_url = (
"https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
)
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
return raw_image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cache_root = "/export/home/.cache/lavis/"
"""
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from PIL import Image
import requests
import torch
import os
from lavis.common.registry import registry
from lavis.processors import *
from lavis.models import *
from lavis.common.utils import build_default_model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_demo_image():
img_url = (
"https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
)
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
return raw_image
def read_img(filepath):
raw_image = Image.open(filepath).convert("RGB")
return raw_image
# model
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth"
feature_extractor = BlipFeatureExtractor(pretrained=model_url)
feature_extractor.eval()
feature_extractor = feature_extractor.to(device)
# preprocessors
vis_processor = BlipImageEvalProcessor(image_size=224)
text_processor = BlipCaptionProcessor()
# files to process
# file_root = "/export/home/.cache/lavis/coco/images/val2014"
file_root = "/export/home/.cache/lavis/coco/images/train2014"
filepaths = os.listdir(file_root)
print(len(filepaths))
caption = "dummy"
path2feat = dict()
bsz = 256
images_in_batch = []
filepaths_in_batch = []
for i, filename in enumerate(filepaths):
if i % bsz == 0 and i > 0:
images_in_batch = torch.cat(images_in_batch, dim=0).to(device)
with torch.no_grad():
image_features = feature_extractor(
images_in_batch, caption, mode="image", normalized=True
)[:, 0]
for filepath, image_feat in zip(filepaths_in_batch, image_features):
path2feat[os.path.basename(filepath)] = image_feat.detach().cpu()
images_in_batch = []
filepaths_in_batch = []
print(len(path2feat), image_features.shape)
else:
filepath = os.path.join(file_root, filename)
image = read_img(filepath)
image = vis_processor(image).unsqueeze(0)
images_in_batch.append(image)
filepaths_in_batch.append(filepath)
torch.save(path2feat, "path2feat_coco_train2014.pth")
"""
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import streamlit as st
from app import device, load_demo_image
from app.utils import load_model_cache
from lavis.processors import load_processor
from PIL import Image
def app():
# ===== layout =====
model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"])
sampling_method = st.sidebar.selectbox(
"Sampling method:", ["Beam search", "Nucleus sampling"]
)
st.markdown(
"<h1 style='text-align: center;'>Image Description Generation</h1>",
unsafe_allow_html=True,
)
instructions = """Try the provided image or upload your own:"""
file = st.file_uploader(instructions)
use_beam = sampling_method == "Beam search"
col1, col2 = st.columns(2)
if file:
raw_img = Image.open(file).convert("RGB")
else:
raw_img = load_demo_image()
col1.header("Image")
w, h = raw_img.size
scaling_factor = 720 / w
resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))
col1.image(resized_image, use_column_width=True)
col2.header("Description")
cap_button = st.button("Generate")
# ==== event ====
vis_processor = load_processor("blip_image_eval").build(image_size=384)
if cap_button:
if model_type.startswith("BLIP"):
blip_type = model_type.split("_")[1].lower()
model = load_model_cache(
"blip_caption",
model_type=f"{blip_type}_coco",
is_eval=True,
device=device,
)
img = vis_processor(raw_img).unsqueeze(0).to(device)
captions = generate_caption(
model=model, image=img, use_nucleus_sampling=not use_beam
)
col2.write("\n\n".join(captions), use_column_width=True)
def generate_caption(
model, image, use_nucleus_sampling=False, num_beams=3, max_length=40, min_length=5
):
samples = {"image": image}
captions = []
if use_nucleus_sampling:
for _ in range(5):
caption = model.generate(
samples,
use_nucleus_sampling=True,
max_length=max_length,
min_length=min_length,
top_p=0.9,
)
captions.append(caption[0])
else:
caption = model.generate(
samples,
use_nucleus_sampling=False,
num_beams=num_beams,
max_length=max_length,
min_length=min_length,
)
captions.append(caption[0])
return captions
"""
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import plotly.graph_objects as go
import requests
import streamlit as st
import torch
from lavis.models import load_model
from lavis.processors import load_processor
from lavis.processors.blip_processors import BlipCaptionProcessor
from PIL import Image
from app import device, load_demo_image
from app.utils import load_blip_itm_model
from lavis.processors.clip_processors import ClipImageEvalProcessor
@st.cache()
def load_demo_image(img_url=None):
if not img_url:
img_url = "https://img.atlasobscura.com/yDJ86L8Ou6aIjBsxnlAy5f164w1rjTgcHZcx2yUs4mo/rt:fit/w:1200/q:81/sm:1/scp:1/ar:1/aHR0cHM6Ly9hdGxh/cy1kZXYuczMuYW1h/em9uYXdzLmNvbS91/cGxvYWRzL3BsYWNl/X2ltYWdlcy85MDll/MDRjOS00NTJjLTQx/NzQtYTY4MS02NmQw/MzI2YWIzNjk1ZGVk/MGZhMTJiMTM5MmZi/NGFfUmVhcl92aWV3/X29mX3RoZV9NZXJs/aW9uX3N0YXR1ZV9h/dF9NZXJsaW9uX1Bh/cmssX1NpbmdhcG9y/ZSxfd2l0aF9NYXJp/bmFfQmF5X1NhbmRz/X2luX3RoZV9kaXN0/YW5jZV8tXzIwMTQw/MzA3LmpwZw.jpg"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
return raw_image
@st.cache(
hash_funcs={
torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
.cpu()
.numpy()
},
allow_output_mutation=True,
)
def load_model_cache(model_type, device):
if model_type == "blip":
model = load_model(
"blip_feature_extractor", model_type="base", is_eval=True, device=device
)
elif model_type == "albef":
model = load_model(
"albef_feature_extractor", model_type="base", is_eval=True, device=device
)
elif model_type == "CLIP_ViT-B-32":
model = load_model(
"clip_feature_extractor", "ViT-B-32", is_eval=True, device=device
)
elif model_type == "CLIP_ViT-B-16":
model = load_model(
"clip_feature_extractor", "ViT-B-16", is_eval=True, device=device
)
elif model_type == "CLIP_ViT-L-14":
model = load_model(
"clip_feature_extractor", "ViT-L-14", is_eval=True, device=device
)
return model
def app():
model_type = st.sidebar.selectbox(
"Model:",
["ALBEF", "BLIP_Base", "CLIP_ViT-B-32", "CLIP_ViT-B-16", "CLIP_ViT-L-14"],
)
score_type = st.sidebar.selectbox("Score type:", ["Cosine", "Multimodal"])
# ===== layout =====
st.markdown(
"<h1 style='text-align: center;'>Zero-shot Classification</h1>",
unsafe_allow_html=True,
)
instructions = """Try the provided image or upload your own:"""
file = st.file_uploader(instructions)
st.header("Image")
if file:
raw_img = Image.open(file).convert("RGB")
else:
raw_img = load_demo_image()
st.image(raw_img) # , use_column_width=True)
col1, col2 = st.columns(2)
col1.header("Categories")
cls_0 = col1.text_input("category 1", value="merlion")
cls_1 = col1.text_input("category 2", value="sky")
cls_2 = col1.text_input("category 3", value="giraffe")
cls_3 = col1.text_input("category 4", value="fountain")
cls_4 = col1.text_input("category 5", value="marina bay")
cls_names = [cls_0, cls_1, cls_2, cls_3, cls_4]
cls_names = [cls_nm for cls_nm in cls_names if len(cls_nm) > 0]
if len(cls_names) != len(set(cls_names)):
st.error("Please provide unique class names")
return
button = st.button("Submit")
col2.header("Prediction")
# ===== event =====
if button:
if model_type.startswith("BLIP"):
text_processor = BlipCaptionProcessor(prompt="A picture of ")
cls_prompt = [text_processor(cls_nm) for cls_nm in cls_names]
if score_type == "Cosine":
vis_processor = load_processor("blip_image_eval").build(image_size=224)
img = vis_processor(raw_img).unsqueeze(0).to(device)
feature_extractor = load_model_cache(model_type="blip", device=device)
sample = {"image": img, "text_input": cls_prompt}
with torch.no_grad():
image_features = feature_extractor.extract_features(
sample, mode="image"
).image_embeds_proj[:, 0]
text_features = feature_extractor.extract_features(
sample, mode="text"
).text_embeds_proj[:, 0]
sims = (image_features @ text_features.t())[
0
] / feature_extractor.temp
else:
vis_processor = load_processor("blip_image_eval").build(image_size=384)
img = vis_processor(raw_img).unsqueeze(0).to(device)
model = load_blip_itm_model(device)
output = model(img, cls_prompt, match_head="itm")
sims = output[:, 1]
sims = torch.nn.Softmax(dim=0)(sims)
inv_sims = [sim * 100 for sim in sims.tolist()[::-1]]
elif model_type.startswith("ALBEF"):
vis_processor = load_processor("blip_image_eval").build(image_size=224)
img = vis_processor(raw_img).unsqueeze(0).to(device)
text_processor = BlipCaptionProcessor(prompt="A picture of ")
cls_prompt = [text_processor(cls_nm) for cls_nm in cls_names]
feature_extractor = load_model_cache(model_type="albef", device=device)
sample = {"image": img, "text_input": cls_prompt}
with torch.no_grad():
image_features = feature_extractor.extract_features(
sample, mode="image"
).image_embeds_proj[:, 0]
text_features = feature_extractor.extract_features(
sample, mode="text"
).text_embeds_proj[:, 0]
st.write(image_features.shape)
st.write(text_features.shape)
sims = (image_features @ text_features.t())[0] / feature_extractor.temp
sims = torch.nn.Softmax(dim=0)(sims)
inv_sims = [sim * 100 for sim in sims.tolist()[::-1]]
elif model_type.startswith("CLIP"):
if model_type == "CLIP_ViT-B-32":
model = load_model_cache(model_type="CLIP_ViT-B-32", device=device)
elif model_type == "CLIP_ViT-B-16":
model = load_model_cache(model_type="CLIP_ViT-B-16", device=device)
elif model_type == "CLIP_ViT-L-14":
model = load_model_cache(model_type="CLIP_ViT-L-14", device=device)
else:
raise ValueError(f"Unknown model type {model_type}")
if score_type == "Cosine":
# image_preprocess = ClipImageEvalProcessor(image_size=336)
image_preprocess = ClipImageEvalProcessor(image_size=224)
img = image_preprocess(raw_img).unsqueeze(0).to(device)
sample = {"image": img, "text_input": cls_names}
with torch.no_grad():
clip_features = model.extract_features(sample)
image_features = clip_features.image_embeds_proj
text_features = clip_features.text_embeds_proj
sims = (100.0 * image_features @ text_features.T)[0].softmax(dim=-1)
inv_sims = sims.tolist()[::-1]
else:
st.warning("CLIP does not support multimodal scoring.")
return
fig = go.Figure(
go.Bar(
x=inv_sims,
y=cls_names[::-1],
text=["{:.2f}".format(s) for s in inv_sims],
orientation="h",
)
)
fig.update_traces(
textfont_size=12,
textangle=0,
textposition="outside",
cliponaxis=False,
)
col2.plotly_chart(fig, use_container_width=True)
"""
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import random
from collections import OrderedDict
from functools import reduce
from tkinter import N
import streamlit as st
from lavis.common.registry import registry
from lavis.datasets.builders import dataset_zoo, load_dataset
from lavis.datasets.builders.base_dataset_builder import load_dataset_config
from PIL import Image
IMAGE_LAYOUT = 3, 4
VIDEO_LAYOUT = 1, 2
PREV_STR = "Prev"
NEXT_STR = "Next"
def sample_dataset(dataset, indices):
samples = [dataset.displ_item(idx) for idx in indices]
return samples
def get_concat_v(im1, im2):
margin = 5
canvas_size = (im1.width + im2.width + margin, max(im1.height, im2.height))
canvas = Image.new("RGB", canvas_size, "White")
canvas.paste(im1, (0, 0))
canvas.paste(im2, (im1.width + margin, 0))
return canvas
def resize_img_w(raw_img, new_w=224):
if isinstance(raw_img, list):
resized_imgs = [resize_img_w(img, 196) for img in raw_img]
# concatenate images
resized_image = reduce(get_concat_v, resized_imgs)
else:
w, h = raw_img.size
scaling_factor = new_w / w
resized_image = raw_img.resize(
(int(w * scaling_factor), int(h * scaling_factor))
)
return resized_image
def get_visual_key(dataset):
if "image" in dataset[0]:
return "image"
elif "image0" in dataset[0]: # NLVR2 dataset
return "image"
elif "video" in dataset[0]:
return "video"
else:
raise ValueError("Visual key not found.")
def gather_items(samples, exclude=[]):
gathered = []
for s in samples:
ns = OrderedDict()
for k in s.keys():
if k not in exclude:
ns[k] = s[k]
gathered.append(ns)
return gathered
@st.cache(allow_output_mutation=True)
def load_dataset_cache(name):
return load_dataset(name)
def format_text(text):
md = "\n\n".join([f"**{k}**: {v}" for k, v in text.items()])
return md
def show_samples(dataset, offset=0, is_next=False):
visual_key = get_visual_key(dataset)
num_rows, num_cols = IMAGE_LAYOUT if visual_key == "image" else VIDEO_LAYOUT
n_samples = num_rows * num_cols
if not shuffle:
if is_next:
start = min(int(start_idx) + offset + n_samples, len(dataset) - n_samples)
else:
start = max(0, int(start_idx) + offset - n_samples)
st.session_state.last_start = start
end = min(start + n_samples, len(dataset))
indices = list(range(start, end))
else:
indices = random.sample(range(len(dataset)), n_samples)
samples = sample_dataset(dataset, indices)
visual_info = (
iter([resize_img_w(s[visual_key]) for s in samples])
if visual_key == "image"
# else iter([s[visual_key] for s in samples])
else iter([s["file"] for s in samples])
)
text_info = gather_items(samples, exclude=["image", "video"])
text_info = iter([format_text(s) for s in text_info])
st.markdown(
"""<hr style="height:1px;border:none;color:#c7ccd4;background-color:#c7ccd4;"/> """,
unsafe_allow_html=True,
)
for _ in range(num_rows):
with st.container():
for col in st.columns(num_cols):
# col.text(next(text_info))
# col.caption(next(text_info))
try:
col.markdown(next(text_info))
if visual_key == "image":
col.image(next(visual_info), use_column_width=True, clamp=True)
elif visual_key == "video":
col.markdown(
"![Alt Text](https://media.giphy.com/media/vFKqnCdLPNOKc/giphy.gif)"
)
except StopIteration:
break
st.markdown(
"""<hr style="height:1px;border:none;color:#c7ccd4;background-color:#c7ccd4;"/> """,
unsafe_allow_html=True,
)
st.session_state.n_display = n_samples
if __name__ == "__main__":
st.set_page_config(
page_title="LAVIS Dataset Explorer",
# layout="wide",
initial_sidebar_state="expanded",
)
dataset_name = st.sidebar.selectbox("Dataset:", dataset_zoo.get_names())
function = st.sidebar.selectbox("Function:", ["Browser"], index=0)
if function == "Browser":
shuffle = st.sidebar.selectbox("Shuffled:", [True, False], index=0)
dataset = load_dataset_cache(dataset_name)
split = st.sidebar.selectbox("Split:", dataset.keys())
dataset_len = len(dataset[split])
st.success(
f"Loaded {dataset_name}/{split} with **{dataset_len}** records. **Image/video directory**: {dataset[split].vis_root}"
)
if "last_dataset" not in st.session_state:
st.session_state.last_dataset = dataset_name
st.session_state.last_split = split
if "last_start" not in st.session_state:
st.session_state.last_start = 0
if "start_idx" not in st.session_state:
st.session_state.start_idx = 0
if "shuffle" not in st.session_state:
st.session_state.shuffle = shuffle
if "first_run" not in st.session_state:
st.session_state.first_run = True
elif (
st.session_state.last_dataset != dataset_name
or st.session_state.last_split != split
):
st.session_state.first_run = True
st.session_state.last_dataset = dataset_name
st.session_state.last_split = split
elif st.session_state.shuffle != shuffle:
st.session_state.shuffle = shuffle
st.session_state.first_run = True
if not shuffle:
n_col, p_col = st.columns([0.05, 1])
prev_button = n_col.button(PREV_STR)
next_button = p_col.button(NEXT_STR)
else:
next_button = st.button(NEXT_STR)
if not shuffle:
start_idx = st.sidebar.text_input(f"Begin from (total {dataset_len})", 0)
if not start_idx.isdigit():
st.error(f"Input to 'Begin from' must be digits, found {start_idx}.")
else:
if int(start_idx) != st.session_state.start_idx:
st.session_state.start_idx = int(start_idx)
st.session_state.last_start = int(start_idx)
if prev_button:
show_samples(
dataset[split],
offset=st.session_state.last_start - st.session_state.start_idx,
is_next=False,
)
if next_button:
show_samples(
dataset[split],
offset=st.session_state.last_start - st.session_state.start_idx,
is_next=True,
)
if st.session_state.first_run:
st.session_state.first_run = False
show_samples(
dataset[split],
offset=st.session_state.last_start - st.session_state.start_idx,
is_next=True,
)
"""
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import numpy as np
import streamlit as st
import torch
from lavis.models.blip_models.blip_image_text_matching import compute_gradcam
from lavis.processors import load_processor
from PIL import Image
from app import device, load_demo_image
from app.utils import getAttMap, init_bert_tokenizer, load_blip_itm_model
def app():
model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"])
if model_type.startswith("BLIP"):
blip_type = model_type.split("_")[1]
model = load_blip_itm_model(device, model_type=blip_type)
vis_processor = load_processor("blip_image_eval").build(image_size=384)
st.markdown(
"<h1 style='text-align: center;'>Image Text Matching</h1>",
unsafe_allow_html=True,
)
values = list(range(1, 12))
default_layer_num = values.index(7)
layer_num = (
st.sidebar.selectbox("Layer number", values, index=default_layer_num) - 1
)
instructions = """Try the provided image or upload your own:"""
file = st.file_uploader(instructions)
col1, col2 = st.columns(2)
col1.header("Image")
col2.header("GradCam")
if file:
raw_img = Image.open(file).convert("RGB")
else:
raw_img = load_demo_image()
w, h = raw_img.size
scaling_factor = 720 / w
resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))
col1.image(resized_image, use_column_width=True)
col3, col4 = st.columns(2)
col3.header("Text")
user_question = col3.text_input(
"Input your sentence!", "a woman sitting on the beach with a dog"
)
submit_button = col3.button("Submit")
col4.header("Matching score")
if submit_button:
tokenizer = init_bert_tokenizer()
img = vis_processor(raw_img).unsqueeze(0).to(device)
text_processor = load_processor("blip_caption").build()
qry = text_processor(user_question)
norm_img = np.float32(resized_image) / 255
qry_tok = tokenizer(qry, return_tensors="pt").to(device)
gradcam, output = compute_gradcam(model, img, qry, qry_tok, block_num=layer_num)
avg_gradcam = getAttMap(norm_img, gradcam[0][1], blur=True)
col2.image(avg_gradcam, use_column_width=True, clamp=True)
# output = model(img, question)
itm_score = torch.nn.functional.softmax(output, dim=1)
new_title = (
'<p style="text-align: left; font-size: 25px;">\n{:.3f}%</p>'.format(
itm_score[0][1].item() * 100
)
)
col4.markdown(new_title, unsafe_allow_html=True)
"""
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from app.multipage import MultiPage
from app import vqa, caption
from app import image_text_match as itm
from app import text_localization as tl
from app import multimodal_search as ms
from app import classification as cl
if __name__ == "__main__":
app = MultiPage()
app.add_page("Image Description Generation", caption.app)
app.add_page("Multimodal Search", ms.app)
app.add_page("Visual Question Answering", vqa.app)
app.add_page("Image Text Matching", itm.app)
app.add_page("Text Localization", tl.app)
app.add_page("Classification", cl.app)
app.run()
"""
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import os
import numpy as np
import streamlit as st
import torch
import torch.nn.functional as F
from app import cache_root, device
from app.utils import (
getAttMap,
init_bert_tokenizer,
load_blip_itm_model,
read_img,
resize_img,
)
from lavis.models import load_model
from lavis.processors import load_processor
@st.cache(
hash_funcs={
torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
.cpu()
.numpy()
},
allow_output_mutation=True,
)
def load_feat():
from lavis.common.utils import download_url
dirname = os.path.join(os.path.dirname(__file__), "assets")
filename = "path2feat_coco_train2014.pth"
filepath = os.path.join(dirname, filename)
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/path2feat_coco_train2014.pth"
if not os.path.exists(filepath):
download_url(url=url, root=dirname, filename="path2feat_coco_train2014.pth")
path2feat = torch.load(filepath)
paths = sorted(path2feat.keys())
all_img_feats = torch.stack([path2feat[k] for k in paths], dim=0).to(device)
return path2feat, paths, all_img_feats
@st.cache(
hash_funcs={
torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
.cpu()
.numpy()
},
allow_output_mutation=True,
)
def load_feature_extractor_model(device):
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth"
model = load_model(
"blip_feature_extractor", model_type="base", is_eval=True, device=device
)
model.load_from_pretrained(model_url)
return model
def app():
# === layout ===
model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"])
file_root = os.path.join(cache_root, "coco/images/train2014/")
values = [12, 24, 48]
default_layer_num = values.index(24)
num_display = st.sidebar.selectbox(
"Number of images:", values, index=default_layer_num
)
show_gradcam = st.sidebar.selectbox("Show GradCam:", [True, False], index=1)
itm_ranking = st.sidebar.selectbox("Multimodal re-ranking:", [True, False], index=0)
# st.title('Multimodal Search')
st.markdown(
"<h1 style='text-align: center;'>Multimodal Search</h1>", unsafe_allow_html=True
)
# === event ===
vis_processor = load_processor("blip_image_eval").build(image_size=384)
text_processor = load_processor("blip_caption")
user_question = st.text_input(
"Search query", "A dog running on the grass.", help="Type something to search."
)
user_question = text_processor(user_question)
feature_extractor = load_feature_extractor_model(device)
# ======= ITC =========
sample = {"text_input": user_question}
with torch.no_grad():
text_feature = feature_extractor.extract_features(
sample, mode="text"
).text_embeds_proj[0, 0]
path2feat, paths, all_img_feats = load_feat()
all_img_feats.to(device)
all_img_feats = F.normalize(all_img_feats, dim=1)
num_cols = 4
num_rows = int(num_display / num_cols)
similarities = text_feature @ all_img_feats.T
indices = torch.argsort(similarities, descending=True)[:num_display]
top_paths = [paths[ind.detach().cpu().item()] for ind in indices]
sorted_similarities = [similarities[idx] for idx in indices]
filenames = [os.path.join(file_root, p) for p in top_paths]
# ========= ITM and GradCam ==========
bsz = 4 # max number of images to avoid cuda oom
if model_type.startswith("BLIP"):
blip_type = model_type.split("_")[1]
itm_model = load_blip_itm_model(device, model_type=blip_type)
tokenizer = init_bert_tokenizer()
queries_batch = [user_question] * bsz
queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to(device)
num_batches = int(num_display / bsz)
avg_gradcams = []
all_raw_images = []
itm_scores = []
for i in range(num_batches):
filenames_in_batch = filenames[i * bsz : (i + 1) * bsz]
raw_images, images = read_and_process_images(filenames_in_batch, vis_processor)
gradcam, itm_output = compute_gradcam_batch(
itm_model, images, queries_batch, queries_tok_batch
)
all_raw_images.extend([resize_img(r_img) for r_img in raw_images])
norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images]
for norm_img, grad_cam in zip(norm_imgs, gradcam):
avg_gradcam = getAttMap(norm_img, grad_cam[0], blur=True)
avg_gradcams.append(avg_gradcam)
with torch.no_grad():
itm_score = torch.nn.functional.softmax(itm_output, dim=1)
itm_scores.append(itm_score)
# ========= ITM re-ranking =========
itm_scores = torch.cat(itm_scores)[:, 1]
if itm_ranking:
itm_scores_sorted, indices = torch.sort(itm_scores, descending=True)
avg_gradcams_sorted = []
all_raw_images_sorted = []
for idx in indices:
avg_gradcams_sorted.append(avg_gradcams[idx])
all_raw_images_sorted.append(all_raw_images[idx])
avg_gradcams = avg_gradcams_sorted
all_raw_images = all_raw_images_sorted
if show_gradcam:
images_to_show = iter(avg_gradcams)
else:
images_to_show = iter(all_raw_images)
for _ in range(num_rows):
with st.container():
for col in st.columns(num_cols):
col.image(next(images_to_show), use_column_width=True, clamp=True)
def read_and_process_images(image_paths, vis_processor):
raw_images = [read_img(path) for path in image_paths]
images = [vis_processor(r_img) for r_img in raw_images]
images_tensors = torch.stack(images).to(device)
return raw_images, images_tensors
def compute_gradcam_batch(model, visual_input, text_input, tokenized_text, block_num=6):
model.text_encoder.base_model.base_model.encoder.layer[
block_num
].crossattention.self.save_attention = True
output = model({"image": visual_input, "text_input": text_input}, match_head="itm")
loss = output[:, 1].sum()
model.zero_grad()
loss.backward()
with torch.no_grad():
mask = tokenized_text.attention_mask.view(
tokenized_text.attention_mask.size(0), 1, -1, 1, 1
) # (bsz,1,token_len, 1,1)
token_length = mask.sum() - 2
token_length = token_length.cpu()
# grads and cams [bsz, num_head, seq_len, image_patch]
grads = model.text_encoder.base_model.base_model.encoder.layer[
block_num
].crossattention.self.get_attn_gradients()
cams = model.text_encoder.base_model.base_model.encoder.layer[
block_num
].crossattention.self.get_attention_map()
# assume using vit large with 576 num image patch
cams = cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask
grads = (
grads[:, :, :, 1:].clamp(0).reshape(visual_input.size(0), 12, -1, 24, 24)
* mask
)
gradcam = cams * grads
# [enc token gradcam, average gradcam across token, gradcam for individual token]
# gradcam = torch.cat((gradcam[0:1,:], gradcam[1:token_length+1, :].sum(dim=0, keepdim=True)/token_length, gradcam[1:, :]))
gradcam = gradcam.mean(1).cpu().detach()
gradcam = (
gradcam[:, 1 : token_length + 1, :].sum(dim=1, keepdim=True) / token_length
)
return gradcam, output
"""
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
"""
This file is the framework for generating multiple Streamlit applications
through an object oriented framework.
"""
# Import necessary libraries
import streamlit as st
# Define the multipage class to manage the multiple apps in our program
class MultiPage:
"""Framework for combining multiple streamlit applications."""
def __init__(self) -> None:
"""Constructor class to generate a list which will store all our applications as an instance variable."""
self.pages = []
def add_page(self, title, func) -> None:
"""Class Method to Add pages to the project
Args:
title ([str]): The title of page which we are adding to the list of apps
func: Python function to render this page in Streamlit
"""
self.pages.append({"title": title, "function": func})
def run(self):
# Drodown to select the page to run
page = st.sidebar.selectbox(
"Navigation", self.pages, format_func=lambda page: page["title"]
)
# run the app function
page["function"]()
"""
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import math
import numpy as np
import streamlit as st
from lavis.models.blip_models.blip_image_text_matching import compute_gradcam
from lavis.processors import load_processor
from PIL import Image
from app import device, load_demo_image
from app.utils import getAttMap, init_bert_tokenizer, load_blip_itm_model
def app():
model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"])
values = list(range(1, 12))
default_layer_num = values.index(7)
layer_num = (
st.sidebar.selectbox("Layer number", values, index=default_layer_num) - 1
)
st.markdown(
"<h1 style='text-align: center;'>Text Localization</h1>", unsafe_allow_html=True
)
vis_processor = load_processor("blip_image_eval").build(image_size=384)
text_processor = load_processor("blip_caption")
tokenizer = init_bert_tokenizer()
instructions = "Try the provided image and text or use your own ones."
file = st.file_uploader(instructions)
query = st.text_input(
"Try a different input.", "A girl playing with her dog on the beach."
)
submit_button = st.button("Submit")
col1, col2 = st.columns(2)
if file:
raw_img = Image.open(file).convert("RGB")
else:
raw_img = load_demo_image()
col1.header("Image")
w, h = raw_img.size
scaling_factor = 720 / w
resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))
col1.image(resized_image, use_column_width=True)
col2.header("GradCam")
if submit_button:
if model_type.startswith("BLIP"):
blip_type = model_type.split("_")[1]
model = load_blip_itm_model(device, model_type=blip_type)
img = vis_processor(raw_img).unsqueeze(0).to(device)
qry = text_processor(query)
qry_tok = tokenizer(qry, return_tensors="pt").to(device)
norm_img = np.float32(resized_image) / 255
gradcam, _ = compute_gradcam(model, img, qry, qry_tok, block_num=layer_num)
avg_gradcam = getAttMap(norm_img, gradcam[0][1], blur=True)
col2.image(avg_gradcam, use_column_width=True, clamp=True)
num_cols = 4.0
num_tokens = len(qry_tok.input_ids[0]) - 2
num_rows = int(math.ceil(num_tokens / num_cols))
gradcam_iter = iter(gradcam[0][2:-1])
token_id_iter = iter(qry_tok.input_ids[0][1:-1])
for _ in range(num_rows):
with st.container():
for col in st.columns(int(num_cols)):
token_id = next(token_id_iter, None)
if not token_id:
break
gradcam_img = next(gradcam_iter)
word = tokenizer.decode([token_id])
gradcam_todraw = getAttMap(norm_img, gradcam_img, blur=True)
new_title = (
'<p style="text-align: center; font-size: 25px;">{}</p>'.format(
word
)
)
col.markdown(new_title, unsafe_allow_html=True)
# st.image(image, channels="BGR")
col.image(gradcam_todraw, use_column_width=True, clamp=True)
"""
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import numpy as np
import streamlit as st
import torch
from lavis.models import BlipBase, load_model
from matplotlib import pyplot as plt
from PIL import Image
from scipy.ndimage import filters
from skimage import transform as skimage_transform
def resize_img(raw_img):
w, h = raw_img.size
scaling_factor = 240 / w
resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))
return resized_image
def read_img(filepath):
raw_image = Image.open(filepath).convert("RGB")
return raw_image
@st.cache(
hash_funcs={
torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
.cpu()
.numpy()
},
allow_output_mutation=True,
)
def load_model_cache(name, model_type, is_eval, device):
return load_model(name, model_type, is_eval, device)
@st.cache(allow_output_mutation=True)
def init_bert_tokenizer():
tokenizer = BlipBase.init_tokenizer()
return tokenizer
def getAttMap(img, attMap, blur=True, overlap=True):
attMap -= attMap.min()
if attMap.max() > 0:
attMap /= attMap.max()
attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
if blur:
attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
attMap -= attMap.min()
attMap /= attMap.max()
cmap = plt.get_cmap("jet")
attMapV = cmap(attMap)
attMapV = np.delete(attMapV, 3, 2)
if overlap:
attMap = (
1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
+ (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
)
return attMap
@st.cache(
hash_funcs={
torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
.cpu()
.numpy()
},
allow_output_mutation=True,
)
def load_blip_itm_model(device, model_type="base"):
model = load_model(
"blip_image_text_matching", model_type, is_eval=True, device=device
)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment