"csrc/git@developer.sourcefind.cn:jerrrrry/infinilm.git" did not exist on "e76bb324ac048e7c91f165610cbec151791dbb62"
Commit 0816cc7e authored by mashun1's avatar mashun1
Browse files

bagel

parents
Pipeline #2730 canceled with stages
wandb
__pycache__
.vscode
notebooks
results
*.ipynb_checkpoints
eval_results
tests
.DS_Store
gradio.sh
\ No newline at end of file
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.4.1-ubuntu22.04-dtk25.04-py3.10
\ No newline at end of file
# VLM
We follow [InternVL2](https://internvl.readthedocs.io/en/latest/internvl2.0/evaluation.html) to evaluate the performance on MME, MMBench, MMMU, MMVet, MathVista and MMVP.
## Data prepration
Please follow the [InternVL2](https://internvl.readthedocs.io/en/latest/get_started/eval_data_preparation.html) to prepare the corresponding data. And the link the data under `vlm`.
The final directory structure is:
```shell
data
├── MathVista
├── mmbench
├── mme
├── MMMU
├── mm-vet
└── MMVP
```
## Evaluation
Directly run `scripts/eval/run_eval_vlm.sh` to evaluate different benchmarks. The output will be saved in `$output_path`.
- Set `$model_path` and `$output_path` for the path for checkpoint and log.
- Increase `GPUS` if you want to run faster.
- For MMBench, please use the official [evaluation server](https://mmbench.opencompass.org.cn/mmbench-submission).
- For MMVet, please use the official [evaluation server](https://huggingface.co/spaces/whyu/MM-Vet_Evaluator).
- For MathVista, please set `$openai_api_key` in `scripts/eval/run_eval_vlm.sh` and `your_api_url` in `eval/vlm/eval/mathvista/utilities.py`. The default GPT version is `gpt-4o-2024-11-20`.
- For MMMU, we use CoT in the report, which improve the accuracy by about 2%. For evaluation of the oprn-ended answer, we use GPT-4o for judgement.
# GenEval
We modify the code in [GenEval](https://github.com/djghosh13/geneval/tree/main) for faster evaluation.
## Setup
Install the following dependencies:
```shell
pip install open-clip-torch
pip install clip-benchmark
pip install --upgrade setuptools
sudo pip install -U openmim
sudo mim install mmengine mmcv-full==1.7.2
git clone https://github.com/open-mmlab/mmdetection.git
cd mmdetection; git checkout 2.x
pip install -v -e .
```
Download Detector:
```shell
cd ./eval/gen/geneval
mkdir model
bash ./evaluation/download_models.sh ./model
```
## Evaluation
Directly run `scripts/eval/run_geneval.sh` to evaluate GenEVAL. The output will be saved in `$output_path`.
- Set `$model_path` and `$output_path` for the path for checkpoint and log.
- Set `metadata_file` to `./eval/gen/geneval/prompts/evaluation_metadata.jsonl` for original GenEval prompts.
# WISE
We modify the code in [WISE](https://github.com/PKU-YuanGroup/WISE/tree/main) for faster evaluation.
## Evaluation
Directly run `scripts/eval/run_wise.sh` to evaluate WISE. The output will be saved in `$output_path`.
- Set `$model_path` and `$output_path` for the path for checkpoint and log.
- Set `$openai_api_key` in `scripts/eval/run_wise.sh` and `your_api_url` in `eval/gen/wise/gpt_eval_mp.py`. The default GPT version is `gpt-4o-2024-11-20`.
- Use `think` for thinking mode.
# GEdit-Bench
Please follow [GEdit-Bench](https://github.com/stepfun-ai/Step1X-Edit/blob/main/GEdit-Bench/EVAL.md) for evaluation.
# IntelligentBench
TBD
\ No newline at end of file
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
\ No newline at end of file
# Bagel
## 论文
`Emerging Properties in Unified Multimodal Pretraining`
* https://arxiv.org/abs/2505.14683
## 模型结构
BAGEL 采用了一种 MoT(Mixture of Transformers)架构,由两个 Transformer 专家组成:一个专注于多模态理解,另一个专注于多模态生成。相应地,模型使用了两个独立的视觉编码器:一个用于理解,一个用于生成。
![alt text](readme_imgs/arch.png)
## 算法原理
两个 Transformer 专家在每一层都通过共享的自注意力机制(self-attention)处理同一个 token 序列。在文本 token 的预测上,BAGEL 遵循 Next-Token-Prediction(下一个 token 预测)范式,延续了自回归语言模型的成熟优势。而在视觉 token 的预测上,BAGEL 采用了 Rectified Flow 方法。
![alt text](readme_imgs/alg.png)
## 环境配置
注意:该项目提供的安装包只适配该项目。
### Docker(方法一)
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.4.1-ubuntu22.04-dtk25.04-py3.10
docker run --shm-size 100g --network=host --name=bagel --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it <your IMAGE ID> bash
pip install -r requirements.txt
pip install whl/torch-2.5.1+das.opt1.dtk2504-cp310-cp310-manylinux_2_28_x86_64.whl
pip install whl/torchvision-0.20.1a0-04d8fc4-cp310-cp310-manylinux_2_28_x86_64.whl
### Dockerfile(方法二)
docker build -t <IMAGE_NAME>:<TAG> .
docker run --shm-size 100g --network=host --name=bagel --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it <your IMAGE ID> bash
pip install -r requirements.txt
pip install whl/torch-2.5.1+das.opt1.dtk2504-cp310-cp310-manylinux_2_28_x86_64.whl
pip install whl/torchvision-0.20.1a0-04d8fc4-cp310-cp310-manylinux_2_28_x86_64.whl
### Anaconda(方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/
```
DTK驱动:dtk25.04
python:python3.10
triton:3.0
flash-attn:2.6.1
deepspeed:0.14.2
apex:1.4.0
```
2、其他非特殊库直接按照requirements.txt安装
```
pip install -r requirements.txt
pip install whl/torch-2.5.1+das.opt1.dtk2504-cp310-cp310-manylinux_2_28_x86_64.whl
pip install whl/torchvision-0.20.1a0-04d8fc4-cp310-cp310-manylinux_2_28_x86_64.whl
```
## 数据集
## 训练
## 推理
```bash
python inference_test.py
```
注意:在运行前需要修改相应参数,`model_path`
## result
|任务|输入1|输入2|输出|
|:---:|:---|:---:|:---|
|t2i|A female cosplayer portraying an ethereal fairy or elf, wearing a flowing dress made of delicate fabrics in soft, mystical colors like emerald green and silver. She has pointed ears, a gentle, enchanting expression, and her outfit is adorned with sparkling jewels and intricate patterns. The background is a magical forest with glowing plants, mystical creatures, and a serene atmosphere||![](readme_imgs/r1.png)|
|editing|She boards a modern subway, quietly reading a folded newspaper, wearing the same clothes.|![](test_images/women.jpg)|![](readme_imgs/r2.png)|
||Can someone explain what’s funny about this meme??|![](test_images/meme.jpg)|The humor in this meme comes from the exaggerated change in handwriting over the course of an exam. At the beginning, the handwriting is clear and legible, indicating a calm and focused start to the exam. However, as the exam progresses, the handwriting becomes increasingly difficult to read, suggesting that the student becomes more anxious and less able to write clearly. The final part of the meme shows an electrocardiogram (ECG) reading, which humorously implies that the student's heart rate is racing and their writing is becoming erratic, much like the fluctuations seen in an ECG. This visual pun highlights the common experience of feeling more nervous and less in control as an exam progresses.|
### 精度
## 应用场景
### 算法类别
`多模态`
### 热点应用行业
`电商,教育,广媒`
## 预训练权重
[BAGEL-7B-MoT](https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT)
## 源码仓库及问题反馈
* https://developer.sourcefind.cn/codes/modelzoo/bagel_pytorch
## 参考资料
* https://github.com/ByteDance-Seed/Bagel/tree/main
<p align="center">
<img src="https://lf3-static.bytednsdoc.com/obj/eden-cn/nuhojubrps/banner.png" alt="BAGEL" width="480"/>
</p>
<p align="center">
<a href="https://bagel-ai.org/">
<img
src="https://img.shields.io/badge/BAGEL-Website-0A66C2?logo=safari&logoColor=white"
alt="BAGEL Website"
/>
</a>
<a href="https://arxiv.org/abs/2505.14683">
<img
src="https://img.shields.io/badge/BAGEL-Paper-red?logo=arxiv&logoColor=red"
alt="BAGEL Paper on arXiv"
/>
</a>
<a href="https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT">
<img
src="https://img.shields.io/badge/BAGEL-Model-yellow?logo=huggingface&logoColor=yellow"
alt="BAGEL Model"
/>
</a>
<a href="https://demo.bagel-ai.org/">
<img
src="https://img.shields.io/badge/BAGEL-Demo-blue?logo=googleplay&logoColor=blue"
alt="BAGEL Demo"
/>
</a>
<a href="https://huggingface.co/spaces/ByteDance-Seed/BAGEL">
<img
src="https://img.shields.io/badge/BAGEL-Space-orange?logo=huggingface&logoColor=yellow"
alt="BAGEL Model"
/>
</a>
<a href="https://discord.gg/Z836xxzy">
<img
src="https://img.shields.io/badge/BAGEL-Discord-5865F2?logo=discord&logoColor=purple"
alt="BAGEL Discord"
/>
</a>
<a href="mailto:bagel@bytedance.com">
<img
src="https://img.shields.io/badge/BAGEL-Email-D14836?logo=gmail&logoColor=red"
alt="BAGEL Email"
/>
</a>
</p>
# Unified Model for Multimodal Understanding and Generation
> [Chaorui Deng*](https://scholar.google.com/citations?hl=en&user=k0TWfBoAAAAJ), [Deyao Zhu*](https://tsutikgiau.github.io/), [Kunchang Li*](https://andy1621.github.io/), [Chenhui Gou*](https://www.linkedin.com/in/chenhui-gou-9201081a1/?originalSubdomain=au), [Feng Li*](https://fengli-ust.github.io/), [Zeyu Wang](https://zw615.github.io/), Shu Zhong, [Weihao Yu](https://whyu.me/), [Xiaonan Nie](https://codecaution.github.io/), [Ziang Song](https://www.linkedin.com/in/ziang-song-43b0ab8a/), Guang Shi :email: , [Haoqi Fan* :tophat: ](https://haoqifan.github.io/)
>
> contact: shiguang.sg@bytedance.com
>
> We present **BAGEL**, an open‑source multimodal foundation model with 7B active parameters (14B total) trained on large‑scale interleaved multimodal data. BAGEL outperforms the current top‑tier open‑source VLMs like Qwen2.5-VL and InternVL-2.5 on standard multimodal understanding leaderboards, and delivers text‑to‑image quality that is competitive with strong specialist generators such as SD3.
Moreover, BAGEL demonstrates superior qualitative results in classical image‑editing scenarios than the leading open-source models. More importantly, it extends to free-form visual manipulation, multiview synthesis, and world navigation, capabilities that constitute "world-modeling" tasks beyond the scope of previous image-editing models.
The figure below showcases BAGEL's qualitative performance.
<p align="center"><img src="assets/teaser.webp" width="95%"></p>
<!-- ## 🧠 Method
BAGEL adopts a Mixture-of-Transformer-Experts (MoT) architecture to maximize the model’s capacity to learn from richly diverse multimodal information. Following the same principle of capacity maximization, it utilizes two separate encoders to capture pixel-level and semantic-level features of an image. The overall framework follows a Next Group of Token Prediction paradigm, where the model is trained to predict the next group of language or visual tokens as a compression target.
BAGEL scales MoT’s capacity through Pre-training, Continued Training, and Supervised Finetuning on trillions of interleaved multimodal tokens spanning language, image, video, and web data. It surpasses open models on standard understanding and generation benchmarks and demonstrates advanced in-context multimodal abilities like free-form image editing, future frame prediction, 3D manipulation, world navigation, and sequential reasoning.
<p align="center"><img src="assets/arch.png" width="95%"></p>
## 🌱 Emerging Properties
<p align="center"><img src="assets/emerging_curves.png" width="95%"></p>
As we scale up BAGEL’s pretraining with more multimodal tokens, we observe consistent performance gains across understanding, generation, and editing tasks. Different capabilities emerge at distinct training stages—multimodal understanding and generation appear early, followed by basic editing, while complex, intelligent editing emerges later. This staged progression suggests an emergent pattern, where advanced multimodal reasoning builds on well-formed foundational skills. Ablation studies further show that combining VAE and ViT features significantly improves intelligent editing, underscoring the importance of visual-semantic context in enabling complex multimodal reasoning and further supporting its role in the emergence of advanced capabilities. -->
## 📢 News
We sincerely thank all contributors from the open community for their valuable support.
- **May 26, 2025:** Thanks to [@neverbiasu](https://github.com/neverbiasu) for contributing [ComfyUI](https://github.com/neverbiasu/ComfyUI-BAGEL).
- **May 25, 2025:** Special thanks to [@LeanModels](https://github.com/LeanModels) for providing the [DF11-compressed version](https://huggingface.co/DFloat11/BAGEL-7B-MoT-DF11), and to [@Gapeleon](https://huggingface.co/Gapeleon) for the [INT8-compressed version](https://huggingface.co/Gapeleon/bytedance_BAGEL-7B-MoT-INT8). We also appreciate [@gluttony-10](https://github.com/gluttony-10) for contributions to the [Windows package](https://github.com/ByteDance-Seed/Bagel/issues/51).
- **May 24, 2025:** Together with [@wangwei1237](https://github.com/wangwei1237), [@gluttony-10](https://github.com/gluttony-10), and [@KingNish24](https://github.com/KingNish24), we built a Gradio [app](app.py) and launched a [Hugging Face Space](https://huggingface.co/spaces/ByteDance-Seed/BAGEL).
- **May 23, 2025:** We have provided a training guideline in [TRAIN](./TRAIN.md).
- **May 20, 2025:** We released the official [website](https://bagel-ai.org/), [demo](https://demo.bagel-ai.org/), [model](https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT), and [report](https://arxiv.org/abs/2505.14683) for BAGEL.
## 📮 Notice
**Call for Bad Cases:** If you have encountered any cases where the model performs poorly, we would greatly appreciate it if you could share them in the [issue#11](https://github.com/ByteDance-Seed/Bagel/issues/11) or [Discord](https://discord.gg/Z836xxzy).
**About Inference Hyperparameters:**
- **`cfg_text_scale`:** Controls how strongly the model follows the text prompt. `1.0` disables text guidance. Typical range: `4.0–8.0`.
- **`cfg_image_scale`:** Controls how much the model preserves input image details. `1.0` disables image guidance. Typical range: `1.0–2.0`.
- **`cfg_interval`:** Fraction of denoising steps where CFG is applied. Later steps can skip CFG to reduce computation. Typical: `[0.4, 1.0]`.
- **`timestep_shift`:** Shifts the distribution of denoising steps. Higher values allocate more steps at the start (affects layout); lower values allocate more at the end (improves details).
- **`num_timesteps`:** Total denoising steps. Typical: `50`.
- **`cfg_renorm_min`:** Minimum value for CFG-Renorm. `1.0` disables renorm. Typical: `0`.
- **`cfg_renorm_type`:** CFG-Renorm method:
- `global`: Normalize over all tokens and channels (default for T2I).
- `channel`: Normalize across channels for each token.
- `text_channel`: Like `channel`, but only applies to text condition (good for editing, may cause blur).
- **If edited images appear blurry, try `global` CFG-Renorm, decrease `cfg_renorm_min` or decrease `cfg_scale`.**
## 🔥 Quick Start
1️⃣ Set up environment
```bash
git clone https://github.com/bytedance-seed/BAGEL.git
cd BAGEL
conda create -n bagel python=3.10 -y
conda activate bagel
pip install -r requirements.txt
```
2️⃣ Download pretrained checkpoint
```python
from huggingface_hub import snapshot_download
save_dir = "/path/to/save/BAGEL-7B-MoT"
repo_id = "ByteDance-Seed/BAGEL-7B-MoT"
cache_dir = save_dir + "/cache"
snapshot_download(cache_dir=cache_dir,
local_dir=save_dir,
repo_id=repo_id,
local_dir_use_symlinks=False,
resume_download=True,
allow_patterns=["*.json", "*.safetensors", "*.bin", "*.py", "*.md", "*.txt"],
)
```
3️⃣ Go to [`inference.ipynb`](inference.ipynb) to start playing with BAGEL!
4️⃣ Use Gradio WebUI to start playing with BAGEL!
```bash
pip install gradio
python app.py
```
## 🔥 Train & Eval
### Train
```bash
bash scripts/train.sh
```
You can replace the variables in the script with your own before running.
See [TRAIN](TRAIN.md) for more details.
### Eval
We provide the scripts for evaluating VLM, T2I and Editing benchmarks.
Please See [EVAL](EVAL.md) for more details.
## 📊 Benchmarks
### 1. Visual Understanding
| Model | MME ↑ | MMBench ↑ | MMMU ↑ | MM-Vet ↑ | MathVista ↑ |
| ------------------- | ----------: | ----------: | -------: | -------: | ----------: |
| Janus-Pro-7B | - | 79.2 | 41.0 | 50.0 | – |
| Qwen2.5-VL-7B | 2347 | 83.5 | **58.6** | 67.1 | 68.2 |
| **BAGEL** | **2388** | **85.0** | 55.3 | **67.2** | **73.1** |
### 2. Text-to-Image Generation
| Model | GenEval ↑ | WISE ↑|
| ------------ | --------- | --------- |
| Janus-Pro-7B | 0.80 | 0.35 |
| SD3-Medium | 0.74 | - |
| FLUX-1-dev | 0.82 | 0.50 |
| **BAGEL** | - | **0.52** |
| **BAGEL + CoT** | **0.88** | **0.70** |
### 3. Image Editing
| Model | GEdit-Bench-EN (SC) ↑ | GEdit-Bench-EN (PQ) ↑ | GEdit-Bench-EN (O) ↑ | IntelligentBench ↑ |
| ------------- | --------------------- | --------------------- | ------------------- | ------------------ |
| Step1X-Edit | 7.09 | 6.76 | **6.70** | 14.9 |
| Gemini-2-exp. | 6.73 | 6.61 | 6.32 | **57.6** |
| **BAGEL** | **7.36** | **6.83** | 6.52 | 44.0 |
| **BAGEL+CoT** | – | – | – | 55.3 |
## ✍️ Citation
```bibtex
@article{deng2025bagel,
title = {Emerging Properties in Unified Multimodal Pretraining},
author = {Deng, Chaorui and Zhu, Deyao and Li, Kunchang and Gou, Chenhui and Li, Feng and Wang, Zeyu and Zhong, Shu and Yu, Weihao and Nie, Xiaonan and Song, Ziang and Shi, Guang and Fan, Haoqi},
journal = {arXiv preprint arXiv:2505.14683},
year = {2025}
}
```
## 📜 License
BAGEL is licensed under the Apache 2.0.
# Data prepration
We provide data examples for **T2I**, **Editing**, and **VLM** tasks. The T2I dataset is generated using [FLUX.1‑dev](https://huggingface.co/black-forest-labs/FLUX.1-dev); the editing examples are randomly sampled from [SEED‑Data‑Edit‑Part3](https://huggingface.co/datasets/AILab-CVC/SEED-Data-Edit-Part2-3); and the VLM set is sourced from [LLaVA‑OneVision‑Data](https://huggingface.co/datasets/lmms-lab/LLaVA-OneVision-Data).
We offer examples in both raw-image folder and parquet shard formats. For other data formats, you can use our dataset code as a template and extend it as needed.
1. **Download the sample dataset**
```bash
wget -O bagel_example.zip \
https://lf3-static.bytednsdoc.com/obj/eden-cn/nuhojubrps/bagel_example.zip
unzip bagel_example.zip -d /data
```
2. **Expected hierarchy**
```text
bagel_example
├── t2i/ # text-to-image (parquet)
├── editing/ # image editing (parquet)
│ ├── seedxedit_multi/
│ └── parquet_info/
└── vlm/
├── images/ # JPEG / PNG frames
└── llava_ov_si.jsonl # vision‑language SFT conversations
```
3. Edit every `your_data_path` placeholder in **`data/dataset_info.py`**.
4. *(Optional)* Extend `DATASET_INFO` with your own parquet shards or JSONL files to mix extra data.
---
# Training
The baseline training recipe looks like this (replace environment variables with real paths or values):
```shell
# Pre-training
torchrun \
--nnodes=$num_nodes \
--node_rank=$node_rank \
--nproc_per_node=8 \
--master_addr=$master_addr \
--master_port=$master_port \
train/pretrain_unified_navit.py \
--dataset_config_file ./data/configs/example.yaml \
--llm_path $llm_path \
--vae_path $vae_path \
--vit_path $vit_path \
--layer_module Qwen2MoTDecoderLayer \
--use_flex True \
--resume_from $resume_from \
--results_dir $output_path \
--checkpoint_dir $ckpt_path \
--max_latent_size 64 # 32 for low-resolution pre-training
# Fine-tuning
torchrun \
--nnodes=$num_nodes \
--node_rank=$node_rank \
--nproc_per_node=8 \
--master_addr=$master_addr \
--master_port=$master_port \
train/pretrain_unified_navit.py \
--dataset_config_file ./data/configs/example.yaml \
--model_path $model_path \
--layer_module Qwen2MoTDecoderLayer \
--max_latent_size 64 \
--resume-from $model_path \
--finetune_from_hf True \
--auto_resume True \
--resume-model-only True \
--finetune-from-ema True \
--log_every 1 \
--lr 2e-5 \
--num_worker 1 \
--expected_num_tokens 10240 \
--max_num_tokens 11520 \
--max_num_tokens_per_sample 10240
```
- **When fine-tuning BAGEL, set `max_latent_size=64` to ensure the correct pretrained weights are loaded.** If this is not set, an out-of-bounds error may occur.
- The total value of `num_used_data` should be greater than `NUM_GPUS × NUM_WORKERS`. (For toy data, use `num_worker=1`.)
- For T2I-only fine-tuning, set `visual_und=False`. For VLM-only fine-tuning, set `visual_gen=False`.
- For debugging purposes, use smaller values for `expected_num_tokens`, `max_num_tokens`, and `max_num_tokens_per_sample`.
- When fine-tuning on toy data, the loss behaves as follows:
```shell
[2025-05-25 17:01:37] (step=0000000) Train Loss mse: 0.4063, Train Loss ce: 0.5504, Train Steps/Sec: 0.01,
[2025-05-25 17:01:40] (step=0000001) Train Loss mse: 0.4121, Train Loss ce: 0.8152, Train Steps/Sec: 0.44,
[2025-05-25 17:01:42] (step=0000002) Train Loss mse: 0.3876, Train Loss ce: 1.3411, Train Steps/Sec: 0.40,
[2025-05-25 17:01:45] (step=0000003) Train Loss mse: 0.3825, Train Loss ce: 0.7360, Train Steps/Sec: 0.44,
```
You are encouraged to adjust any of these hyperparameters to fit your GPU budget and the scale of your dataset. If you encounter any issues, please open an issue for assistance. 🎉
## Model config
| Argument | Default | Description |
| ---------------------------- | ------------------------------------------- | --------------------------------------------------------------- |
| `llm_path` | `hf/Qwen2.5-0.5B-Instruct` | Language‑model backbone (HuggingFace repo or local folder). |
| `vae_path` | `flux/vae/ae.safetensors` | Pre‑trained VAE checkpoint for latent diffusion. |
| `vit_path` | `hf/siglip-so400m-14-980-flash-attn2-navit` | SigLIP ViT used for image understanding. |
| `max_latent_size` | `32` | Maximum latent grid side; defines highest generable resolution. |
| `latent_patch_size` | `2` | VAE pixels represented by one latent patch. |
| `vit_max_num_patch_per_side` | `70` | Max ViT patches per image side after resizing. |
| `text_cond_dropout_prob` | `0.1` | Probability to drop text conditioning while training. |
| `vae_cond_dropout_prob` | `0.3` | Dropout on VAE latent inputs. |
| `vit_cond_dropout_prob` | `0.3` | Dropout on visual features. |
*(See `ModelArguments` for many more options.)*
## Data config
| Argument | Default | Description |
| --------------------------- | --------------------------- | --------------------------------------------------------- |
| `dataset_config_file` | `data/configs/example.yaml` | YAML that groups datasets and assigns sampling weights. |
| `num_workers` | `4` | Background workers per rank for the PyTorch `DataLoader`. |
| `prefetch_factor` | `2` | Batches pre‑fetched by each worker. |
| `max_num_tokens_per_sample` | `16384` | Skip raw samples longer than this. |
| `max_num_tokens` | `36864` | Hard cap for a packed batch (prevents OOM). |
| `max_buffer_size` | `50` | Overflow buffer length for oversized samples. |
| `data_seed` | `42` | Seed for reproducible shuffling and sampling. |
## Training config
| Argument | Default | Description |
| -------------------------------------- | ---------------------- | ------------------------------------------------------ |
| `total_steps` | `500_000` | Optimiser steps to run. |
| `lr` | `1e-4` | Peak learning rate after warm‑up. |
| `lr_scheduler` | `constant` | Learning‑rate schedule (`constant` or `cosine`). |
| `warmup_steps` | `2000` | Linear warm‑up duration. |
| `ema` | `0.9999` | Exponential moving‑average decay for model weights. |
| `max_grad_norm` | `1.0` | Gradient‑clipping threshold. |
| `save_every` | `2000` | Checkpoint frequency (steps). |
| `visual_gen / visual_und` | `True` | Enable image generation / understanding branches. |
| `freeze_llm / freeze_vit / freeze_vae` | `False / False / True` | Freeze selected modules to save VRAM or for ablations. |
| `use_flex` | `True` (in example) | Enable FLEX packing for higher GPU utilisation. |
| `sharding_strategy` | `HYBRID_SHARD` | FSDP sharding mode. |
| `num_shard` | `8` | Parameter shards per rank in HYBRID mode. |
**Distributed‑launch environment variables**
| Var | Meaning |
| ----------------------------- | --------------------------------- |
| `num_nodes` / `node_rank` | Multi‑node orchestration indices. |
| `nproc_per_node` | Number of GPUs per node. |
| `master_addr` / `master_port` | NCCL rendezvous endpoint. |
## Logging config
| Argument | Default | Description |
| ---------------- | --------------------- | ---------------------------------------------------- |
| `results_dir` | `results` | Root directory for logs and metrics. |
| `checkpoint_dir` | `results/checkpoints` | Checkpoints are saved here. |
| `log_every` | `10` | Steps between console / W\&B logs. |
| `wandb_project` | `bagel` | Weights & Biases project name. |
| `wandb_name` | `run` | Run name inside the project. |
| `wandb_offline` | `False` | Switch to offline mode (logs locally, sync later). |
| `wandb_resume` | `allow` | Resumption policy if an existing run ID is detected. |
> **Tip** Export `WANDB_API_KEY` before launching if you want online dashboards.
\ No newline at end of file
This diff is collapsed.
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
\ No newline at end of file
t2i_pretrain:
dataset_names:
- t2i
image_transform_args:
image_stride: 16
max_image_size: 1024
min_image_size: 512
is_mandatory: true
num_used_data: # The sum should be larger that NUM_GPUS x NUM_WORKERS
- 10
weight: 1
unified_edit:
dataset_names:
- seedxedit_multi
image_transform_args:
image_stride: 16
max_image_size: 1024
min_image_size: 512
vit_image_transform_args:
image_stride: 14
max_image_size: 518
min_image_size: 224
is_mandatory: false
num_used_data:
- 10
weight: 1
vlm_sft:
dataset_names:
- llava_ov
image_transform_args:
image_stride: 14
max_image_size: 980
min_image_size: 378
max_pixels: 2_007_040
frame_sampler_args:
max_num_frames: 12
min_num_frames: 8
is_mandatory: true
shuffle_lines: True
shuffle_seed: 0
num_used_data:
- 1000
weight: 1
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import math
import random
from PIL import Image
import torch
from torch.nn.attention.flex_attention import or_masks, and_masks
def create_sparse_mask(document_lens, split_lens, attn_modes, device):
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
def full_and_noise_mask(b, h, q_idx, kv_idx):
return (full_and_noise_seq_id[q_idx] == full_and_noise_seq_id[kv_idx]) & (full_and_noise_seq_id[q_idx] >= 0)
def remove_noise_mask(b, h, q_idx, kv_idx):
return (~((noise_seq_id[kv_idx] >= 0) & (noise_seq_id[q_idx] != noise_seq_id[kv_idx])))
def sample_mask(b, h, q_idx, kv_idx):
return document_id[q_idx] == document_id[kv_idx]
full_and_noise_tmp = []
noise_tmp = []
for i, (length, model) in enumerate(zip(split_lens, attn_modes)):
value = i if model in ['full', 'noise'] else -1
full_and_noise_tmp.extend([value] * length)
value_noise = i if model == 'noise' else -1
noise_tmp.extend([value_noise] * length)
full_and_noise_seq_id = torch.Tensor(full_and_noise_tmp).to(device)
noise_seq_id = torch.Tensor(noise_tmp).to(device)
document_id = torch.cat([torch.full((l,), i) for i, l in enumerate(document_lens, start=1)]).to(device)
return and_masks(or_masks(causal_mask, full_and_noise_mask), remove_noise_mask, sample_mask)
def patchify(image, patch_size):
p = patch_size
c, h, w = image.shape
assert h % p == 0 and w % p == 0
image = image.reshape(c, h // p, p, w // p, p)
image = torch.einsum("chpwq->hwpqc", image)
image = image.reshape(-1, p**2 * c)
return image
def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side):
num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
coords_h = torch.arange(0, num_patches_h)
coords_w = torch.arange(0, num_patches_w)
pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten()
return pos_ids
def get_flattened_position_ids_interpolate(img_h, img_w, patch_size, max_num_patches_per_side):
num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
boundaries = torch.arange(1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side)
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
pos_ids = (bucket_coords_h[:, None] * max_num_patches_per_side + bucket_coords_w).flatten()
return pos_ids
def prepare_attention_mask_per_sample(split_lens, attn_modes, device="cpu"):
"""
nested_split_lens: A list of N lists of ints. Each int indicates the length of a split within
a sample, where each sample contains multiple splits with different attn modes.
nested_attn_modes: whether to use full attn in each split.
"""
sample_len = sum(split_lens)
attention_mask = torch.zeros((sample_len, sample_len), dtype=torch.bool, device=device)
csum = 0
for s, attn_mode in zip(split_lens, attn_modes):
assert attn_mode in ['causal', 'full', 'noise']
if attn_mode == "causal":
attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s), device=device).tril()
attention_mask[csum:csum + s, :csum] = 1
else:
attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s))
attention_mask[csum:csum + s, :csum] = 1
csum += s
csum = 0
for s, attn_mode in zip(split_lens, attn_modes):
if attn_mode == "noise":
attention_mask[:, csum : csum + s] = torch.zeros((sample_len, s))
attention_mask[csum : csum + s, csum : csum + s] = torch.ones((s, s))
csum += s
attention_mask = torch.zeros_like(attention_mask, dtype=torch.float).masked_fill_(
~attention_mask, float("-inf")
)
return attention_mask
def split_integer_exp_decay(S, ng_sample_decay=1.0):
if ng_sample_decay == 1.0:
N = random.randint(1, S)
else:
base = (1 - ng_sample_decay) / (1 - math.pow(ng_sample_decay, S))
p = [base * math.pow(ng_sample_decay, i) for i in range(S)]
N = random.choices(list(range(1, S + 1)), p, k=1)[0]
cumsum = [0] + sorted(random.sample(range(1, S), N - 1)) + [S]
result = [cumsum[i+1] - cumsum[i] for i in range(len(cumsum) - 1)]
return result, cumsum
def pil_img2rgb(image):
if image.mode == "RGBA" or image.info.get("transparency", None) is not None:
image = image.convert("RGBA")
white = Image.new(mode="RGB", size=image.size, color=(255, 255, 255))
white.paste(image, mask=image.split()[3])
image = white
else:
image = image.convert("RGB")
return image
def add_special_tokens(tokenizer):
all_special_tokens = []
for k, v in tokenizer.special_tokens_map.items():
if isinstance(v, str):
all_special_tokens.append(v)
elif isinstance(v, list):
all_special_tokens += v
new_tokens = []
if '<|im_start|>' not in all_special_tokens:
new_tokens.append('<|im_start|>')
if '<|im_end|>' not in all_special_tokens:
new_tokens.append('<|im_end|>')
if '<|vision_start|>' not in all_special_tokens:
new_tokens.append('<|vision_start|>')
if '<|vision_end|>' not in all_special_tokens:
new_tokens.append('<|vision_end|>')
num_new_tokens = tokenizer.add_tokens(new_tokens)
bos_token_id = tokenizer.convert_tokens_to_ids('<|im_start|>')
eos_token_id = tokenizer.convert_tokens_to_ids('<|im_end|>')
start_of_image = tokenizer.convert_tokens_to_ids('<|vision_start|>')
end_of_image = tokenizer.convert_tokens_to_ids('<|vision_end|>')
new_token_ids = dict(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
start_of_image=start_of_image,
end_of_image=end_of_image,
)
return tokenizer, new_token_ids, num_new_tokens
def len2weight(x, loss_reduction='square'):
if x == 0:
return x
if loss_reduction == 'token':
return 1
if loss_reduction == 'sample':
return 1 / x
if loss_reduction == 'square':
return 1 / (x ** 0.5)
raise NotImplementedError(loss_reduction)
This diff is collapsed.
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
from .interleave_datasets import UnifiedEditIterableDataset
from .t2i_dataset import T2IIterableDataset
from .vlm_dataset import SftJSONLIterableDataset
DATASET_REGISTRY = {
't2i_pretrain': T2IIterableDataset,
'vlm_sft': SftJSONLIterableDataset,
'unified_edit': UnifiedEditIterableDataset,
}
DATASET_INFO = {
't2i_pretrain': {
't2i': {
'data_dir': 'your_data_path/bagel_example/t2i', # path of the parquet files
'num_files': 10, # number of data units to be sharded across all ranks and workers
'num_total_samples': 1000, # number of total samples in the dataset
},
},
'unified_edit':{
'seedxedit_multi': {
'data_dir': 'your_data_path/bagel_example/editing/seedxedit_multi',
'num_files': 10,
'num_total_samples': 1000,
"parquet_info_path": 'your_data_path/bagel_example/editing/parquet_info/seedxedit_multi_nas.json', # information of the parquet files
},
},
'vlm_sft': {
'llava_ov': {
'data_dir': 'your_data_path/bagel_example/vlm/images',
'jsonl_path': 'your_data_path/bagel_example/vlm/llava_ov_si.jsonl',
'num_total_samples': 1000
},
},
}
\ No newline at end of file
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import random
import torch
class DistributedIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, dataset_name, local_rank=0, world_size=1, num_workers=8):
self.dataset_name = dataset_name
self.local_rank = local_rank
self.world_size = world_size
self.num_workers = num_workers
self.rng = random.Random()
self.data_paths = None
def get_data_paths(self, *args, **kwargs):
raise NotImplementedError
def set_epoch(self, seed=42):
if self.data_paths is None:
return
if isinstance(self.data_paths[0], tuple):
data_paths = sorted(self.data_paths, key=lambda x: (x[0], x[1]))
elif isinstance(self.data_paths[0], str):
data_paths = sorted(self.data_paths)
else:
raise ValueError(f"Unknown data_paths type: {type(self.data_paths[0])}")
self.rng.seed(seed)
self.rng.shuffle(data_paths)
num_files_per_rank = len(data_paths) // self.world_size
local_start = self.local_rank * num_files_per_rank
local_end = (self.local_rank + 1) * num_files_per_rank
self.num_files_per_rank = num_files_per_rank
self.data_paths_per_rank = data_paths[local_start:local_end]
def get_data_paths_per_worker(self):
if self.data_paths is None:
return None
info = torch.utils.data.get_worker_info()
if info is None:
# Single worker: Use all files assigned to the rank
return self.data_paths_per_rank, 0
worker_id = info.id
num_files_per_worker = self.num_files_per_rank // info.num_workers
start = num_files_per_worker * worker_id
end = num_files_per_worker * (worker_id + 1)
data_paths_per_worker = self.data_paths_per_rank[start:end]
return data_paths_per_worker[::-1], worker_id
def __iter__(self):
raise NotImplementedError
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
from .edit_dataset import UnifiedEditIterableDataset
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import io
import random
from PIL import Image, ImageFile, PngImagePlugin
from .interleave_t2i_dataset import InterleavedBaseIterableDataset, ParquetStandardIterableDataset
from ..data_utils import pil_img2rgb
Image.MAX_IMAGE_PIXELS = 200000000
ImageFile.LOAD_TRUNCATED_IMAGES = True
MaximumDecompressedSize = 1024
MegaByte = 2 ** 20
PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
class UnifiedEditIterableDataset(InterleavedBaseIterableDataset, ParquetStandardIterableDataset):
def parse_row(self, row):
image_num = len(row["image_list"])
# randomly choose start and end, return [0, 1] when only two images
start_idx = random.choice(range(image_num - 1))
max_end = min(start_idx + 3, image_num)
end_idx = random.choice(range(start_idx + 1, max_end))
data = self._init_data()
data = self._add_image(
data,
pil_img2rgb(Image.open(io.BytesIO(row["image_list"][start_idx]))),
need_loss=False,
need_vae=True,
need_vit=True,
)
if end_idx - start_idx > 1 and random.random() < 0.5: # concat multiple insturction
if end_idx == image_num - 1:
end_idx -= 1
instruction = ""
for idx in range(start_idx + 1, end_idx + 1):
instruction += random.choice(row["instruction_list"][idx-1]) + ". "
data = self._add_text(data, instruction.rstrip(), need_loss=False)
data = self._add_image(
data,
pil_img2rgb(Image.open(io.BytesIO(row["image_list"][end_idx]))),
need_loss=True,
need_vae=False,
need_vit=False,
)
else:
for idx in range(start_idx + 1, end_idx + 1):
instruction = random.choice(row["instruction_list"][idx-1])
data = self._add_text(data, instruction, need_loss=False)
if idx != end_idx:
data = self._add_image(
data,
pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))),
need_loss=True,
need_vae=True,
need_vit=True,
)
else:
data = self._add_image(
data,
pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))),
need_loss=True,
need_vae=False,
need_vit=False,
)
return data
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import pyarrow.parquet as pq
from ..distributed_iterable_dataset import DistributedIterableDataset
from ..parquet_utils import get_parquet_data_paths, init_arrow_pf_fs
class InterleavedBaseIterableDataset(DistributedIterableDataset):
def _init_data(self):
data = {
'sequence_plan': [],
'text_ids_list': [],
'image_tensor_list': [],
'num_tokens': 0,
}
return data
def _add_text(self, data, text, need_loss, enable_cfg=True):
text_ids = self.tokenizer.encode(text)
data['num_tokens'] += len(text_ids)
data['text_ids_list'].append(text_ids)
data['sequence_plan'].append(
{
'type': 'text',
'enable_cfg': int(enable_cfg),
'loss': int(need_loss),
'special_token_loss': 0,
'special_token_label': None,
}
)
return data
def _add_image(self, data, image, need_loss, need_vae, need_vit, enable_cfg=True):
assert need_loss or need_vae or need_vit
if need_loss:
data['sequence_plan'].append(
{
'type': 'vae_image',
'enable_cfg': 0,
'loss': 1,
'special_token_loss': 0,
'special_token_label': None,
}
)
image_tensor = self.transform(image)
height, width = image_tensor.shape[1:]
data['num_tokens'] += width * height // self.transform.stride ** 2
data['image_tensor_list'].append(image_tensor)
if need_vae:
data['sequence_plan'].append(
{
'type': 'vae_image',
'enable_cfg': int(enable_cfg),
'loss': 0,
'special_token_loss': 0,
'special_token_label': None,
}
)
image_tensor = self.transform(image)
height, width = image_tensor.shape[1:]
data['num_tokens'] += width * height // self.transform.stride ** 2
data['image_tensor_list'].append(image_tensor.clone())
if need_vit:
data['sequence_plan'].append(
{
'type': 'vit_image',
'enable_cfg': int(enable_cfg),
'loss': 0,
'special_token_loss': 0,
'special_token_label': None,
},
)
vit_image_tensor = self.vit_transform(image)
height, width = vit_image_tensor.shape[1:]
data['num_tokens'] += width * height // self.vit_transform.stride ** 2
data['image_tensor_list'].append(vit_image_tensor)
return data
def _add_video(self, data, frames, frame_indexes, need_loss, need_vae, enable_cfg=True):
assert int(need_loss) + int(need_vae) == 1
if need_loss:
for idx, (image, frame_idx) in enumerate(zip(frames, frame_indexes)):
current_sequence_plan = {
'type': 'vae_image',
'enable_cfg': 0,
'loss': 1,
'special_token_loss': 0,
'special_token_label': None,
'split_start': idx == 0,
'split_end': idx == len(frames) - 1,
}
if idx < len(frame_indexes) - 1:
current_sequence_plan['frame_delta'] = frame_indexes[idx + 1] - frame_idx
data['sequence_plan'].append(current_sequence_plan)
image_tensor = self.transform(image)
height, width = image_tensor.shape[1:]
data['image_tensor_list'].append(image_tensor)
data['num_tokens'] += width * height // self.transform.stride ** 2
elif need_vae:
for idx, (image, frame_idx) in enumerate(zip(frames, frame_indexes)):
current_sequence_plan = {
'type': 'vae_image',
'enable_cfg': int(enable_cfg),
'loss': 0,
'special_token_loss': 0,
'special_token_label': None,
'split_start': idx == 0,
'split_end': idx == len(frames) - 1,
}
if idx < len(frame_indexes) - 1:
current_sequence_plan['frame_delta'] = frame_indexes[idx + 1] - frame_idx
data['sequence_plan'].append(current_sequence_plan)
image_tensor = self.transform(image)
height, width = image_tensor.shape[1:]
data['image_tensor_list'].append(image_tensor)
data['num_tokens'] += width * height // self.transform.stride ** 2
return data
class ParquetStandardIterableDataset(DistributedIterableDataset):
def __init__(
self, dataset_name, transform, tokenizer, vit_transform,
data_dir_list, num_used_data, parquet_info,
local_rank=0, world_size=1, num_workers=8, data_status=None,
):
"""
data_dir_list: list of data directories contains parquet files
num_used_data: list of number of sampled data paths for each data directory
vit_transform: input transform for vit model.
"""
super().__init__(dataset_name, local_rank, world_size, num_workers)
self.transform = transform
self.vit_transform = vit_transform
self.tokenizer = tokenizer
self.data_status = data_status
self.data_paths = self.get_data_paths(data_dir_list, num_used_data, parquet_info)
self.set_epoch()
def get_data_paths(self, data_dir_list, num_used_data, parquet_info):
row_groups = []
for data_dir, num_data_path in zip(data_dir_list, num_used_data):
data_paths = get_parquet_data_paths([data_dir], [num_data_path])
for data_path in data_paths:
if data_path in parquet_info.keys():
num_row_groups = parquet_info[data_path]['num_row_groups']
for rg_idx in range(num_row_groups):
row_groups.append((data_path, rg_idx))
return row_groups
def parse_row(self, row):
raise NotImplementedError
def __iter__(self):
file_paths_per_worker, worker_id = self.get_data_paths_per_worker()
if self.data_status is not None:
global_row_group_start_id = self.data_status[worker_id][0]
row_start_id = self.data_status[worker_id][1] + 1
else:
global_row_group_start_id = 0
row_start_id = 0
print(
f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
f"resuming data at global_rg#{global_row_group_start_id}, row#{row_start_id}"
)
while True:
file_paths_per_worker_ = file_paths_per_worker[global_row_group_start_id:]
for global_row_group_idx, (parquet_file_path, row_group_id) in enumerate(
file_paths_per_worker_, start=global_row_group_start_id
):
fs = init_arrow_pf_fs(parquet_file_path)
with fs.open_input_file(parquet_file_path) as f:
try:
fr = pq.ParquetFile(f)
df = fr.read_row_group(row_group_id).to_pandas()
df = df.iloc[row_start_id:]
except Exception as e:
print(f'Error {e} in rg#{row_group_id}, {parquet_file_path}')
continue
for row_idx, row in df.iterrows():
try:
data = self.parse_row(row)
if len(data) == 0:
continue
data['data_indexes'] = {
"data_indexes": [global_row_group_idx, row_idx],
"worker_id": worker_id,
"dataset_name": self.dataset_name,
}
except Exception as e:
print(f'Error {e} in rg#{row_group_id}, {parquet_file_path}')
continue
yield data
row_start_id = 0
global_row_group_start_id = 0
print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment