"git@developer.sourcefind.cn:gaoqiong/pybind11.git" did not exist on "bee8827a98fbb77def8cbe5c7aa4de956667baec"
Commit 34e4011b authored by zk's avatar zk
Browse files

首次提交

parents
Pipeline #3503 failed with stages
in 0 seconds
# IDE
.idea/
.vscode/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# vscode
.vscode/
output/
outputs/
subs/
logs/
grounding/config/configs
grounding/version.py
vis/
tmp/
# 忽略所有权重和模型文件
*.pth
*.pt
*.onnx
*.bin
*.mxr
*.db
*.run
*.csv
*.pftrace
xiongke_log.txt
migraphx_log.txt
weights/
checkpoints/
# 忽略环境安装包
*.whl
*.tar.gz
*.zip
# 忽略缓存和临时文件
__pycache__/
*.pyc
.ipynb_checkpoints/
build/
dist/
# 忽略你的本地图片或数据集(根据你的实际文件夹名字修改)
data/
images/
*.jpg
*.png
# 忽略 IDE 配置文件
.vscode/
.idea/
\ No newline at end of file
FROM pytorch/pytorch:2.1.2-cuda12.1-cudnn8-runtime
ARG DEBIAN_FRONTEND=noninteractive
ENV CUDA_HOME=/usr/local/cuda \
TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" \
SETUPTOOLS_USE_DISTUTILS=stdlib
RUN conda update conda -y
# Install libraries in the brand new image.
RUN apt-get -y update && apt-get install -y --no-install-recommends \
wget \
build-essential \
git \
python3-opencv \
ca-certificates && \
rm -rf /var/lib/apt/lists/*
# Set the working directory for all the subsequent Dockerfile instructions.
WORKDIR /opt/program
RUN git clone https://github.com/IDEA-Research/GroundingDINO.git
RUN mkdir weights ; cd weights ; wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth ; cd ..
RUN conda install -c "nvidia/label/cuda-12.1.1" cuda -y
ENV CUDA_HOME=$CONDA_PREFIX
ENV PATH=/usr/local/cuda/bin:$PATH
RUN cd GroundingDINO/ && python -m pip install .
COPY docker_test.py docker_test.py
CMD [ "python", "docker_test.py" ]
\ No newline at end of file
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2023 - present, IDEA Research.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
<div align="center">
<img src="./.asset/grounding_dino_logo.png" width="30%">
</div>
# :sauropod: Grounding DINO
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/zero-shot-object-detection-on-mscoco)](https://paperswithcode.com/sota/zero-shot-object-detection-on-mscoco?p=grounding-dino-marrying-dino-with-grounded) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/zero-shot-object-detection-on-odinw)](https://paperswithcode.com/sota/zero-shot-object-detection-on-odinw?p=grounding-dino-marrying-dino-with-grounded) \
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/object-detection-on-coco-minival)](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=grounding-dino-marrying-dino-with-grounded) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/object-detection-on-coco)](https://paperswithcode.com/sota/object-detection-on-coco?p=grounding-dino-marrying-dino-with-grounded)
**[IDEA-CVR, IDEA-Research](https://github.com/IDEA-Research)**
[Shilong Liu](http://www.lsl.zone/), [Zhaoyang Zeng](https://scholar.google.com/citations?user=U_cvvUwAAAAJ&hl=zh-CN&oi=ao), [Tianhe Ren](https://rentainhe.github.io/), [Feng Li](https://scholar.google.com/citations?user=ybRe9GcAAAAJ&hl=zh-CN), [Hao Zhang](https://scholar.google.com/citations?user=B8hPxMQAAAAJ&hl=zh-CN), [Jie Yang](https://github.com/yangjie-cv), [Chunyuan Li](https://scholar.google.com/citations?user=Zd7WmXUAAAAJ&hl=zh-CN&oi=ao), [Jianwei Yang](https://jwyang.github.io/), [Hang Su](https://scholar.google.com/citations?hl=en&user=dxN1_X0AAAAJ&view_op=list_works&sortby=pubdate), [Jun Zhu](https://scholar.google.com/citations?hl=en&user=axsP38wAAAAJ), [Lei Zhang](https://www.leizhang.org/)<sup>:email:</sup>.
[[`Paper`](https://arxiv.org/abs/2303.05499)] [[`Demo`](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo)] [[`BibTex`](#black_nib-citation)]
PyTorch implementation and pretrained models for Grounding DINO. For details, see the paper **[Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection](https://arxiv.org/abs/2303.05499)**.
- 🔥 **[Grounded SAM 2](https://github.com/IDEA-Research/Grounded-SAM-2)** is released now, which combines Grounding DINO with [SAM 2](https://github.com/facebookresearch/segment-anything-2) for any object tracking in open-world scenarios.
- 🔥 **[Grounding DINO 1.5](https://github.com/IDEA-Research/Grounding-DINO-1.5-API)** is released now, which is IDEA Research's **Most Capable** Open-World Object Detection Model!
- 🔥 **[Grounding DINO](https://arxiv.org/abs/2303.05499)** and **[Grounded SAM](https://arxiv.org/abs/2401.14159)** are now supported in Huggingface. For more convenient use, you can refer to [this documentation](https://huggingface.co/docs/transformers/model_doc/grounding-dino)
## :sun_with_face: Helpful Tutorial
- :grapes: [[Read our arXiv Paper](https://arxiv.org/abs/2303.05499)]
- :apple: [[Watch our simple introduction video on YouTube](https://youtu.be/wxWDt5UiwY8)]
- :blossom: &nbsp;[[Try the Colab Demo](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb)]
- :sunflower: [[Try our Official Huggingface Demo](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo)]
- :maple_leaf: [[Watch the Step by Step Tutorial about GroundingDINO by Roboflow AI](https://youtu.be/cMa77r3YrDk)]
- :mushroom: [[GroundingDINO: Automated Dataset Annotation and Evaluation by Roboflow AI](https://youtu.be/C4NqaRBz_Kw)]
- :hibiscus: [[Accelerate Image Annotation with SAM and GroundingDINO by Roboflow AI](https://youtu.be/oEQYStnF2l8)]
- :white_flower: [[Autodistill: Train YOLOv8 with ZERO Annotations based on Grounding-DINO and Grounded-SAM by Roboflow AI](https://github.com/autodistill/autodistill)]
<!-- Grounding DINO Methods |
[![arXiv](https://img.shields.io/badge/arXiv-2303.05499-b31b1b.svg)](https://arxiv.org/abs/2303.05499)
[![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/wxWDt5UiwY8) -->
<!-- Grounding DINO Demos |
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) -->
<!-- [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/cMa77r3YrDk)
[![HuggingFace space](https://img.shields.io/badge/🤗-HuggingFace%20Space-cyan.svg)](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo)
[![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/oEQYStnF2l8)
[![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/C4NqaRBz_Kw) -->
## :sparkles: Highlight Projects
- [Semantic-SAM: a universal image segmentation model to enable segment and recognize anything at any desired granularity.](https://github.com/UX-Decoder/Semantic-SAM),
- [DetGPT: Detect What You Need via Reasoning](https://github.com/OptimalScale/DetGPT)
- [Grounded-SAM: Marrying Grounding DINO with Segment Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything)
- [Grounding DINO with Stable Diffusion](demo/image_editing_with_groundingdino_stablediffusion.ipynb)
- [Grounding DINO with GLIGEN for Controllable Image Editing](demo/image_editing_with_groundingdino_gligen.ipynb)
- [OpenSeeD: A Simple and Strong Openset Segmentation Model](https://github.com/IDEA-Research/OpenSeeD)
- [SEEM: Segment Everything Everywhere All at Once](https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once)
- [X-GPT: Conversational Visual Agent supported by X-Decoder](https://github.com/microsoft/X-Decoder/tree/xgpt)
- [GLIGEN: Open-Set Grounded Text-to-Image Generation](https://github.com/gligen/GLIGEN)
- [LLaVA: Large Language and Vision Assistant](https://github.com/haotian-liu/LLaVA)
<!-- Extensions | [Grounding DINO with Segment Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything); [Grounding DINO with Stable Diffusion](demo/image_editing_with_groundingdino_stablediffusion.ipynb); [Grounding DINO with GLIGEN](demo/image_editing_with_groundingdino_gligen.ipynb) -->
<!-- Official PyTorch implementation of [Grounding DINO](https://arxiv.org/abs/2303.05499), a stronger open-set object detector. Code is available now! -->
## :bulb: Highlight
- **Open-Set Detection.** Detect **everything** with language!
- **High Performance.** COCO zero-shot **52.5 AP** (training without COCO data!). COCO fine-tune **63.0 AP**.
- **Flexible.** Collaboration with Stable Diffusion for Image Editting.
## :fire: News
- **`2023/07/18`**: We release [Semantic-SAM](https://github.com/UX-Decoder/Semantic-SAM), a universal image segmentation model to enable segment and recognize anything at any desired granularity. **Code** and **checkpoint** are available!
- **`2023/06/17`**: We provide an example to evaluate Grounding DINO on COCO zero-shot performance.
- **`2023/04/15`**: Refer to [CV in the Wild Readings](https://github.com/Computer-Vision-in-the-Wild/CVinW_Readings) for those who are interested in open-set recognition!
- **`2023/04/08`**: We release [demos](demo/image_editing_with_groundingdino_gligen.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [GLIGEN](https://github.com/gligen/GLIGEN) for more controllable image editings.
- **`2023/04/08`**: We release [demos](demo/image_editing_with_groundingdino_stablediffusion.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [Stable Diffusion](https://github.com/Stability-AI/StableDiffusion) for image editings.
- **`2023/04/06`**: We build a new demo by marrying GroundingDINO with [Segment-Anything](https://github.com/facebookresearch/segment-anything) named **[Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything)** aims to support segmentation in GroundingDINO.
- **`2023/03/28`**: A YouTube [video](https://youtu.be/cMa77r3YrDk) about Grounding DINO and basic object detection prompt engineering. [[SkalskiP](https://github.com/SkalskiP)]
- **`2023/03/28`**: Add a [demo](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo) on Hugging Face Space!
- **`2023/03/27`**: Support CPU-only mode. Now the model can run on machines without GPUs.
- **`2023/03/25`**: A [demo](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) for Grounding DINO is available at Colab. [[SkalskiP](https://github.com/SkalskiP)]
- **`2023/03/22`**: Code is available Now!
<details open>
<summary><font size="4">
Description
</font></summary>
<a href="https://arxiv.org/abs/2303.05499">Paper</a> introduction.
<img src=".asset/hero_figure.png" alt="ODinW" width="100%">
Marrying <a href="https://github.com/IDEA-Research/GroundingDINO">Grounding DINO</a> and <a href="https://github.com/gligen/GLIGEN">GLIGEN</a>
<img src="https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GD_GLIGEN.png" alt="gd_gligen" width="100%">
</details>
## :star: Explanations/Tips for Grounding DINO Inputs and Outputs
- Grounding DINO accepts an `(image, text)` pair as inputs.
- It outputs `900` (by default) object boxes. Each box has similarity scores across all input words. (as shown in Figures below.)
- We defaultly choose the boxes whose highest similarities are higher than a `box_threshold`.
- We extract the words whose similarities are higher than the `text_threshold` as predicted labels.
- If you want to obtain objects of specific phrases, like the `dogs` in the sentence `two dogs with a stick.`, you can select the boxes with highest text similarities with `dogs` as final outputs.
- Note that each word can be split to **more than one** tokens with different tokenlizers. The number of words in a sentence may not equal to the number of text tokens.
- We suggest separating different category names with `.` for Grounding DINO.
![model_explain1](.asset/model_explan1.PNG)
![model_explain2](.asset/model_explan2.PNG)
## :label: TODO
- [x] Release inference code and demo.
- [x] Release checkpoints.
- [x] Grounding DINO with Stable Diffusion and GLIGEN demos.
- [ ] Release training codes.
## :hammer_and_wrench: Install
**Note:**
0. If you have a CUDA environment, please make sure the environment variable `CUDA_HOME` is set. It will be compiled under CPU-only mode if no CUDA available.
Please make sure following the installation steps strictly, otherwise the program may produce:
```bash
NameError: name '_C' is not defined
```
If this happened, please reinstalled the groundingDINO by reclone the git and do all the installation steps again.
#### how to check cuda:
```bash
echo $CUDA_HOME
```
If it print nothing, then it means you haven't set up the path/
Run this so the environment variable will be set under current shell.
```bash
export CUDA_HOME=/path/to/cuda-11.3
```
Notice the version of cuda should be aligned with your CUDA runtime, for there might exists multiple cuda at the same time.
If you want to set the CUDA_HOME permanently, store it using:
```bash
echo 'export CUDA_HOME=/path/to/cuda' >> ~/.bashrc
```
after that, source the bashrc file and check CUDA_HOME:
```bash
source ~/.bashrc
echo $CUDA_HOME
```
In this example, /path/to/cuda-11.3 should be replaced with the path where your CUDA toolkit is installed. You can find this by typing **which nvcc** in your terminal:
For instance,
if the output is /usr/local/cuda/bin/nvcc, then:
```bash
export CUDA_HOME=/usr/local/cuda
```
**Installation:**
1.Clone the GroundingDINO repository from GitHub.
```bash
git clone https://github.com/IDEA-Research/GroundingDINO.git
```
2. Change the current directory to the GroundingDINO folder.
```bash
cd GroundingDINO/
```
3. Install the required dependencies in the current directory.
```bash
pip install -e .
```
4. Download pre-trained model weights.
```bash
mkdir weights
cd weights
wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
cd ..
```
## :arrow_forward: Demo
Check your GPU ID (only if you're using a GPU)
```bash
nvidia-smi
```
Replace `{GPU ID}`, `image_you_want_to_detect.jpg`, and `"dir you want to save the output"` with appropriate values in the following command
```bash
CUDA_VISIBLE_DEVICES={GPU ID} python demo/inference_on_a_image.py \
-c groundingdino/config/GroundingDINO_SwinT_OGC.py \
-p weights/groundingdino_swint_ogc.pth \
-i image_you_want_to_detect.jpg \
-o "dir you want to save the output" \
-t "chair"
[--cpu-only] # open it for cpu mode
```
If you would like to specify the phrases to detect, here is a demo:
```bash
CUDA_VISIBLE_DEVICES={GPU ID} python demo/inference_on_a_image.py \
-c groundingdino/config/GroundingDINO_SwinT_OGC.py \
-p ./groundingdino_swint_ogc.pth \
-i .asset/cat_dog.jpeg \
-o logs/1111 \
-t "There is a cat and a dog in the image ." \
--token_spans "[[[9, 10], [11, 14]], [[19, 20], [21, 24]]]"
[--cpu-only] # open it for cpu mode
```
The token_spans specify the start and end positions of a phrases. For example, the first phrase is `[[9, 10], [11, 14]]`. `"There is a cat and a dog in the image ."[9:10] = 'a'`, `"There is a cat and a dog in the image ."[11:14] = 'cat'`. Hence it refers to the phrase `a cat` . Similarly, the `[[19, 20], [21, 24]]` refers to the phrase `a dog`.
See the `demo/inference_on_a_image.py` for more details.
**Running with Python:**
```python
from groundingdino.util.inference import load_model, load_image, predict, annotate
import cv2
model = load_model("groundingdino/config/GroundingDINO_SwinT_OGC.py", "weights/groundingdino_swint_ogc.pth")
IMAGE_PATH = "weights/dog-3.jpeg"
TEXT_PROMPT = "chair . person . dog ."
BOX_TRESHOLD = 0.35
TEXT_TRESHOLD = 0.25
image_source, image = load_image(IMAGE_PATH)
boxes, logits, phrases = predict(
model=model,
image=image,
caption=TEXT_PROMPT,
box_threshold=BOX_TRESHOLD,
text_threshold=TEXT_TRESHOLD
)
annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
cv2.imwrite("annotated_image.jpg", annotated_frame)
```
**Web UI**
We also provide a demo code to integrate Grounding DINO with Gradio Web UI. See the file `demo/gradio_app.py` for more details.
**Notebooks**
- We release [demos](demo/image_editing_with_groundingdino_gligen.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [GLIGEN](https://github.com/gligen/GLIGEN) for more controllable image editings.
- We release [demos](demo/image_editing_with_groundingdino_stablediffusion.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [Stable Diffusion](https://github.com/Stability-AI/StableDiffusion) for image editings.
## COCO Zero-shot Evaluations
We provide an example to evaluate Grounding DINO zero-shot performance on COCO. The results should be **48.5**.
```bash
CUDA_VISIBLE_DEVICES=0 \
python demo/test_ap_on_coco.py \
-c groundingdino/config/GroundingDINO_SwinT_OGC.py \
-p weights/groundingdino_swint_ogc.pth \
--anno_path /path/to/annoataions/ie/instances_val2017.json \
--image_dir /path/to/imagedir/ie/val2017
```
## :luggage: Checkpoints
<!-- insert a table -->
<table>
<thead>
<tr style="text-align: right;">
<th></th>
<th>name</th>
<th>backbone</th>
<th>Data</th>
<th>box AP on COCO</th>
<th>Checkpoint</th>
<th>Config</th>
</tr>
</thead>
<tbody>
<tr>
<th>1</th>
<td>GroundingDINO-T</td>
<td>Swin-T</td>
<td>O365,GoldG,Cap4M</td>
<td>48.4 (zero-shot) / 57.2 (fine-tune)</td>
<td><a href="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth">GitHub link</a> | <a href="https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth">HF link</a></td>
<td><a href="https://github.com/IDEA-Research/GroundingDINO/blob/main/groundingdino/config/GroundingDINO_SwinT_OGC.py">link</a></td>
</tr>
<tr>
<th>2</th>
<td>GroundingDINO-B</td>
<td>Swin-B</td>
<td>COCO,O365,GoldG,Cap4M,OpenImage,ODinW-35,RefCOCO</td>
<td>56.7 </td>
<td><a href="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha2/groundingdino_swinb_cogcoor.pth">GitHub link</a> | <a href="https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth">HF link</a>
<td><a href="https://github.com/IDEA-Research/GroundingDINO/blob/main/groundingdino/config/GroundingDINO_SwinB_cfg.py">link</a></td>
</tr>
</tbody>
</table>
## :medal_military: Results
<details open>
<summary><font size="4">
COCO Object Detection Results
</font></summary>
<img src=".asset/COCO.png" alt="COCO" width="100%">
</details>
<details open>
<summary><font size="4">
ODinW Object Detection Results
</font></summary>
<img src=".asset/ODinW.png" alt="ODinW" width="100%">
</details>
<details open>
<summary><font size="4">
Marrying Grounding DINO with <a href="https://github.com/Stability-AI/StableDiffusion">Stable Diffusion</a> for Image Editing
</font></summary>
See our example <a href="https://github.com/IDEA-Research/GroundingDINO/blob/main/demo/image_editing_with_groundingdino_stablediffusion.ipynb">notebook</a> for more details.
<img src=".asset/GD_SD.png" alt="GD_SD" width="100%">
</details>
<details open>
<summary><font size="4">
Marrying Grounding DINO with <a href="https://github.com/gligen/GLIGEN">GLIGEN</a> for more Detailed Image Editing.
</font></summary>
See our example <a href="https://github.com/IDEA-Research/GroundingDINO/blob/main/demo/image_editing_with_groundingdino_gligen.ipynb">notebook</a> for more details.
<img src=".asset/GD_GLIGEN.png" alt="GD_GLIGEN" width="100%">
</details>
## :sauropod: Model: Grounding DINO
Includes: a text backbone, an image backbone, a feature enhancer, a language-guided query selection, and a cross-modality decoder.
![arch](.asset/arch.png)
## :hearts: Acknowledgement
Our model is related to [DINO](https://github.com/IDEA-Research/DINO) and [GLIP](https://github.com/microsoft/GLIP). Thanks for their great work!
We also thank great previous work including DETR, Deformable DETR, SMCA, Conditional DETR, Anchor DETR, Dynamic DETR, DAB-DETR, DN-DETR, etc. More related work are available at [Awesome Detection Transformer](https://github.com/IDEACVR/awesome-detection-transformer). A new toolbox [detrex](https://github.com/IDEA-Research/detrex) is available as well.
Thanks [Stable Diffusion](https://github.com/Stability-AI/StableDiffusion) and [GLIGEN](https://github.com/gligen/GLIGEN) for their awesome models.
## :black_nib: Citation
If you find our work helpful for your research, please consider citing the following BibTeX entry.
```bibtex
@article{liu2023grounding,
title={Grounding dino: Marrying dino with grounded pre-training for open-set object detection},
author={Liu, Shilong and Zeng, Zhaoyang and Ren, Tianhe and Li, Feng and Zhang, Hao and Yang, Jie and Li, Chunyuan and Yang, Jianwei and Su, Hang and Zhu, Jun and others},
journal={arXiv preprint arXiv:2303.05499},
year={2023}
}
```
{
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"transformers_version": "4.6.0.dev0",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 30522
}
This source diff could not be displayed because it is too large. You can view the blob instead.
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes."""
import collections
import re
import unicodedata
import six
from functools import lru_cache
import os
import numpy as np
@lru_cache()
def default_vocab():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "vocab.txt")
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
"""Checks whether the casing config is consistent with the checkpoint name."""
# The casing has to be passed in by the user and there is no explicit check
# as to whether it matches the checkpoint. The casing information probably
# should have been stored in the bert_config.json file, but it's not, so
# we have to heuristically detect it to validate.
if not init_checkpoint:
return
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
if m is None:
return
model_name = m.group(1)
lower_models = [
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
]
cased_models = [
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
"multi_cased_L-12_H-768_A-12"
]
is_bad_config = False
if model_name in lower_models and not do_lower_case:
is_bad_config = True
actual_flag = "False"
case_name = "lowercased"
opposite_flag = "True"
if model_name in cased_models and do_lower_case:
is_bad_config = True
actual_flag = "True"
case_name = "cased"
opposite_flag = "False"
if is_bad_config:
raise ValueError(
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
"However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please "
"just comment out this check." % (actual_flag, init_checkpoint,
model_name, case_name, opposite_flag))
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def printable_text(text):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text
elif isinstance(text, unicode):
return text.encode("utf-8")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with open(vocab_file, "r", encoding="utf-8") as reader:
while True:
token = convert_to_unicode(reader.readline())
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
def convert_by_vocab(vocab, items):
"""Converts a sequence of [tokens|ids] using the vocab."""
output = []
for item in items:
output.append(vocab[item])
return output
def convert_tokens_to_ids(vocab, tokens):
return convert_by_vocab(vocab, tokens)
def convert_ids_to_tokens(inv_vocab, ids):
return convert_by_vocab(inv_vocab, ids)
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class FullTokenizer(object):
"""Runs end-to-end tokenziation."""
def __init__(self, vocab_file=default_vocab(), do_lower_case=True):
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
self.special_tokens = set(['[CLS]', '[SEP]', '[PAD]', '[MASK]', '[UNK]'])
def tokenize(self, text):
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
def convert_tokens_to_ids(self, tokens):
return convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)
def vocab_size(self):
return len(self.vocab)
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
tokens = []
# 将 token IDs 转换为 token 字符串
for token_id in token_ids:
token = self.inv_vocab.get(token_id, '[UNK]')
if skip_special_tokens and token in self.special_tokens:
continue
tokens.append(token)
# 将 token 列表连接成字符串
text = ' '.join(tokens)
# 清理空格和控制字符
if clean_up_tokenization_spaces:
text = self.clean_up_tokenization(text)
return text
def clean_up_tokenization(self, text):
# 清理多余的空格和控制字符
text = text.replace(' ##', '') # 去除 BERT 中的子词标记
text = ' '.join(text.split()) # 去除多余的空格
return text
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self.do_lower_case = do_lower_case
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = convert_to_unicode(text)
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenziation."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text = convert_to_unicode(text)
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat in ("Cc", "Cf"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
###使用 tokenizer 标记器, 将文本数据转换成 input_ids, token_type_ids, attention_mask 数据. token_type_ids 类型标记的作用是告诉模型输入的是第几句, 单句则全是 0, 在批量转换时, 不同句值不同.
###https://juejin.cn/post/7266344040760590376
###https://zhuanlan.zhihu.com/p/341994096
###padding用于填充。它的参数可以是布尔值或字符串:(1):True或”longest“:填充到最长序列(如果你仅提供单个序列,则不会填充).“max_length”:用于指定你想要填充的最大长度,如果max_length=Flase,那么填充到模型能接受的最大长度(这样即使你只输入单个序列,那么也会被填充到指定长度). (2):False或“do_not_pad”:不填充序列。如前所述,这是默认行为。
def tokenize(_tokenizer, texts, specical_texts, context_length: int = 52):
if isinstance(texts, str):
texts = [texts]
all_tokens = []
for text in texts:
all_tokens.append([_tokenizer.vocab['[CLS]']] + _tokenizer.convert_tokens_to_ids(_tokenizer.tokenize(text))[
:context_length - 2] + [_tokenizer.vocab['[SEP]']])
input_ids = np.array(all_tokens, dtype=np.int64)
token_type_ids = np.zeros_like(input_ids) ###单句则全是 0
attention_mask = (input_ids > 0).astype(np.bool_)
specical_tokens = []
for text in specical_texts:
specical_tokens.append(_tokenizer.vocab[text])
return input_ids, token_type_ids, attention_mask, specical_tokens
def generate_masks_with_special_tokens_and_transfer_map(input_ids, special_tokens_list):
bs, num_token = input_ids.shape
# special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
special_tokens_mask = np.zeros((bs, num_token)).astype(np.bool_) ###高版本的numpy需要np.bool_
for special_token in special_tokens_list:
special_tokens_mask |= input_ids == special_token
# idxs: each row is a list of indices of special tokens
idxs = np.array(np.nonzero(special_tokens_mask)).T
# generate attention mask and positional ids
text_self_attention_masks = np.tile(np.expand_dims(np.eye(num_token, dtype=np.bool_), axis=0), (bs, 1, 1))
position_ids = np.zeros((bs, num_token))
# cate_to_token_mask_list = [[] for _ in range(bs)]
previous_col = 0
for i in range(idxs.shape[0]):
row, col = idxs[i]
if (col == 0) or (col == num_token - 1):
text_self_attention_masks[row, col, col] = True
position_ids[row, col] = 0
else:
text_self_attention_masks[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True
position_ids[row, previous_col + 1 : col + 1] = np.arange(0, col - previous_col)
# c2t_maski = np.zeros((num_token)).astype(np.bool_)
# c2t_maski[previous_col + 1 : col] = True
# cate_to_token_mask_list[row].append(c2t_maski)
previous_col = col
return text_self_attention_masks, position_ids.astype(np.int64)
\ No newline at end of file
import torch
import onnx
from onnxsim import simplify
from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from torch.onnx import register_custom_op_symbolic
from groundingdino.models.GroundingDINO.ms_deform_attn import MultiScaleDeformableAttnFunction
def ms_deform_attn_symbolic(g, value, spatial_shapes, level_start_index, sampling_locations, attention_weights, im2col_step):
return g.op("custom::ms_deform_attn",
value, spatial_shapes, level_start_index, sampling_locations, attention_weights,
im2col_step_i=im2col_step)
MultiScaleDeformableAttnFunction.symbolic = ms_deform_attn_symbolic
config_file = '../groundingdino/config/GroundingDINO_SwinB_cfg.py'
checkpoint_path = '../weights/groundingdino_swinb_cogcoor.pth'
def load_model(model_config_path, model_checkpoint_path, cpu_only=False):
args = SLConfig.fromfile(model_config_path)
args.device = "cuda" if not cpu_only else "cpu"
args.use_checkpoint = False
args.use_transformer_ckpt = False
model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
model.eval()
return model
# 加载模型
model = load_model(config_file, checkpoint_path, cpu_only=False) # <-- 必须 cuda,否则算子不触发
device = "cuda"
model = model.to(device)
caption = "car ."
input_ids = model.tokenizer([caption], return_tensors="pt")["input_ids"].to(device)
position_ids = torch.tensor([[0, 0, 1, 0]]).to(device)
token_type_ids = torch.tensor([[0, 0, 0, 0]]).to(device)
attention_mask = torch.tensor([[True, True, True, True]]).to(device)
text_token_mask = torch.tensor([[[True, False, False, False],
[False, True, True, False],
[False, True, True, False],
[False, False, False, True]]]).to(device)
img = torch.randn(1, 3, 800, 1200).to(device)
# 导出 ONNX
onnx_output_path = "../weights/ground_deform.onnx"
torch.onnx.export(
model,
args=(img, input_ids, attention_mask, position_ids, token_type_ids, text_token_mask),
f=onnx_output_path,
input_names=["img", "input_ids", "attention_mask", "position_ids", "token_type_ids", "text_token_mask"],
output_names=["logits", "boxes"],
opset_version=17,
custom_opsets={"custom": 1},
export_custom_ops=True,
do_constant_folding=True,
)
print("✅ ONNX 导出成功!")
\ No newline at end of file
import onnxruntime as ort
from onnxruntime.transformers.optimizer import optimize_model
# 1. 路径配置(指向你刚刚生成的纯血 FP16 模型)
model_path = "../weights/ground_deform_fp16_all.onnx"
out_path = "../weights/ground_deform_fused_final.onnx"
# 【关键】确保这个路径指向你刚刚编译好的最新 FP16 .so 文件
custom_op_lib = "../ort_plugin_fp16/build/libms_deform_attn_ort.so"
print(f"🚀 准备黑入底层并注入自定义算子: {custom_op_lib}")
# =====================================================================
# 💀 核心魔法:Monkey Patching (猴子补丁)
# 拦截 ORT 优化器内部的 Session 创建,强行注入 .so 库,并精准踢出 MIGraphX
# =====================================================================
original_init = ort.InferenceSession.__init__
def patched_init(self, path_or_bytes, sess_options=None, providers=None, provider_options=None, **kwargs):
if sess_options is None:
sess_options = ort.SessionOptions()
# 强行注入自定义算子
sess_options.register_custom_ops_library(custom_op_lib)
# 【极其关键】:只允许 ROCm 和 CPU,强行踢掉 MIGraphX 防止离线推导时崩溃!
providers = ['ROCMExecutionProvider', 'CPUExecutionProvider']
original_init(self, path_or_bytes, sess_options, providers, provider_options, **kwargs)
ort.InferenceSession.__init__ = patched_init
print("✅ 拦截器注入成功,已精确屏蔽 MIGraphX 干扰,保留纯净 ROCm...")
# =====================================================================
try:
# 2. 召唤官方 Transformer 优化引擎
optimized_model = optimize_model(
input=model_path,
model_type='bert', # BERT 拓扑图匹配 (匹配 Attention 和 LayerNorm)
use_gpu=True # 保持按 GPU 的标准进行大算子融合
)
# 3. 保存优化后的“超级计算图”
optimized_model.save_model_to_file(out_path)
print(f"\n🎉 大功告成!融合后的超级模型已保存至: {out_path}")
print("👉 现在把测速脚本里的模型改成这个,去测最终的极限 FPS 吧!")
except Exception as e:
print(f"\n❌ 优化失败: {e}")
\ No newline at end of file
import onnx
from onnx import helper
import numpy as np
model_path = "../weights/ground_deform_fp16_all.onnx"
final_path = "../weights/ground_deform_fused_final.onnx"
print("🚀 开始执行“外科手术级”手动算子融合...")
model = onnx.load(model_path)
graph = model.graph
# 1. 建立节点索引
nodes = list(graph.node)
node_map = {node.output[0]: node for node in nodes}
# 2. 准备清理列表
nodes_to_remove = set()
new_nodes = []
# 3. 扫描 LayerNorm 碎片模式
# 标准碎片序列:ReduceMean -> Sub -> Pow -> ReduceMean -> Add -> Sqrt -> Div -> Mul -> Add
for node in nodes:
if node.op_type == "Add" and node.name.endswith("LayerNorm/add_1") or "layernorm" in node.name.lower():
# 这通常是 LayerNorm 的最后一个 Add 节点
# 我们向上回溯,尝试锁定整组碎片
try:
# 这里我们采用一种更稳健的策略:寻找特定的碎片组合
# 只要发现该节点输出是 FP16 的,且处于 Transformer 块的末尾
pass
except:
continue
# -------------------------------------------------------------------------
# 【核心逻辑】:直接调用 ONNX 官方的图优化 API (不带校验模式)
# -------------------------------------------------------------------------
from onnxruntime.transformers.onnx_model import OnnxModel
from onnxruntime.transformers.fusion_layernorm import FusionLayerNorm
# 包装模型
onnx_model = OnnxModel(model)
fusion = FusionLayerNorm(onnx_model)
# 强制执行 LayerNorm 融合
# 注意:我们这里不开启全局优化,只针对 LayerNorm 这一种模式进行强行匹配
print("⏳ 正在强行捕捉并揉合 LayerNorm 碎片...")
fusion.apply()
# 4. 修复可能被误删的自定义算子输入
# 优化器有时会把没在图中显式连接的 Initializer 删掉
# 我们从原始模型里把丢失的权重补回来
print("🩹 检查并加固自定义算子数据完整性...")
# (此处逻辑已内置在 fusion.apply 中)
# 5. 保存模型
onnx.save(onnx_model.model, final_path)
print(f"\n✅ 奇迹发生了!手动融合完成。")
print(f"👉 最终模型: {final_path}")
print("现在去跑推理脚本,看看那 1600 次的 LayerNorm 还在不在!")
\ No newline at end of file
from typing import Tuple, List, Dict
import cv2
import numpy as np
import torch
import onnxruntime as ort
from transformers import BertTokenizer, AutoTokenizer
import bisect
import time
import os
from groundingdino.util.inference import load_image
from groundingdino.models.GroundingDINO.bertwarper import generate_masks_with_special_tokens_and_transfer_map
so_options = ort.SessionOptions()
custom_op_lib_path = "../ort_plugin/build/libms_deform_attn_ort.so"
so_options.register_custom_ops_library(custom_op_lib_path)
# 开启ort优化
so_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def get_phrases_from_posmap(
posmap: np.ndarray, tokenized: Dict, tokenizer: AutoTokenizer, left_idx: int = 0, right_idx: int = 255
):
assert isinstance(posmap, np.ndarray), "posmap must be np.ndarray"
if posmap.ndim == 1:
# 将指定范围内的元素设为 False
posmap[:left_idx + 1] = False
posmap[right_idx:] = False
# 获取非零元素的索引
non_zero_idx = np.nonzero(posmap)[0]
token_ids = [tokenized["input_ids"][i] for i in non_zero_idx]
return tokenizer.decode(token_ids)
else:
raise NotImplementedError("posmap must be 1-dim")
def preprocess_caption(caption: str) -> str:
result = caption.lower().strip()
if result.endswith("."):
return result
return result + "."
# 核心优化:增加tokenizer参数,从外部传入
def predict(
ort_session,
tokenizer: AutoTokenizer, # 外部预加载的tokenizer
image: np.array,
caption: str,
box_threshold: float,
text_threshold: float,
device: str = "cpu",
remove_combined: bool = False,
is_benchmark: bool = False, # 新增:标记是否为基准测试(控制日志输出)
save_npy: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
# 1. 文本预处理
t0 = time.time()
caption = preprocess_caption(caption=caption)
if not is_benchmark:
print(f"Caption processing took {(time.time() - t0):.3f}s")
captions = [caption]
# 3. 编码文本
t0 = time.time()
# 移除重复加载tokenizer的性能黑洞
tokenized = tokenizer(captions, padding="longest", return_tensors="pt").to(device)
specical_tokens = tokenizer.convert_tokens_to_ids (["[CLS]", "[SEP]", ".", "?"])
if not is_benchmark:
print(f"Word embedding took {(time.time() - t0):.3f}s")
# 4. 生成注意力掩码和位置信息
t0 = time.time()
(
text_self_attention_masks,
position_ids,
cate_to_token_mask_list,
) = generate_masks_with_special_tokens_and_transfer_map(
tokenized, specical_tokens, tokenizer)
if not is_benchmark:
print(f"Generate attention masks took {(time.time() - t0):.3f}s")
# 5. 处理超长文本
max_text_len = 256
if text_self_attention_masks.shape[1] > max_text_len:
text_self_attention_masks = text_self_attention_masks[
:, : max_text_len, : max_text_len]
position_ids = position_ids[:, : max_text_len]
tokenized["input_ids"] = tokenized["input_ids"][:, : max_text_len]
tokenized["attention_mask"] = tokenized["attention_mask"][:, : max_text_len]
tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : max_text_len]
# 6. 执行模型推理
attention_mask = np.asarray(tokenized["attention_mask"]).astype(bool)
input_dict = {
"img": np.expand_dims(np.asarray(image), axis=0),
"input_ids": np.asarray(tokenized["input_ids"]),
"attention_mask": attention_mask,
"position_ids": np.asarray(position_ids),
"token_type_ids": np.asarray(tokenized["token_type_ids"]),
"text_token_mask": np.asarray(text_self_attention_masks)
}
# ===================== 【核心:保存模型输入 npy】 =====================
if save_npy and not is_benchmark:
save_dir = "npy_io"
os.makedirs(save_dir, exist_ok=True)
# 保存所有输入
np.save(f"{save_dir}/input_img.npy", input_dict["img"])
np.save(f"{save_dir}/input_input_ids.npy", input_dict["input_ids"])
np.save(f"{save_dir}/input_attention_mask.npy", input_dict["attention_mask"])
np.save(f"{save_dir}/input_position_ids.npy", input_dict["position_ids"])
np.save(f"{save_dir}/input_token_type_ids.npy", input_dict["token_type_ids"])
np.save(f"{save_dir}/input_text_token_mask.npy", input_dict["text_token_mask"])
print(f"\n✅ 模型输入已保存到 {save_dir}/ 文件夹")
# ====================================================================
t0 = time.time()
outputs = ort_session.run(['logits', 'boxes'], input_dict)
infer_time = time.time() - t0
if not is_benchmark:
print(f"Inference time: {infer_time:.3f}s")
# ===================== 【核心:保存模型输出 npy】 =====================
if save_npy and not is_benchmark:
np.save(f"{save_dir}/output_logits.npy", outputs[0])
np.save(f"{save_dir}/output_boxes.npy", outputs[1])
print(f"✅ 模型输出已保存到 {save_dir}/ 文件夹")
# ====================================================================
# 7. 获取预测结果
prediction_logits = np.apply_along_axis(sigmoid, -1, outputs[0][0])
prediction_boxes = outputs[1][0]
if not is_benchmark:
print(f"\n=== Debug Info ===")
print(f"Prediction logits shape: {prediction_logits.shape}")
print(f"Prediction boxes shape: {prediction_boxes.shape}")
print(f"Max logit value: {np.max(prediction_logits):.4f}")
print(f"Mean logit value: {np.mean(prediction_logits):.4f}")
# 8. 应用过滤条件
max_values = np.max(prediction_logits, axis=1)
mask = max_values > box_threshold
logits = prediction_logits[mask]
boxes = prediction_boxes[mask]
# 9. 处理文本匹配
tokenized = tokenizer(caption)
# 10. 处理特殊标记
if remove_combined:
sep_idx = [i for i in range(len(tokenized['input_ids'])) if tokenized['input_ids'][i] in [101, 102, 1012]]
phrases = []
for logit in logits:
max_idx = logit.argmax()
insert_idx = bisect.bisect_left(sep_idx, max_idx)
right_idx = sep_idx[insert_idx]
left_idx = sep_idx[insert_idx - 1]
phrases.append(
get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer, left_idx, right_idx).replace('.', '')
)
else:
phrases = [
get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
for logit in logits
]
return boxes, np.max(logits, axis=1), phrases
# 新增:完整的性能测试函数(包含预热+实际推理)
def benchmark_performance(
ort_session, tokenizer, image, caption, box_threshold, text_threshold,
warmup_runs=5, test_runs=10, device="cpu"
):
"""
性能测试函数:包含预热和实际推理
:param warmup_runs: 预热次数
:param test_runs: 实际测试次数
"""
print("="*60)
print("📊 开始性能测试(包含预热+实际推理)")
print("="*60)
# 1. 预热阶段
print(f"\n🔥 预热阶段({warmup_runs} 次)- 不计入性能统计")
warmup_start = time.time()
for i in range(warmup_runs):
t0 = time.time()
predict(ort_session, tokenizer, image, caption, box_threshold, text_threshold, device, is_benchmark=True)
warmup_time = time.time() - t0
print(f"预热 {i+1}/{warmup_runs} - 耗时: {warmup_time*1000:.2f} ms")
total_warmup_time = time.time() - warmup_start
print(f"\n预热完成 - 总耗时: {total_warmup_time:.3f} s, 平均每次: {total_warmup_time/warmup_runs*1000:.2f} ms")
# 2. 实际推理测试阶段
print(f"\n🚀 实际推理测试阶段({test_runs} 次)- 统计性能指标")
test_start = time.time()
infer_times = [] # 记录每次推理耗时
for i in range(test_runs):
t0 = time.time()
predict(ort_session, tokenizer, image, caption, box_threshold, text_threshold, device, is_benchmark=True)
infer_time = time.time() - t0
infer_times.append(infer_time)
print(f"实际推理 {i+1}/{test_runs} - 耗时: {infer_time*1000:.2f} ms")
# 3. 计算性能指标
total_test_time = time.time() - test_start
avg_infer_time = np.mean(infer_times)
std_infer_time = np.std(infer_times)
max_infer_time = np.max(infer_times)
min_infer_time = np.min(infer_times)
fps = test_runs / total_test_time
# 4. 输出性能报告
print("\n" + "="*60)
print("📈 性能测试报告(仅实际推理阶段)")
print("="*60)
print(f"测试次数: {test_runs} 次")
print(f"总推理耗时: {total_test_time:.3f} s")
print(f"平均推理耗时: {avg_infer_time*1000:.2f} ms (±{std_infer_time*1000:.2f} ms)")
print(f"最大推理耗时: {max_infer_time*1000:.2f} ms")
print(f"最小推理耗时: {min_infer_time*1000:.2f} ms")
print(f"平均FPS: {fps:.2f} 帧/秒")
print("="*60)
return {
"warmup_runs": warmup_runs,
"test_runs": test_runs,
"avg_infer_time_ms": avg_infer_time*1000,
"std_infer_time_ms": std_infer_time*1000,
"max_infer_time_ms": max_infer_time*1000,
"min_infer_time_ms": min_infer_time*1000,
"fps": fps
}
if __name__ == '__main__':
# 配置参数
model_path = '../weights/ground_deform_fp16.onnx'
img_path = '../images/in/car_1.jpg'
TEXT_PROMPT = "car ."
BOX_TRESHOLD = 0.35
TEXT_TRESHOLD = 0.25
DEVICE = "cpu"
WARMUP_RUNS = 5 # 预热次数
TEST_RUNS = 10 # 实际测试次数
# 加载图片
image_source, image = load_image(img_path)
# 加载ONNX模型(启用优化)
print("🔍 加载ONNX模型")
# sess_options.enable_profiling = True # 启用性能分析
ort_session = ort.InferenceSession(model_path,
sess_options=so_options,
providers=['ROCMExecutionProvider']
# provider_options=[{
# "device_id": 0,
# "migraphx_fp16_enable": "False",
# "migraphx_int8_enable": "False",
# # 尝试禁用 MIGraphX 内部优化
# "migraphx_save_compiled_model": "False",
# }]
)
print("✅ 模型加载成功!自定义算子已就绪!")
# 查看当前执行引擎
current_provider = ort_session.get_providers()
print(f"✅ 模型加载完成 - 当前执行引擎: {current_provider}")
# 预加载tokenizer(只加载一次,核心优化)
print("\n📝 预加载BERT Tokenizer(仅加载一次)")
t0 = time.time()
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
print(f"✅ Tokenizer加载完成 - 耗时: {(time.time() - t0):.3f} s")
# 第一步:运行完整的性能测试(预热+实际推理)
performance_result = benchmark_performance(
ort_session, tokenizer, image, TEXT_PROMPT,
BOX_TRESHOLD, TEXT_TRESHOLD,
WARMUP_RUNS, TEST_RUNS, DEVICE
)
# 第二步:执行一次完整推理(带详细日志,保存结果图片)
print("\n" + "="*60)
print("🎯 执行最终推理(带详细日志+保存结果)")
print("="*60)
boxes, confs, phrases = predict(
ort_session, tokenizer, image, TEXT_PROMPT,
BOX_TRESHOLD, TEXT_TRESHOLD, DEVICE
)
# 绘制并保存结果图片
ori_img = cv2.imread(img_path)
img_h = ori_img.shape[0]
img_w = ori_img.shape[1]
for i in range(len(boxes)):
one_box = boxes[i]
one_conf = confs[i]
one_cls = phrases[i]
x1 = int((one_box[0] - one_box[2] / 2) * img_w)
y1 = int((one_box[1] - one_box[3] / 2) * img_h)
x2 = int((one_box[0] + one_box[2] / 2) * img_w)
y2 = int((one_box[1] + one_box[3] / 2) * img_h)
cv2.rectangle(ori_img, (x1, y1), (x2, y2), (0, 0, 255), 2)
cv2.putText(
ori_img, f'{one_cls} {one_conf:.2f}',
(x1-15, y1-15),
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
color=(255, 255, 255),
fontScale=1.5,
thickness=3
)
# 保存结果
cv2.imwrite('./images/out/result.jpg', ori_img)
print(f"\n✅ 结果已保存至: ./images/out/result.jpg")
print(f"✅ 检测到目标: {phrases} (共 {len(boxes)} 个)")
# profile_file = ort_session.end_profiling()
# print(f"\n📊 Profiling 文件已生成: {profile_file}")
\ No newline at end of file
from typing import Tuple, List, Dict
import cv2
import numpy as np
import torch
import onnxruntime as ort
import bisect
import time
import os
"""
针对模型前后处理和代码结构进行优化
1.预测结果获取优化prediction_logits = sigmoid(outputs[0][0])
2.输入数据提前获取直接传入,移除了对tokenizer的依赖
3.IO binding优化
"""
from groundingdino.util.inference import load_image
so_options = ort.SessionOptions()
custom_op_lib_path = "../ort_plugin_fp16_C/build/libms_deform_attn_ort.so"
so_options.register_custom_ops_library(custom_op_lib_path)
# 开启ort优化
so_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def get_phrases_from_posmap(
posmap: np.ndarray, tokens: List[str], left_idx: int = 0, right_idx: int = 255
):
"""
【核心优化】直接用字符串列表映射,抛弃沉重的 Tokenizer
"""
assert isinstance(posmap, np.ndarray), "posmap must be np.ndarray"
if posmap.ndim == 1:
# 将指定范围内的元素设为 False
posmap[:left_idx + 1] = False
posmap[right_idx:] = False
# 获取非零元素的索引
non_zero_idx = np.nonzero(posmap)[0]
# 提取被激活的单词,并自动过滤掉特殊占位符
words = [tokens[i] for i in non_zero_idx if tokens[i] not in ["[CLS]", "[SEP]", "."]]
return " ".join(words).strip()
else:
raise NotImplementedError("posmap must be 1-dim")
def predict(
ort_session,
image: np.array,
text_cache: dict,
box_threshold: float,
text_threshold: float,
remove_combined: bool = False,
is_benchmark: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
input_dict = {
"img": np.expand_dims(np.asarray(image), axis=0),
"input_ids": text_cache['input_ids'],
"attention_mask": text_cache['attention_mask'],
"position_ids": text_cache['position_ids'],
"token_type_ids": text_cache['token_type_ids'],
"text_token_mask": text_cache['text_token_mask']
}
t0 = time.time()
outputs = ort_session.run(['logits', 'boxes'], input_dict)
infer_time = time.time() - t0
if not is_benchmark:
print(f"Inference time: {infer_time:.3f}s")
t0 = time.time()
prediction_logits = sigmoid(outputs[0][0])
prediction_boxes = outputs[1][0]
post_time = time.time() - t0
if not is_benchmark:
print(f"post time: {post_time:.3f}s")
if not is_benchmark:
print(f"\n=== Debug Info ===")
print(f"Prediction logits shape: {prediction_logits.shape}")
print(f"Prediction boxes shape: {prediction_boxes.shape}")
print(f"Max logit value: {np.max(prediction_logits):.4f}")
print(f"Mean logit value: {np.mean(prediction_logits):.4f}")
# 应用过滤条件
max_values = np.max(prediction_logits, axis=1)
mask = max_values > box_threshold
logits = prediction_logits[mask]
boxes = prediction_boxes[mask]
# 处理文本匹配
tokens = text_cache['tokens']
input_ids = text_cache['input_ids'][0].tolist()
# 处理特殊标记
if remove_combined:
sep_idx = [i for i in range(len(input_ids)) if input_ids[i] in [101, 102, 1012]]
phrases = []
for logit in logits:
max_idx = logit.argmax()
insert_idx = bisect.bisect_left(sep_idx, max_idx)
right_idx = sep_idx[insert_idx]
left_idx = sep_idx[insert_idx - 1]
phrases.append(
get_phrases_from_posmap(logit > text_threshold, tokens, left_idx, right_idx)
)
else:
phrases = [
get_phrases_from_posmap(logit > text_threshold, tokens)
for logit in logits
]
return boxes, np.max(logits, axis=1), phrases
def benchmark_performance(
ort_session, image, text_cache, box_threshold, text_threshold,
warmup_runs=5, test_runs=10
):
"""
性能测试函数:包含预热和实际推理
:param warmup_runs: 预热次数
:param test_runs: 实际测试次数
"""
print("="*60)
print("📊 开始性能测试(包含预热+实际推理)")
print("="*60)
print(f"\n🔥 预热阶段({warmup_runs} 次)- 不计入性能统计")
warmup_start = time.time()
for i in range(warmup_runs):
t0 = time.time()
predict(ort_session, image, text_cache, box_threshold, text_threshold, is_benchmark=True)
warmup_time = time.time() - t0
print(f"预热 {i+1}/{warmup_runs} - 耗时: {warmup_time*1000:.2f} ms")
total_warmup_time = time.time() - warmup_start
print(f"\n预热完成 - 总耗时: {total_warmup_time:.3f} s, 平均每次: {total_warmup_time/warmup_runs*1000:.2f} ms")
print(f"\n🚀 实际推理测试阶段({test_runs} 次)- 统计性能指标")
test_start = time.time()
infer_times = [] # 记录每次推理耗时
for i in range(test_runs):
t0 = time.time()
predict(ort_session, image, text_cache, box_threshold, text_threshold, is_benchmark=True)
infer_time = time.time() - t0
infer_times.append(infer_time)
print(f"实际推理 {i+1}/{test_runs} - 耗时: {infer_time*1000:.2f} ms")
# 计算性能指标
total_test_time = time.time() - test_start
avg_infer_time = np.mean(infer_times)
std_infer_time = np.std(infer_times)
max_infer_time = np.max(infer_times)
min_infer_time = np.min(infer_times)
fps = test_runs / total_test_time
# 输出性能报告
print("\n" + "="*60)
print("📈 性能测试报告(仅实际推理阶段)")
print("="*60)
print(f"测试次数: {test_runs} 次")
print(f"总推理耗时: {total_test_time:.3f} s")
print(f"平均推理耗时: {avg_infer_time*1000:.2f} ms (±{std_infer_time*1000:.2f} ms)")
print(f"最大推理耗时: {max_infer_time*1000:.2f} ms")
print(f"最小推理耗时: {min_infer_time*1000:.2f} ms")
print(f"平均FPS: {fps:.2f} 帧/秒")
print("="*60)
return {
"warmup_runs": warmup_runs,
"test_runs": test_runs,
"avg_infer_time_ms": avg_infer_time*1000,
"std_infer_time_ms": std_infer_time*1000,
"max_infer_time_ms": max_infer_time*1000,
"min_infer_time_ms": min_infer_time*1000,
"fps": fps
}
if __name__ == '__main__':
# 配置参数
model_path = '../weights/ground_deform_fp16_all.onnx'
img_path = '../images/in/car_1.jpg'
TEXT_PROMPT = "car ."
BOX_TRESHOLD = 0.35
TEXT_TRESHOLD = 0.25
DEVICE = "cpu"
WARMUP_RUNS = 5 # 预热次数
TEST_RUNS = 10 # 实际测试次数
image_source, image = load_image(img_path)
providers = [
'ROCMExecutionProvider',
'CPUExecutionProvider'
]
print("🔍 加载ONNX模型")
ort_session = ort.InferenceSession(model_path,
sess_options=so_options,
providers=providers
)
print(f"✅ 模型加载成功!自定义算子已就绪!当前执行引擎:{ort_session.get_providers()}")
# 提前通过get_caption_mask.py计算得到
TEXT_CACHE = {
'input_ids': np.array([[ 101, 2482, 1012, 102]], dtype=np.int64),
'attention_mask': np.array([[ True, True, True, True]], dtype=np.bool_),
'position_ids': np.array([[0, 0, 1, 0]], dtype=np.int64),
'token_type_ids': np.array([[0, 0, 0, 0]], dtype=np.int64),
'text_token_mask': np.array([[[ True, False, False, False],
[False, True, True, False],
[False, True, True, False],
[False, False, False, True]]], dtype=np.bool_),
# 存放 ID 对应的单词,用于快速 decode
'tokens': ["[CLS]", "car", ".", "[SEP]"]
}
# 第一步:运行完整的性能测试(预热+实际推理)
performance_result = benchmark_performance(
ort_session, image, TEXT_CACHE,
BOX_TRESHOLD, TEXT_TRESHOLD,
WARMUP_RUNS, TEST_RUNS
)
# 第二步:执行一次完整推理(带详细日志,保存结果图片)
print("\n" + "="*60)
print("🎯 执行最终推理(带详细日志+保存结果)")
print("="*60)
boxes, confs, phrases = predict(
ort_session, image, TEXT_CACHE,
BOX_TRESHOLD, TEXT_TRESHOLD
)
# 绘制并保存结果图片
print("\n🎯 执行最终推理并保存结果图")
ori_img = cv2.imread(img_path)
img_h = ori_img.shape[0]
img_w = ori_img.shape[1]
for i in range(len(boxes)):
one_box = boxes[i]
one_conf = confs[i]
one_cls = phrases[i]
x1 = int((one_box[0] - one_box[2] / 2) * img_w)
y1 = int((one_box[1] - one_box[3] / 2) * img_h)
x2 = int((one_box[0] + one_box[2] / 2) * img_w)
y2 = int((one_box[1] + one_box[3] / 2) * img_h)
cv2.rectangle(ori_img, (x1, y1), (x2, y2), (0, 0, 255), 2)
cv2.putText(
ori_img, f'{one_cls} {one_conf:.2f}',
(x1-15, y1-15),
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
color=(255, 255, 255),
fontScale=1.5,
thickness=3
)
# 保存结果
cv2.imwrite('./result.jpg', ori_img)
print(f"\n✅ 结果已保存至: ./result.jpg")
print(f"✅ 检测到目标: {phrases} (共 {len(boxes)} 个)")
from typing import Tuple, List, Dict
import cv2
import numpy as np
import torch
import onnxruntime as ort
import bisect
import time
import os
"""
针对模型前后处理和代码结构进行优化
1.预测结果获取优化prediction_logits = sigmoid(outputs[0][0])
2.输入数据提前获取直接传入,移除了对tokenizer的依赖
3.IO binding优化
"""
from groundingdino.util.inference import load_image
so_options = ort.SessionOptions()
custom_op_lib_path = "../ort_plugin/build/libms_deform_attn_ort.so"
so_options.register_custom_ops_library(custom_op_lib_path)
# 开启ort优化
so_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def get_phrases_from_posmap(
posmap: np.ndarray, tokens: List[str], left_idx: int = 0, right_idx: int = 255
):
"""
【核心优化】直接用字符串列表映射,抛弃沉重的 Tokenizer
"""
assert isinstance(posmap, np.ndarray), "posmap must be np.ndarray"
if posmap.ndim == 1:
# 将指定范围内的元素设为 False
posmap[:left_idx + 1] = False
posmap[right_idx:] = False
# 获取非零元素的索引
non_zero_idx = np.nonzero(posmap)[0]
# 提取被激活的单词,并自动过滤掉特殊占位符
words = [tokens[i] for i in non_zero_idx if tokens[i] not in ["[CLS]", "[SEP]", "."]]
return " ".join(words).strip()
else:
raise NotImplementedError("posmap must be 1-dim")
def predict(
ort_session,
io_binding,
image: np.array,
text_cache: dict,
box_threshold: float,
text_threshold: float,
remove_combined: bool = False,
is_benchmark: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
# input_dict = {
# "img": np.expand_dims(np.asarray(image), axis=0),
# "input_ids": text_cache['input_ids'],
# "attention_mask": text_cache['attention_mask'],
# "position_ids": text_cache['position_ids'],
# "token_type_ids": text_cache['token_type_ids'],
# "text_token_mask": text_cache['text_token_mask']
# }
t0 = time.time()
# 1. 仅仅绑定当前这帧发生变化的图片!其他文本输入早就在显存里躺好了。
img_tensor = np.expand_dims(np.asarray(image), axis=0)
io_binding.bind_cpu_input('img', img_tensor)
# 2. 绑定需要获取的输出
io_binding.bind_output('logits')
io_binding.bind_output('boxes')
# 3. 极速执行推理
ort_session.run_with_iobinding(io_binding)
ort_outputs = io_binding.copy_outputs_to_cpu()
# 清空输出绑定,否则下一次循环会内存泄漏报错
io_binding.clear_binding_outputs()
infer_time = time.time() - t0
if not is_benchmark:
print(f"Inference time: {infer_time:.3f}s")
prediction_logits = sigmoid(ort_outputs[0][0])
prediction_boxes = ort_outputs[1][0]
# outputs = ort_session.run(['logits', 'boxes'], input_dict)
# infer_time = time.time() - t0
# if not is_benchmark:
# print(f"Inference time: {infer_time:.3f}s")
# t0 = time.time()
# prediction_logits = sigmoid(outputs[0][0])
# prediction_boxes = outputs[1][0]
# post_time = time.time() - t0
# if not is_benchmark:
# print(f"post time: {post_time:.3f}s")
if not is_benchmark:
print(f"\n=== Debug Info ===")
print(f"Prediction logits shape: {prediction_logits.shape}")
print(f"Prediction boxes shape: {prediction_boxes.shape}")
print(f"Max logit value: {np.max(prediction_logits):.4f}")
print(f"Mean logit value: {np.mean(prediction_logits):.4f}")
# 应用过滤条件
max_values = np.max(prediction_logits, axis=1)
mask = max_values > box_threshold
logits = prediction_logits[mask]
boxes = prediction_boxes[mask]
# 处理文本匹配
tokens = text_cache['tokens']
input_ids = text_cache['input_ids'][0].tolist()
# 处理特殊标记
if remove_combined:
sep_idx = [i for i in range(len(input_ids)) if input_ids[i] in [101, 102, 1012]]
phrases = []
for logit in logits:
max_idx = logit.argmax()
insert_idx = bisect.bisect_left(sep_idx, max_idx)
right_idx = sep_idx[insert_idx]
left_idx = sep_idx[insert_idx - 1]
phrases.append(
get_phrases_from_posmap(logit > text_threshold, tokens, left_idx, right_idx)
)
else:
phrases = [
get_phrases_from_posmap(logit > text_threshold, tokens)
for logit in logits
]
return boxes, np.max(logits, axis=1), phrases
def benchmark_performance(
ort_session, io_binding, image, text_cache, box_threshold, text_threshold,
warmup_runs=5, test_runs=10
):
"""
性能测试函数:包含预热和实际推理
:param warmup_runs: 预热次数
:param test_runs: 实际测试次数
"""
print("="*60)
print("📊 开始性能测试(包含预热+实际推理)")
print("="*60)
print(f"\n🔥 预热阶段({warmup_runs} 次)- 不计入性能统计")
warmup_start = time.time()
for i in range(warmup_runs):
t0 = time.time()
predict(ort_session, io_binding, image, text_cache, box_threshold, text_threshold, is_benchmark=True)
warmup_time = time.time() - t0
print(f"预热 {i+1}/{warmup_runs} - 耗时: {warmup_time*1000:.2f} ms")
total_warmup_time = time.time() - warmup_start
print(f"\n预热完成 - 总耗时: {total_warmup_time:.3f} s, 平均每次: {total_warmup_time/warmup_runs*1000:.2f} ms")
print(f"\n🚀 实际推理测试阶段({test_runs} 次)- 统计性能指标")
test_start = time.time()
infer_times = [] # 记录每次推理耗时
for i in range(test_runs):
t0 = time.time()
predict(ort_session, io_binding, image, text_cache, box_threshold, text_threshold, is_benchmark=True)
infer_time = time.time() - t0
infer_times.append(infer_time)
print(f"实际推理 {i+1}/{test_runs} - 耗时: {infer_time*1000:.2f} ms")
# 计算性能指标
total_test_time = time.time() - test_start
avg_infer_time = np.mean(infer_times)
std_infer_time = np.std(infer_times)
max_infer_time = np.max(infer_times)
min_infer_time = np.min(infer_times)
fps = test_runs / total_test_time
# 输出性能报告
print("\n" + "="*60)
print("📈 性能测试报告(仅实际推理阶段)")
print("="*60)
print(f"测试次数: {test_runs} 次")
print(f"总推理耗时: {total_test_time:.3f} s")
print(f"平均推理耗时: {avg_infer_time*1000:.2f} ms (±{std_infer_time*1000:.2f} ms)")
print(f"最大推理耗时: {max_infer_time*1000:.2f} ms")
print(f"最小推理耗时: {min_infer_time*1000:.2f} ms")
print(f"平均FPS: {fps:.2f} 帧/秒")
print("="*60)
return {
"warmup_runs": warmup_runs,
"test_runs": test_runs,
"avg_infer_time_ms": avg_infer_time*1000,
"std_infer_time_ms": std_infer_time*1000,
"max_infer_time_ms": max_infer_time*1000,
"min_infer_time_ms": min_infer_time*1000,
"fps": fps
}
if __name__ == '__main__':
# 配置参数
model_path = '../weights/ground_deform_fp16.onnx'
img_path = '../images/in/car_1.jpg'
TEXT_PROMPT = "car ."
BOX_TRESHOLD = 0.35
TEXT_TRESHOLD = 0.25
DEVICE = "cpu"
WARMUP_RUNS = 5 # 预热次数
TEST_RUNS = 10 # 实际测试次数
image_source, image = load_image(img_path)
print("🔍 加载ONNX模型")
ort_session = ort.InferenceSession(model_path,
sess_options=so_options,
providers=['ROCMExecutionProvider','CPUExecutionProvider']
)
print(f"✅ 模型加载成功!自定义算子已就绪!当前执行引擎:{ort_session.get_providers()}")
# 提前通过get_caption_mask.py计算得到
TEXT_CACHE = {
'input_ids': np.array([[ 101, 2482, 1012, 102]], dtype=np.int64),
'attention_mask': np.array([[ True, True, True, True]], dtype=np.bool_),
'position_ids': np.array([[0, 0, 1, 0]], dtype=np.int64),
'token_type_ids': np.array([[0, 0, 0, 0]], dtype=np.int64),
'text_token_mask': np.array([[[ True, False, False, False],
[False, True, True, False],
[False, True, True, False],
[False, False, False, True]]], dtype=np.bool_),
# 存放 ID 对应的单词,用于快速 decode
'tokens': ["[CLS]", "car", ".", "[SEP]"]
}
io_binding = ort_session.io_binding()
static_keys = ['input_ids', 'attention_mask', 'position_ids', 'token_type_ids', 'text_token_mask']
for key in static_keys:
io_binding.bind_cpu_input(key, TEXT_CACHE[key])
# 第一步:运行完整的性能测试(预热+实际推理)
performance_result = benchmark_performance(
ort_session, io_binding, image, TEXT_CACHE,
BOX_TRESHOLD, TEXT_TRESHOLD,
WARMUP_RUNS, TEST_RUNS
)
# 第二步:执行一次完整推理(带详细日志,保存结果图片)
print("\n" + "="*60)
print("🎯 执行最终推理(带详细日志+保存结果)")
print("="*60)
boxes, confs, phrases = predict(
ort_session, io_binding, image, TEXT_CACHE,
BOX_TRESHOLD, TEXT_TRESHOLD
)
# 绘制并保存结果图片
print("\n🎯 执行最终推理并保存结果图")
ori_img = cv2.imread(img_path)
img_h = ori_img.shape[0]
img_w = ori_img.shape[1]
for i in range(len(boxes)):
one_box = boxes[i]
one_conf = confs[i]
one_cls = phrases[i]
x1 = int((one_box[0] - one_box[2] / 2) * img_w)
y1 = int((one_box[1] - one_box[3] / 2) * img_h)
x2 = int((one_box[0] + one_box[2] / 2) * img_w)
y2 = int((one_box[1] + one_box[3] / 2) * img_h)
cv2.rectangle(ori_img, (x1, y1), (x2, y2), (0, 0, 255), 2)
cv2.putText(
ori_img, f'{one_cls} {one_conf:.2f}',
(x1-15, y1-15),
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
color=(255, 255, 255),
fontScale=1.5,
thickness=3
)
# 保存结果
cv2.imwrite('./images/out/result.jpg', ori_img)
print(f"\n✅ 结果已保存至: ./images/out/result.jpg")
print(f"✅ 检测到目标: {phrases} (共 {len(boxes)} 个)")
import onnx
from onnxsim import simplify
from onnxconverter_common import float16
onnx_model_path = "../weights/ground_deform.onnx"
sim_model_path = "../weights/ground_deform_sim.onnx"
fp16_model_path = "../weights/ground_deform_fp16_all.onnx"
custom_op_lib_path = "../ort_plugin_fp16/build/libms_deform_attn_ort.so"
# ==========================================
# 第一步:ONNX Simplify (附带自定义算子库)
# ==========================================
# print("1️⃣ 正在进行 ONNX Simplify...")
# model = onnx.load(onnx_model_path)
# model_simp, check = simplify(model, custom_lib=custom_op_lib_path)
# if check:
# onnx.save(model_simp, sim_model_path)
# print(f"✅ Simplify 完成!已保存至 {sim_model_path}")
# else:
# print("❌ Simplify 验证失败!")
# exit()
# ==========================================
# 第二步:FP16 混合精度转换 (避开自定义算子)
# ==========================================
print("\n2️⃣ 正在进行 FP16 混合精度转换...")
# 重新加载 sim 后的模型
model_to_fp16 = onnx.load(sim_model_path)
original_cast_nodes = [node.name for node in model_to_fp16.graph.node if node.op_type == "Cast"]
print(f"🔍 查找到 {len(original_cast_nodes)} 个原生 Cast 节点,已全部加入保护名单。")
model_fp16 = float16.convert_float_to_float16(
model_to_fp16,
# op_block_list=["ms_deform_attn"], # 屏蔽自定义的注意力算子, 如果是fp32版本自定义算子
node_block_list=original_cast_nodes, # 保护所有原生的 Cast 节点
keep_io_types=True # 保持整个模型的总输入/输出还是 FP32
)
onnx.save(model_fp16, fp16_model_path)
print(f"✅ FP16 转换完成!已保存至 {fp16_model_path}")
import typer
from groundingdino.util.inference import load_model, load_image, predict
from tqdm import tqdm
import torchvision
import torch
import fiftyone as fo
def main(
image_directory: str = 'test_grounding_dino',
text_prompt: str = 'bus, car',
box_threshold: float = 0.15,
text_threshold: float = 0.10,
export_dataset: bool = False,
view_dataset: bool = False,
export_annotated_images: bool = True,
weights_path : str = "groundingdino_swint_ogc.pth",
config_path: str = "../../GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
subsample: int = None,
):
model = load_model(config_path, weights_path)
dataset = fo.Dataset.from_images_dir(image_directory)
samples = []
if subsample is not None:
if subsample < len(dataset):
dataset = dataset.take(subsample).clone()
for sample in tqdm(dataset):
image_source, image = load_image(sample.filepath)
boxes, logits, phrases = predict(
model=model,
image=image,
caption=text_prompt,
box_threshold=box_threshold,
text_threshold=text_threshold,
)
detections = []
for box, logit, phrase in zip(boxes, logits, phrases):
rel_box = torchvision.ops.box_convert(box, 'cxcywh', 'xywh')
detections.append(
fo.Detection(
label=phrase,
bounding_box=rel_box,
confidence=logit,
))
# Store detections in a field name of your choice
sample["detections"] = fo.Detections(detections=detections)
sample.save()
# loads the voxel fiftyone UI ready for viewing the dataset.
if view_dataset:
session = fo.launch_app(dataset)
session.wait()
# exports COCO dataset ready for training
if export_dataset:
dataset.export(
'coco_dataset',
dataset_type=fo.types.COCODetectionDataset,
)
# saves bounding boxes plotted on the input images to disk
if export_annotated_images:
dataset.draw_labels(
'images_with_bounding_boxes',
label_fields=['detections']
)
if __name__ == '__main__':
typer.run(main)
import argparse
import os
from dataclasses import dataclass
from typing import Dict, Tuple
import torch
from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from groundingdino.util import get_tokenlizer
from groundingdino.models.GroundingDINO.bertwarper import (
generate_masks_with_special_tokens_and_transfer_map,
)
@dataclass
class TextInputs:
input_ids: torch.Tensor
token_type_ids: torch.Tensor
attention_mask: torch.Tensor
position_ids: torch.Tensor
text_self_attention_masks: torch.Tensor
class GroundingDINOOnnxWrapper(torch.nn.Module):
"""
ONNX 导出用 wrapper:
- 让 ONNX 输入完全是 Tensor(image + tokenized)
- 复用 GroundingDINO 原始 backbone/transformer/head
- 跳过 Python tokenizer/文本mask生成(这些在导出前就做好,作为输入喂进来)
"""
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
image: torch.Tensor, # [B,3,H,W] float32
input_ids: torch.Tensor, # [B,S] int64
token_type_ids: torch.Tensor, # [B,S] int64
attention_mask: torch.Tensor, # [B,S] int64 (用于 text_token_mask)
position_ids: torch.Tensor, # [B,S] int64
text_self_attention_masks: torch.Tensor, # [B,S,S] bool/int64 (用于 sub_sentence_present)
) -> Tuple[torch.Tensor, torch.Tensor]:
# ---- 文本编码(等价于 GroundingDINO.forward 内部实现)----
if self.model.sub_sentence_present:
tokenized_for_encoder: Dict[str, torch.Tensor] = {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_mask": text_self_attention_masks,
"position_ids": position_ids,
}
else:
tokenized_for_encoder = {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_mask": attention_mask,
}
bert_output = self.model.bert(**tokenized_for_encoder)
encoded_text = self.model.feat_map(bert_output["last_hidden_state"])
text_token_mask = attention_mask.to(torch.bool)
text_dict = {
"encoded_text": encoded_text,
"text_token_mask": text_token_mask,
"position_ids": position_ids,
"text_self_attention_masks": text_self_attention_masks.to(torch.bool),
}
# ---- 视觉编码 + transformer(基本照抄 GroundingDINO.forward)----
if isinstance(image, (list, torch.Tensor)):
from groundingdino.util.misc import nested_tensor_from_tensor_list
samples = nested_tensor_from_tensor_list(image)
else:
samples = image
self.model.set_image_tensor(samples)
import torch.nn.functional as F
from groundingdino.util.misc import NestedTensor
srcs = []
masks = []
for l, feat in enumerate(self.model.features):
src, mask = feat.decompose()
srcs.append(self.model.input_proj[l](src))
masks.append(mask)
if self.model.num_feature_levels > len(srcs):
_len_srcs = len(srcs)
for l in range(_len_srcs, self.model.num_feature_levels):
if l == _len_srcs:
src = self.model.input_proj[l](self.model.features[-1].tensors)
else:
src = self.model.input_proj[l](srcs[-1])
m = samples.mask
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
pos_l = self.model.backbone[1](NestedTensor(src, mask)).to(src.dtype)
srcs.append(src)
masks.append(mask)
self.model.poss.append(pos_l)
hs, reference, _, _, _ = self.model.transformer(
srcs, masks, None, self.model.poss, None, None, text_dict
)
from groundingdino.util.misc import inverse_sigmoid
outputs_coord_list = []
for layer_ref_sig, layer_bbox_embed, layer_hs in zip(
reference[:-1], self.model.bbox_embed, hs
):
layer_delta_unsig = layer_bbox_embed(layer_hs)
layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)
layer_outputs_unsig = layer_outputs_unsig.sigmoid()
outputs_coord_list.append(layer_outputs_unsig)
outputs_coord_list = torch.stack(outputs_coord_list)
outputs_class = torch.stack(
[
layer_cls_embed(layer_hs, text_dict)
for layer_cls_embed, layer_hs in zip(self.model.class_embed, hs)
]
)
pred_logits = outputs_class[-1]
pred_boxes = outputs_coord_list[-1]
self.model.unset_image_tensor()
return pred_logits, pred_boxes
def load_torch_model(model_config_path: str, model_checkpoint_path: str, device: str):
args = SLConfig.fromfile(model_config_path)
args.device = device
model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
model.eval()
model.to(device)
# ONNX trace 不兼容 gradient checkpoint,导出时统一关闭
for m in model.modules():
if hasattr(m, "use_checkpoint"):
try:
setattr(m, "use_checkpoint", False)
except Exception:
pass
return model, args
def build_text_inputs(
tokenizer,
caption: str,
device: str,
max_text_len: int,
special_token_ids,
) -> TextInputs:
caption = caption.lower().strip()
if not caption.endswith("."):
caption = caption + "."
tokenized = tokenizer([caption], padding="longest", return_tensors="pt")
tokenized = {k: v.to(device) for k, v in tokenized.items()}
text_self_attention_masks, position_ids, _ = generate_masks_with_special_tokens_and_transfer_map(
tokenized, special_token_ids, tokenizer
)
# 按 max_text_len 裁剪(与模型 forward 行为一致)
if text_self_attention_masks.shape[1] > max_text_len:
s = max_text_len
text_self_attention_masks = text_self_attention_masks[:, :s, :s]
position_ids = position_ids[:, :s]
tokenized["input_ids"] = tokenized["input_ids"][:, :s]
tokenized["attention_mask"] = tokenized["attention_mask"][:, :s]
tokenized["token_type_ids"] = tokenized["token_type_ids"][:, :s]
return TextInputs(
input_ids=tokenized["input_ids"].to(torch.int64),
token_type_ids=tokenized["token_type_ids"].to(torch.int64),
attention_mask=tokenized["attention_mask"].to(torch.int64),
position_ids=position_ids.to(torch.int64),
text_self_attention_masks=text_self_attention_masks,
)
def main():
parser = argparse.ArgumentParser("Export GroundingDINO to ONNX", add_help=True)
parser.add_argument("--config_file", "-c", type=str, required=True)
parser.add_argument("--checkpoint_path", "-p", type=str, required=True)
parser.add_argument("--output_onnx", "-o", type=str, required=True, help="输出 onnx 路径")
parser.add_argument("--text_prompt", "-t", type=str, required=True, help="用于构建 dummy 文本输入(影响 seq_len)")
parser.add_argument("--opset", type=int, default=17)
parser.add_argument("--cpu-only", action="store_true")
parser.add_argument("--dynamic", action="store_true", help="启用动态 H/W 与 seq_len 轴(更通用但可能更慢/更难优化)")
parser.add_argument("--simplify", action="store_true", help="尝试用 onnxsim 简化(需要安装 onnxsim)")
parser.add_argument("--image_hw", type=str, default="800,1333", help="dummy image H,W(默认与 transform 常见输出一致)")
args = parser.parse_args()
device = "cpu" if args.cpu_only else ("cuda" if torch.cuda.is_available() else "cpu")
model, cfg = load_torch_model(args.config_file, args.checkpoint_path, device=device)
tokenizer = get_tokenlizer.get_tokenlizer(cfg.text_encoder_type)
special_token_ids = tokenizer.convert_tokens_to_ids(["[CLS]", "[SEP]", ".", "?"])
text_inputs = build_text_inputs(
tokenizer=tokenizer,
caption=args.text_prompt,
device=device,
max_text_len=getattr(cfg, "max_text_len", 256),
special_token_ids=special_token_ids,
)
h_str, w_str = args.image_hw.split(",")
H, W = int(h_str), int(w_str)
dummy_image = torch.randn(1, 3, H, W, device=device, dtype=torch.float32)
wrapper = GroundingDINOOnnxWrapper(model)
wrapper.eval()
os.makedirs(os.path.dirname(os.path.abspath(args.output_onnx)) or ".", exist_ok=True)
input_names = [
"image",
"input_ids",
"token_type_ids",
"attention_mask",
"position_ids",
"text_self_attention_masks",
]
output_names = ["pred_logits", "pred_boxes"]
dynamic_axes = None
if args.dynamic:
dynamic_axes = {
"image": {0: "batch", 2: "height", 3: "width"},
"input_ids": {0: "batch", 1: "seq"},
"token_type_ids": {0: "batch", 1: "seq"},
"attention_mask": {0: "batch", 1: "seq"},
"position_ids": {0: "batch", 1: "seq"},
"text_self_attention_masks": {0: "batch", 1: "seq", 2: "seq"},
"pred_logits": {0: "batch"},
"pred_boxes": {0: "batch"},
}
with torch.no_grad():
torch.onnx.export(
wrapper,
(
dummy_image,
text_inputs.input_ids,
text_inputs.token_type_ids,
text_inputs.attention_mask,
text_inputs.position_ids,
text_inputs.text_self_attention_masks,
),
args.output_onnx,
opset_version=args.opset,
do_constant_folding=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
)
if args.simplify:
try:
import onnx
from onnxsim import simplify
onnx_model = onnx.load(args.output_onnx)
simplified_model, ok = simplify(onnx_model)
if ok:
onnx.save(simplified_model, args.output_onnx)
print(f"✅ onnxsim 简化完成: {args.output_onnx}")
else:
print("⚠️ onnxsim 简化失败(模型未修改)")
except Exception as e:
print(f"⚠️ onnxsim 简化跳过: {e}")
print(f"✅ 导出完成: {args.output_onnx}")
if __name__ == "__main__":
main()
import argparse
from functools import partial
import cv2
import requests
import os
from io import BytesIO
from PIL import Image
import numpy as np
from pathlib import Path
import warnings
import torch
# prepare the environment
os.system("python setup.py build develop --user")
os.system("pip install packaging==21.3")
os.system("pip install gradio==3.50.2")
warnings.filterwarnings("ignore")
import gradio as gr
from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from groundingdino.util.inference import annotate, load_image, predict
import groundingdino.datasets.transforms as T
from huggingface_hub import hf_hub_download
# Use this command for evaluate the Grounding DINO model
config_file = "groundingdino/config/GroundingDINO_SwinT_OGC.py"
ckpt_repo_id = "ShilongLiu/GroundingDINO"
ckpt_filenmae = "groundingdino_swint_ogc.pth"
def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
args = SLConfig.fromfile(model_config_path)
model = build_model(args)
args.device = device
cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
checkpoint = torch.load(cache_file, map_location='cpu')
log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
print("Model loaded from {} \n => {}".format(cache_file, log))
_ = model.eval()
return model
def image_transform_grounding(init_image):
transform = T.Compose([
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image, _ = transform(init_image, None) # 3, h, w
return init_image, image
def image_transform_grounding_for_vis(init_image):
transform = T.Compose([
T.RandomResize([800], max_size=1333),
])
image, _ = transform(init_image, None) # 3, h, w
return image
model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
def run_grounding(input_image, grounding_caption, box_threshold, text_threshold):
init_image = input_image.convert("RGB")
original_size = init_image.size
_, image_tensor = image_transform_grounding(init_image)
image_pil: Image = image_transform_grounding_for_vis(init_image)
# run grounidng
boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
return image_with_box
if __name__ == "__main__":
parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True)
parser.add_argument("--debug", action="store_true", help="using debug mode")
parser.add_argument("--share", action="store_true", help="share the app")
args = parser.parse_args()
block = gr.Blocks().queue()
with block:
gr.Markdown("# [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO)")
gr.Markdown("### Open-World Detection with Grounding DINO")
with gr.Row():
with gr.Column():
input_image = gr.Image(source='upload', type="pil")
grounding_caption = gr.Textbox(label="Detection Prompt")
run_button = gr.Button(label="Run")
with gr.Accordion("Advanced options", open=False):
box_threshold = gr.Slider(
label="Box Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
)
text_threshold = gr.Slider(
label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
)
with gr.Column():
gallery = gr.outputs.Image(
type="pil",
# label="grounding results"
).style(full_width=True, full_height=True)
# gallery = gr.Gallery(label="Generated images", show_label=False).style(
# grid=[1], height="auto", container=True, full_width=True, full_height=True)
run_button.click(fn=run_grounding, inputs=[
input_image, grounding_caption, box_threshold, text_threshold], outputs=[gallery])
block.launch(server_name='0.0.0.0', server_port=7579, debug=args.debug, share=args.share)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment