Commit 727428ec authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit CI/CD

parents
blank_issues_enabled: false
---
name: Error report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''
---
Thanks for your error report and we appreciate it a lot.
**Checklist**
1. I have searched related issues but cannot get the expected help.
2. The bug has not been fixed in the latest version.
**Describe the bug**
A clear and concise description of what the bug is.
**Reproduction**
1. What command or script did you run?
```none
A placeholder for the command.
```
2. Did you make any modifications on the code or config? Did you understand what you have modified?
3. What dataset did you use?
**Environment**
1. Please run `python utils/collect_env.py` to collect necessary environment information and paste it here.
2. You may add addition that may be helpful for locating the problem, such as
- How you installed PyTorch \[e.g., pip, conda, source\]
- Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.)
**Error traceback**
If applicable, paste the error trackback here.
```none
A placeholder for trackback.
```
**Bug fix**
If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ''
assignees: ''
---
**Describe the feature**
**Motivation**
A clear and concise description of the motivation of the feature.
Ex1. It is inconvenient when \[....\].
Ex2. There is a recent paper \[....\], which is very helpful for \[....\].
**Related resources**
If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful.
**Additional context**
Add any other context or screenshots about the feature request here.
If you would like to implement the feature and create a PR, please leave a comment here and that would be much appreciated.
---
name: General questions
about: Ask general questions to get help
title: ''
labels: ''
assignees: ''
---
---
name: Reimplementation Questions
about: Ask about questions during model reimplementation
title: ''
labels: reimplementation
assignees: ''
---
**Notice**
There are several common situations in the reimplementation issues as below
1. Reimplement a model in the model zoo using the provided configs
2. Reimplement a model in the model zoo on other dataset (e.g., custom datasets)
3. Reimplement a custom model but all the components are implemented in HunyuanDiT
4. Reimplement a custom model with new modules implemented by yourself
There are several things to do for different cases as below.
- For case 1 & 3, please follow the steps in the following sections thus we could help to quick identify the issue.
- For case 2 & 4, please understand that we are not able to do much help here because we usually do not know the full code and the users should be responsible to the code they write.
- One suggestion for case 2 & 4 is that the users should first check whether the bug lies in the self-implemented code or the original code. For example, users can first make sure that the same model runs well on supported datasets. If you still need help, please describe what you have done and what you obtain in the issue, and follow the steps in the following sections and try as clear as possible so that we can better help you.
**Checklist**
1. I have searched related issues but cannot get the expected help.
2. The issue has not been fixed in the latest version.
**Describe the issue**
A clear and concise description of what the problem you meet and what have you done.
**Reproduction**
1. What command or script did you run?
```none
A placeholder for the command.
```
2. What config dir you run?
```none
A placeholder for the config.
```
3. Did you make any modifications on the code or config? Did you understand what you have modified?
4. What dataset did you use?
**Environment**
1. Please run `python utils/collect_env.py` to collect necessary environment information and paste it here.
2. You may add addition that may be helpful for locating the problem, such as
1. How you installed PyTorch \[e.g., pip, conda, source\]
2. Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.)
**Results**
If applicable, paste the related results here, e.g., what you expect and what you get.
```none
A placeholder for results comparison
```
**Issue fix**
If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!
---
name: 一般问题
about: 提出一般性问题以获得帮助
title: ''
labels: ''
assignees: ''
---
---
name: 功能请求
about: 为此项目提出一个想法
title: ''
labels: ''
assignees: ''
---
**描述功能**
**动机**
清晰简洁地描述该功能的动机。
例如1. 当\[....\]时很不方便。
例如2. 最近有一篇论文\[....\],这对\[....\]非常有帮助。
**相关资源**
如果有官方代码发布或第三方实现,也请在此提供信息,这将非常有帮助。
**其他上下文**
在此添加有关功能请求的其他上下文或截图。
如果您希望实现该功能并创建 PR,请在此留言,我们将不胜感激。
---
name: 复现问题
about: 询问模型复现中的问题
title: ''
labels: reimplementation
assignees: ''
---
**注意**
在复现的问题中,通常有以下几种常见情况:
1. 使用提供的配置复现模型库中的模型
2. 在其他数据集(例如自定义数据集)上复现模型库中的模型
3. 实现自定义模型,但利用的都是HunyuanDiT中实现的组件
4. 使用自己实现的新模块实现自定义模型
针对不同情况需要做的事情如下:
- 对于情况 1 和 3,请按照以下部分的步骤操作,以便我们快速识别问题。
- 对于情况 2 和 4,请理解我们无法提供太多帮助,因为我们通常不了解全部代码,用户应对自己编写的代码负责。
- 针对情况 2 和 4 的一个建议是,用户应首先检查错误是否在自实现代码中或原始代码中。例如,用户可以首先确保相同的模型在支持的数据集上运行良好。如果您仍然需要帮助,请描述您所做的工作和遇到的问题,并按照以下部分的步骤尽可能清晰地描述,以便我们更好地帮助您。
**检查清单**
1. 我已搜索相关问题,但无法获得预期的帮助。
2. 最新版本中尚未修复此问题。
**描述问题**
清晰简洁地描述您遇到的问题以及您所做的工作。
**重现步骤**
1. 您运行了什么命令或脚本?
```none
命令的占位符。
```
2. 您运行的配置目录是什么?
```none
配置目录的占位符
```
3. 您是否对代码或配置进行了任何修改?您是否理解您所修改的内容?
4. 您使用了什么数据集?
**环境**
1. 请运行 python utils/collect_env.py 收集必要的环境信息并粘贴在此。
2. 您可以添加其他有助于定位问题的信息,例如
- 您如何安装的 PyTorch \[例如,pip, conda, source\]
- 其他可能相关的环境变量 (例如 `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH` 等)
**结果**
如果适用,请在此粘贴相关结果,例如,您的预期结果和实际结果。
```none
结果比较的占位符
```
**Bug修复**
如果您已经确定了原因,可以在此提供信息。如果您愿意创建 PR 进行修复,也请在此留言,我们将不胜感激!
\ No newline at end of file
---
name: 错误报告
about: 创建报告以帮助我们改进
title: ''
labels: ''
assignees: ''
---
感谢您的错误报告,我们非常感谢。
**检查清单**
1. 我已搜索相关问题,但无法获得预期的帮助。
2. 最新版本中尚未修复此错误。
**描述错误**
对错误的清晰简洁描述。
**重现步骤**
1. 您运行了什么命令或脚本?
```none
命令的占位符。
```
2. 您是否对代码或配置进行了任何修改?您是否理解您所修改的内容?
3. 您使用了什么数据集?
**环境**
1. 请运行 python utils/collect_env.py 收集必要的环境信息并粘贴在此。
2. 您可以添加其他有助于定位问题的信息,例如
- 您如何安装的 PyTorch \[例如,pip, conda, source\]
- 其他可能相关的环境变量 (例如 `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH` 等)
**错误回溯**
如可以,请在此粘贴错误回溯信息。
```none
回溯信息的占位符。
```
**Bug修复**
如果您已经确定了原因,可以在此提供信息。如果您愿意创建 PR 进行修复,也请在此留言,我们将不胜感激!
results/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
.vscode/
.idea/
.DS_Store
results/
trt/activate.sh
trt/deactivate.sh
*.onnx
ckpts/
app/output/
[submodule "comfyui-hydit"]
path = comfyui-hydit
url = https://github.com/zml-ai/comfyui-hydit
# IndexKits
[TOC]
## Introduction
Index Kits (`index_kits`) for streaming Arrow data.
* Supports creating datasets from configuration files.
* Supports creating index v2 from arrows.
* Supports the creation, reading, and usage of index files.
* Supports the creation, reading, and usage of multi-bucket, multi-resolution indexes.
`index_kits` has the following advantages:
* Optimizes memory usage for streaming reads, supporting the reading of data at the billion level.
* Optimizes memory usage and reading speed of Index files.
* Supports the creation of Base/Multireso Index V2 datasets, including data filtering, repeating, and deduplication during creation.
* Supports loading various dataset types (Arrow files, Base Index V2, multiple Base Index V2, Multireso Index V2, multiple Multireso Index V2).
* Uses a unified API interface to access images, text, and other attributes for all dataset types.
* Built-in shuffle, `resize_and_crop`, and returns the coordinates of the crop.
## Installation
* **Install from pre-compiled whl (recommended)**
* Install from source
```shell
cd IndexKits
pip install -e .
```
## Usage
### Loading an Index V2 File
```python
from index_kits import ArrowIndexV2
index_manager = ArrowIndexV2('data.json')
# You can shuffle the dataset using a seed. If a seed is provided (i.e., the seed is not None), this shuffle will not affect any random state.
index_manager.shuffle(1234, fast=True)
for i in range(len(index_manager)):
# Get an image (requires the arrow to contain an image/binary column):
pil_image = index_manager.get_image(i)
# Get text (requires the arrow to contain a text_zh column):
text = index_manager.get_attribute(i, column='text_zh')
# Get MD5 (requires the arrow to contain an md5 column):
md5 = index_manager.get_md5(i)
# Get any attribute by specifying the column (must be contained in the arrow):
ocr_num = index_manager.get_attribute(i, column='ocr_num')
# Get multiple attributes at once
item = index_manager.get_data(i, columns=['text_zh', 'md5']) # i: in-dataset index
print(item)
# {
# 'index': 3, # in-json index
# 'in_arrow_index': 3, # in-arrow index
# 'arrow_name': '/HunYuanDiT/dataset/porcelain/00000.arrow',
# 'text_zh': 'Fortune arrives with the auspicious green porcelain tea cup',
# 'md5': '1db68f8c0d4e95f97009d65fdf7b441c'
# }
```
### Loading a Set of Arrow Files
If you have a small batch of arrow files, you can directly use `IndexV2Builder` to load these arrow files without creating an Index V2 file.
```python
from index_kits import IndexV2Builder
index_manager = IndexV2Builder(arrow_files).to_index_v2()
```
### Loading a Multi-Bucket (Multi-Resolution) Index V2 File
When using a multi-bucket (multi-resolution) index file, refer to the following example code, especially the definition of SimpleDataset and the usage of MultiResolutionBucketIndexV2.
```python
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
from index_kits import MultiResolutionBucketIndexV2
from index_kits.sampler import BlockDistributedSampler
class SimpleDataset(Dataset):
def __init__(self, index_file, batch_size, world_size):
# When using multi-bucket (multi-resolution), batch_size and world_size need to be specified for underlying data alignment.
self.index_manager = MultiResolutionBucketIndexV2(index_file, batch_size, world_size)
self.flip_norm = T.Compose(
[
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize([0.5], [0.5]),
]
)
def shuffle(self, seed, fast=False):
self.index_manager.shuffle(seed, fast=fast)
def __len__(self):
return len(self.index_manager)
def __getitem__(self, idx):
image = self.index_manager.get_image(idx)
original_size = image.size # (w, h)
target_size = self.index_manager.get_target_size(idx) # (w, h)
image, crops_coords_left_top = self.index_manager.resize_and_crop(image, target_size)
image = self.flip_norm(image)
text = self.index_manager.get_attribute(idx, column='text_zh')
return image, text, (original_size, target_size, crops_coords_left_top)
batch_size = 8 # batch_size per GPU
world_size = 8 # total number of GPUs
rank = 0 # rank of the current process
num_workers = 4 # customizable based on actual conditions
shuffle = False # must be set to False
drop_last = True # must be set to True
# Correct batch_size and world_size must be passed in, otherwise, the multi-resolution data cannot be aligned correctly.
dataset = SimpleDataset('data_multireso.json', batch_size=batch_size, world_size=world_size)
# Must use BlockDistributedSampler to ensure the samples in a batch have the same resolution.
sampler = BlockDistributedSampler(dataset, num_replicas=world_size, rank=rank,
shuffle=shuffle, drop_last=drop_last, batch_size=batch_size)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
num_workers=num_workers, pin_memory=True, drop_last=drop_last)
for epoch in range(10):
# Please use the shuffle method provided by the dataset, not the DataLoader's shuffle parameter.
dataset.shuffle(epoch, fast=True)
for batch in loader:
pass
```
## Fast Shuffle
When your index v2 file contains hundreds of millions of samples, using the default shuffle can be quite slow. Therefore, it’s recommended to enable `fast=True` mode:
```python
index_manager.shuffle(seed=1234, fast=True)
```
This will globally shuffle without keeping the indices of the same arrow together. Although this may reduce reading speed, it requires a trade-off based on the model’s forward time.
\ No newline at end of file
#!/usr/bin/env python3
import argparse
from index_kits.dataset.make_dataset_core import startup, make_multireso
from index_kits.common import show_index_info
from index_kits import __version__
def common_args(parser):
parser.add_argument('-t', '--target', type=str, required=True, help='Save path')
def get_args():
parser = argparse.ArgumentParser(description="""
IndexKits is a tool to build and manage index files for large-scale datasets.
It supports both base index and multi-resolution index.
Introduction
------------
This command line tool provides the following functionalities:
1. Show index v2 information
2. Build base index v2
3. Build multi-resolution index v2
Examples
--------
1. Show index v2 information
index_kits show /path/to/index.json
2. Build base index v2
Default usage:
index_kits base -c /path/to/config.yaml -t /path/to/index.json
Use multiple processes:
index_kits base -c /path/to/config.yaml -t /path/to/index.json -w 40
3. Build multi-resolution index v2
Build with a configuration file:
index_kits multireso -c /path/to/config.yaml -t /path/to/index_mb_gt512.json
Build by specifying arguments without a configuration file:
index_kits multireso --src /path/to/index.json --base-size 512 --reso-step 32 --min-size 512 -t /path/to/index_mb_gt512.json
Build by specifying target-ratios:
index_kits multireso --src /path/to/index.json --base-size 512 --target-ratios 1:1 4:3 3:4 16:9 9:16 --min-size 512 -t /path/to/index_mb_gt512.json
Build with multiple source index files.
index_kits multireso --src /path/to/index1.json /path/to/index2.json --base-size 512 --reso-step 32 --min-size 512 -t /path/to/index_mb_gt512.json
""", formatter_class=argparse.RawTextHelpFormatter)
sub_parsers = parser.add_subparsers(dest='task', required=True)
# Show index message
show_parser = sub_parsers.add_parser('show', description="""
Show base/multireso index v2 information.
Example
-------
index_kits show /path/to/index.json
""", formatter_class=argparse.RawTextHelpFormatter)
show_parser.add_argument('src', type=str, help='Path to a base/multireso index file.')
show_parser.add_argument('--arrow-files', action='store_true', help='Show arrow files only.')
show_parser.add_argument('--depth', type=int, default=1,
help='Arrow file depth. Default is 1, the level of last folder in the arrow file path. '
'Set it to 0 to show the full path including `xxx/last_folder/*.arrow`.')
# Single resolution bucket
base_parser = sub_parsers.add_parser('base', description="""
Build base index v2.
Example
-------
index_kits base -c /path/to/config.yaml -t /path/to/index.json
""", formatter_class=argparse.RawTextHelpFormatter)
base_parser.add_argument('-c', '--config', type=str, required=True, help='Configuration file path')
common_args(base_parser)
base_parser.add_argument('-w', '--world-size', type=int, default=1)
base_parser.add_argument('--work-dir', type=str, default='.', help='Work directory')
base_parser.add_argument('--use-cache', action='store_true', help='Use cache to avoid reprocessing. '
'Perform merge pkl results directly.')
# Multi-resolution bucket
mo_parser = sub_parsers.add_parser('multireso', description="""
Build multi-resolution index v2
Example
-------
Build with a configuration file:
index_kits multireso -c /path/to/config.yaml -t /path/to/index_mb_gt512.json
Build by specifying arguments without a configuration file:
index_kits multireso --src /path/to/index.json --base-size 512 --reso-step 32 --min-size 512 -t /path/to/index_mb_gt512.json
Build by specifying target-ratios:
index_kits multireso --src /path/to/index.json --base-size 512 --target-ratios 1:1 4:3 3:4 16:9 9:16 --min-size 512 -t /path/to/index_mb_gt512.json
Build with multiple source index files.
index_kits multireso --src /path/to/index1.json /path/to/index2.json --base-size 512 --reso-step 32 --min-size 512 -t /path/to/index_mb_gt512.json
""", formatter_class=argparse.RawTextHelpFormatter)
mo_parser.add_argument('-c', '--config', type=str, default=None,
help='Configuration file path in a yaml format. Either --config or --src must be provided.')
mo_parser.add_argument('-s', '--src', type=str, nargs='+', default=None,
help='Source index files. Either --config or --src must be provided.')
common_args(mo_parser)
mo_parser.add_argument('--base-size', type=int, default=None, help="Base size. Typically set as 256/512/1024 according to image size you train model.")
mo_parser.add_argument('--reso-step', type=int, default=None,
help="Resolution step. Either reso_step or target_ratios must be provided.")
mo_parser.add_argument('--target-ratios', type=str, nargs='+', default=None,
help="Target ratios. Either reso_step or target_ratios must be provided.")
mo_parser.add_argument('--md5-file', type=str, default=None,
help='You can provide an md5 to height and width file to accelerate the process. '
'It is a pickle file that contains a dict, which maps md5 to (height, width) tuple.')
mo_parser.add_argument('--align', type=int, default=16, help="Used when --target-ratios is provided. Align size of source image height and width.")
mo_parser.add_argument('--min-size', type=int, default=0,
help="Minimum size. Images smaller than this size will be ignored.")
# Common
parser.add_argument('-v', '--version', action='version', version=f'%(prog)s {__version__}')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = get_args()
if args.task == 'show':
show_index_info(args.src,
args.arrow_files,
args.depth,
)
elif args.task == 'base':
startup(args.config,
args.target,
args.world_size,
args.work_dir,
use_cache=args.use_cache,
)
elif args.task == 'multireso':
make_multireso(args.target,
args.config,
args.src,
args.base_size,
args.reso_step,
args.target_ratios,
args.align,
args.min_size,
args.md5_file,
)
# Use the IndexKits library to create datasets.
[TOC]
## Introdution
The IndexKits library offers a command-line tool `idk` for creating datasets and viewing their statistical information. You can view the instructions by using `idk -h`. Here, it refers to creating an Index V2 format dataset from a series of arrow files.
## 1. Creating a Base Index V2 Dataset
### 1.1 Create a Base Index V2 dataset using `idk`
When creating a Base Index V2 dataset, you need to specify the path to a configuration file using the `-c` parameter and a save path using the `-t` parameter.
```shell
idk base -c base_config.yaml -t base_dataset.json
```
### 1.2 Basic Configuration
Next, let’s discuss how to write the configuration file. The configuration file is in `yaml` format, and below is a basic example:
Filename: `base_config.yaml`
```yaml
source:
- /HunYuanDiT/dataset/porcelain/arrows/00000.arrow
```
| Field Name | Type | Description |
|:---------:|:---:|:----------:|
| source | Optional | Arrow List |
We provide an example that includes all features and fields in [full_config.yaml](./docs/full_config.yaml).
### 1.3 Filtering Criteria
`idk` offers two types of filtering capabilities during the dataset creation process: (1) based on columns in Arrow, and (2) based on MD5 files.
To enable filtering criteria, add the `filter` field in the configuration file.
#### 1.3.1 Column Filtering
To enable column filtering, add the `column` field under the `filter` section.
Multiple column filtering criteria can be applied simultaneously, with the intersection of multiple conditions being taken.
For example, to select data where both the length and width are greater than or equal to 512,
with the default being 1024 if the length and width are invalid:
```yaml
filter:
column:
- name: height
type: int
action: ge
target: 512
default: 1024
- name: width
type: int
action: ge
target: 512
default: 1024
```
This filtering condition is equivalent to `table['height'].to_int(default=1024) >= 512 && table['width'].to_int(default=1024) >= 512`.
Each filtering criterion includes the following fields:
| Field Name | Type | Description | Value Range |
|:------------------:|:---:|:---------------:|:----------------------:|
| name | Required | Column name in Arrow | Column in Arrow |
| type | Required | Type of Elements in the Column | `int`, `float` or `str` |
| action | Required | Filtering Criteria | See the table below for possible values |
| target | Required | Filtering Criteria | Numeric or String |
| default | Required | Default value when the element is invalid | Numeric or String |
| arrow_file_keyword | Optional | Keywords in the Arrow file path | - |
Below are the specific meanings of “action” and the optional values under different circumstances::
| Action | Description | Type | Action | Description | Type |
|:-------------:|:----------------------------------------------------:|:---------------------:|:------------:|:--------------------------------------------:|:---------------------:|
| eq | equal, `==` | `int`, `float`, `str` | ne | not equal, `!=` | `int`, `float`, `str` |
| gt | great than, `>` | `int`, `float` | lt | less than, `<` | `int`, `float` |
| ge | great or equal, `>=` | `int`, `float` | le | less or equal, `<=` | `int`, `float` |
| len_eq | str length equal, `str.len()==` | `str` | len_ne | str length not equal, `str.len()!=` | `str` |
| len_gt | str length great than, `str.len()>` | `str` | len_lt | str length less than, `str.len()<` | `str` |
| len_ge | str length great or equal, `str.len()>=` | `str` | len_le | str length less or equal, `str.len()<=` | `str` |
| contains | str contains, `str.contains(target)` | `str` | not_contains | str not contains, `str.not_contains(target)` | `str` |
| in | str in, `str.in(target)` | `str` | not_in | str not in, `str.not_in(target)` | `str` |
| lower_last_in | lower str last char in, `str.lower()[-1].in(target)` | `str` | | | |
#### 1.3.2 MD5 Filtering
Add an `md5` field under the `filter` section to initiate MD5 filtering. Multiple MD5 filtering criteria can be applied simultaneously, with the intersection of multiple conditions being taken.
例如:
* `badcase.txt` is a list of MD5s, aiming to filter out the entries listed in these lists.
* `badcase.json` is a dictionary, where the key is the MD5 and the value is a text-related `tag`. The goal is to filter out specific `tags`.
```shell
filter:
md5:
- name: badcase1
path:
- badcase1.txt
type: list
action: in
is_valid: false
- name: badcase2
path: badcase2.json
type: dict
action: eq
target: 'Specified tag'
is_valid: false
```
Each filtering criterion includes five fields:
| Field Name | Type | Description | Value Range |
|:------------------:|:---:|:-----------------------------------------------------------------:|:------------------------------------------------------------------------------------:|
| name | Required | The name of the filtering criterion, which can be customized for ease of statistics. | - |
| path | Required | The path to the filtering file, which can be a single path or multiple paths provided in a list format. Supports `.txt`, `.json`, `.pkl` formats. | - |
| type | Required | The type of records in the filtering file. | `list` or `dict` |
| action | Required | The filtering action. | For `list`: `in`, `not_in`; For `dict`: `eq`, `ne`, `gt`, `lt`, `ge`, `le`|
| target | Optional | The filtering criterion. | Required when type is `dict` |
| is_valid | Required | Whether a hit on action+target is considered valid or invalid. | `true` or `false` |
| arrow_file_keyword | Optional | Keywords in the Arrow file path. | - |
### 1.4 Advanced Filtering
`idk` also supports some more advanced filtering functions.
#### 1.4.1 Filtering Criteria Applied to Part of Arrow Files
Using the `arrow_file_keyword` parameter, filtering criteria can be applied only to part of the Arrow files.
For example:
* The filtering criterion `height>=0` applies only to arrows whose path includes `human`.
* The filtering criterion “keep samples in goodcase.txt” applies only to arrows whose path includes `human`.
```yaml
filter:
column:
- name: height
type: int
action: ge
target: 512
default: 1024
arrow_file_keyword:
- human
md5:
- name: goodcase
path: goodcase.txt
type: list
action: in
is_valid: true
arrow_file_keyword:
- human
```
#### 1.4.2 The “or” Logic in Filtering Criteria
By default, filtering criteria follow an “and” logic. If you want two or more filtering criteria to follow an “or” logic, you should use the `logical_or` field. The column filtering criteria listed under this field will be combined using an “or” logic.
```yaml
filter:
column:
- logical_or:
- name: md5
type: str
action: lower_last_in
target: '02468ace'
default: ''
- name: text_zh
type: str
action: contains
target: 'turtle|rabbit|duck|swan|peacock|hedgehog|wolf|fox|seal|crow|deer'
default: ''
```
Special Note: The `logical_or` field is applicable only to the `column` filtering criteria within `filter`.
1.4.3 Excluding Certain Arrows from the Source
While wildcards can be used to fetch multiple arrows at once, there might be instances where we want to exclude some of them. This can be achieved through the `exclude` field. Keywords listed under `exclude`, if found in the path of the current group of arrows, will result in those arrows being excluded.
```yaml
source:
- /HunYuanDiT/dataset/porcelain/arrows/*.arrow:
exclude:
- arrow1
- arrow2
```
### 1.5 Repeating Samples
`idk` offers the capability to repeat either all samples or specific samples during dataset creation. There are three types of repeaters:
* Directly repeating the source
* Based on keywords in the Arrow file name (enable repeat conditions by adding the `repeater` field in the configuration file)
* Based on an MD5 file (enable repeat conditions by adding the `repeater` field in the configuration file)
**Special Note:** The above three conditions can be used simultaneously. If a sample meets multiple repeat conditions, the highest number of repeats will be taken.
#### 1.5.1 Repeating the Source
In the source, for the Arrow(s) you want to repeat (which can include wildcards like `*`, `?`), add `repeat: n` to mark the number of times to repeat.
```yaml
source:
- /HunYuanDiT/dataset/porcelain/arrows/*.arrow:
repeat: 10
```
**Special Note: Add a colon at the end of the Arrow path**
#### 1.5.2 Arrow Keywords
Add an `arrow_file_keyword` field under the `repeater` section.
```yaml
repeater:
arrow_file_keyword:
- repeat: 8
keyword:
- Lolita anime style
- Minimalist style
- repeat: 5
keyword:
- Magical Barbie style
- Disney style
```
Each repeat condition includes two fields:
| Field Name | Type | Description | Value Range |
|:-------:|:---:|:---------------:|:----:|
| repeat | Required | The number of times to repeat | Number |
| keyword | Required | Keywords in the Arrow file path | - |
#### 1.5.3 MD5 File
Add an `md5` field under the `repeater` section.
```yaml
repeater:
md5:
- name: goodcase1
path: /HunYuanDiT/dataset/porcelain/md5_repeat_1.json
type: dict
plus: 3
- name: goodcase2
path: /HunYuanDiT/dataset/porcelain/md5_repeat_2.json
type: list
repeat: 6
```
Each repeat condition includes the following fields:
| Field Name | Type | Description | Value Range |
|:------:|:---:|:----------------------------------------------:|:---------------------------:|
| name | Required | Custom name for the repeat condition | - |
| path | Required | Path to the file containing MD5s | `.txt`, `.json`, or `.pkl` formats |
| type | Required | Type of the MD5 file | `list` or `dict` |
| repeat | Optional | Number of times to repeat, this will override plus | Integer |
| plus | Optional | Adds this value to the value obtained from the MD5 file as the number of repeats | Integer |
### 1.6 Deduplication
To deduplicate, add `remove_md5_dup` and set it to `true`. For example:
```yaml
remove_md5_dup: true
```
**Special Note:** Deduplication is performed after repeat conditions, so using them together will make repeat ineffective.
### 1.7 Creating a Base Index V2 Dataset with Python Code
```python
from index_kits import IndexV2Builder
builder = IndexV2Builder(arrow_files)
builder.save('data_v2.json')
```
## 2. Creating a Multireso Index V2 Dataset
### 2.1 Creating a Multireso Index V2 Dataset with `idk`
To create a multi-resolution dataset, you can do so through a configuration file:
```shell
src:
- /HunYuanDiT/dataset/porcelain/jsons/a.json
- /HunYuanDiT/dataset/porcelain/jsons/b.json
- /HunYuanDiT/dataset/porcelain/jsons/c.json
base_size: 512
reso_step: 32
min_size: 512
```
The fields are as follows:
| Field Name | Type | Description | Value Range |
|:-------------:|:---:|:------------------------------------------------------:|:-------------------------------:|
| src | Required | Path(s) to the Base Index V2 file, can be single or multiple | - |
| base_size | Required | The base resolution (n, n) from which to create multiple resolutions | Recommended values: 256/512/1024 |
| reso_step | Optional | The step size for traversing multiple resolutions. Choose either this or target_ratios | Recommended values: 16/32/64 |
| target_ratios | Optional | Target aspect ratios, a list. Choose either this or reso_step | Recommended values: 1:1, 4:3, 3:4, 16:9, 9:16 |
| align | Optional | When using target_ratios, the multiple to align the target resolution to | Recommended value: 16 (2x patchify, 8x VAE) |
| min_size | Optional | The minimum resolution filter for samples when creating from Base to Multireso Index V2 | Recommended values: 256/512/1024 |
| md5_file | Optional | A pre-calculated dictionary of image sizes in pkl format, key is MD5, value is (h, w) | - |
### 2.2 Creating a Multireso Index V2 Dataset with Python Code
First, import the necessary function
```python
from index_kits import build_multi_resolution_bucket
md5_hw = None
# If many arrows in your Index V2 lack height and width, you can pre-calculate them and pass through the md5_hw parameter.
# md5_hw = {'c67be1d8f30fd0edcff6ac99b703879f': (720, 1280), ...}
# with open('md5_hw.pkl', 'rb') as f:
# md5_hw = pickle.load(f)
index_v2_file_path = 'data_v2.json'
index_v2_save_path = 'data_multireso.json'
# Method 1: Given base_size and reso_step, automatically calculate all resolution buckets.
build_multi_resolution_bucket(base_size=1024,
reso_step=64,
min_size=1024,
src_index_files=index_v2_file_path,
save_file=index_v2_save_path,
md5_hw=md5_hw)
# Method 2: Given a series of target aspect ratios, automatically calculate all resolution buckets.
build_multi_resolution_bucket(base_size=1024,
target_ratios=["1:1", "3:4", "4:3", "16:9", "9:16"],
align=16,
min_size=1024,
src_index_files=index_v2_file_path,
save_file=index_v2_save_path,
md5_hw=md5_hw)
```
**Note:** If both reso_step and target_ratios are provided, target_ratios will be prioritized.
source:
- /HunyuanDiT/dataset/task1/*.arrow
- /HunyuanDiT/dataset/task2/*.arrow:
repeat: 2
remove_md5_dup: false
filter:
column:
- name: height
type: int
action: ge
target: 512
default: 1024
- name: width
type: int
action: ge
target: 512
default: 1024
md5:
- name: Badcase
path: /path/to/bad_case_md5.txt
type: list
action: in
is_valid: false
repeater:
md5:
- name: Hardcase Repeat
path: /path/to/hard_case_md5.json
type: dict
plus: 3
- name: Goodcase Repeat
path: /path/to/good_case_md5.txt
type: list
repeat: 6
from .bucket import (
MultiIndexV2,
MultiResolutionBucketIndexV2,
MultiMultiResolutionBucketIndexV2,
build_multi_resolution_bucket,
)
from .bucket import Resolution, ResolutionGroup
from .indexer import IndexV2Builder, ArrowIndexV2
from .common import load_index, show_index_info
__version__ = "0.3.5"
import bisect
import json
import random
from pathlib import Path
from typing import Optional, Dict, List, Callable, Union
import numpy as np
from tqdm import tqdm
from PIL import Image
from .indexer import ArrowIndexV2, IndexV2Builder
class Resolution(object):
def __init__(self, size, *args):
if isinstance(size, str):
if "x" in size:
size = size.split("x")
size = (int(size[0]), int(size[1]))
else:
size = int(size)
if len(args) > 0:
size = (size, args[0])
if isinstance(size, int):
size = (size, size)
self.h = self.height = size[0]
self.w = self.width = size[1]
self.r = self.ratio = self.height / self.width
def __getitem__(self, idx):
if idx == 0:
return self.h
elif idx == 1:
return self.w
else:
raise IndexError(f"Index {idx} out of range")
def __str__(self):
return f"{self.h}x{self.w}"
class ResolutionGroup(object):
def __init__(
self,
base_size=None,
step=None,
align=1,
target_ratios=None,
enlarge=1,
data=None,
):
self.enlarge = enlarge
if data is not None:
self.data = data
mid = len(self.data) // 2
self.base_size = self.data[mid].h
self.step = self.data[mid].h - self.data[mid - 1].h
else:
self.align = align
self.base_size = base_size
assert (
base_size % align == 0
), f"base_size {base_size} is not divisible by align {align}"
if base_size is not None and not isinstance(base_size, int):
raise ValueError(
f"base_size must be None or int, but got {type(base_size)}"
)
if step is None and target_ratios is None:
raise ValueError(f"Either step or target_ratios must be provided")
if step is not None and step > base_size // 2:
raise ValueError(
f"step must be smaller than base_size // 2, but got {step} > {base_size // 2}"
)
self.step = step
self.data = self.calc(target_ratios)
self.ratio = np.array([x.ratio for x in self.data])
self.attr = ["" for _ in range(len(self.data))]
self.prefix_space = 0
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def __repr__(self):
prefix = self.prefix_space * " "
prefix_close = (self.prefix_space - 4) * " "
res_str = f"ResolutionGroup(base_size={self.base_size}, step={self.step}, data="
attr_maxlen = max([len(x) for x in self.attr] + [5])
res_str += f'\n{prefix}ID: height width ratio {" " * max(0, attr_maxlen - 4)}count h/16 w/16 tokens\n{prefix}'
res_str += ("\n" + prefix).join(
[
f"{i:2d}: ({x.h:4d}, {x.w:4d}) {self.ratio[i]:.4f} {self.attr[i]:>{attr_maxlen}s} "
f"({x.h // 16:3d}, {x.w // 16:3d}) {x.h // 16 * x.w // 16:6d}"
for i, x in enumerate(self.data)
]
)
res_str += f"\n{prefix_close})"
return res_str
@staticmethod
def from_list_of_hxw(hxw_list):
data = [Resolution(x) for x in hxw_list]
data = sorted(data, key=lambda x: x.ratio)
return ResolutionGroup(None, data=data)
def calc(self, target_ratios=None):
if target_ratios is None:
return self._calc_by_step()
else:
return self._calc_by_ratio(target_ratios)
def _calc_by_ratio(self, target_ratios):
resolutions = []
for ratio in target_ratios:
if ratio == "1:1":
reso = Resolution(self.base_size, self.base_size)
else:
hr, wr = map(int, ratio.split(":"))
x = int(
(
self.base_size**2
* self.enlarge
// self.align
// self.align
/ (hr * wr)
)
** 0.5
)
height = x * hr * self.align
width = x * wr * self.align
reso = Resolution(height, width)
resolutions.append(reso)
resolutions = sorted(resolutions, key=lambda x_: x_.ratio)
return resolutions
def _calc_by_step(self):
min_height = self.base_size // 2
min_width = self.base_size // 2
max_height = self.base_size * 2
max_width = self.base_size * 2
resolutions = [Resolution(self.base_size, self.base_size)]
cur_height, cur_width = self.base_size, self.base_size
while True:
if cur_height >= max_height and cur_width <= min_width:
break
cur_height = min(cur_height + self.step, max_height)
cur_width = max(cur_width - self.step, min_width)
resolutions.append(Resolution(cur_height, cur_width))
cur_height, cur_width = self.base_size, self.base_size
while True:
if cur_height <= min_height and cur_width >= max_width:
break
cur_height = max(cur_height - self.step, min_height)
cur_width = min(cur_width + self.step, max_width)
resolutions.append(Resolution(cur_height, cur_width))
resolutions = sorted(resolutions, key=lambda x: x.ratio)
return resolutions
class Bucket(ArrowIndexV2):
def __init__(self, height, width, *args, **kwargs):
super().__init__(*args, **kwargs)
self.height = height
self.width = width
self.ratio = height / width
self.scale_dist = []
def get_scale_by_index(self, index, size_col="hw", shadow=None):
"""
Calculate the scale to resize the image to fit the bucket.
Parameters
----------
index: int
An in-json index.
size_col: str
How to get the size of the image. 'hw' for height and width column,
while 'image' for decoding image binary and get the PIL Image size.
shadow: str
The shadow name. If None, return the main arrow file. If not None, return the shadow arrow file.
Returns
-------
scale: float
"""
if size_col == "hw":
w = int(self.get_attribute_by_index(index, "width", shadow=shadow))
h = int(self.get_attribute_by_index(index, "height", shadow=shadow))
else:
w, h = self.get_image_by_index(index, shadow=shadow).size
tw, th = self.width, self.height
tr = th / tw
r = h / w
scale = th / h if r < tr else tw / w
return scale
@staticmethod
def from_bucket_index(index_file, align=1, shadow_file_fn=None):
with open(index_file, "r") as f:
res_dict = json.load(f)
if not isinstance(res_dict["group_length"], dict):
error_msg = f'`group_length` must be a dict, but got {type(res_dict["group_length"])}'
if isinstance(res_dict["group_length"], list):
raise ValueError(
f"{error_msg}\nYou may using a vanilla Index V2 file. Try `ArrowIndexV2` instead."
)
else:
raise ValueError(error_msg)
assert "indices_file" in res_dict, f"indices_file not found in {index_file}"
assert res_dict["indices_file"] != "", f"indices_file is empty in {index_file}"
indices_file = Path(index_file).parent / res_dict["indices_file"]
assert Path(indices_file).exists(), f"indices_file {indices_file} not found"
# Loading indices data
indices_data = np.load(indices_file)
# Build buckets
buckets = []
keys = []
for k, v in indices_data.items():
data = {
"data_type": res_dict["data_type"],
"arrow_files": res_dict["arrow_files"],
"cum_length": res_dict["cum_length"],
}
data["indices_file"] = ""
data["indices"] = v
data["group_length"] = res_dict["group_length"][k]
height, width = map(int, k.split("x"))
bucket = Bucket(
height, width, res_dict=data, align=align, shadow_file_fn=shadow_file_fn
)
if len(bucket) > 0:
buckets.append(bucket)
keys.append(k)
resolutions = ResolutionGroup.from_list_of_hxw(keys)
resolutions.attr = [f"{len(bucket):,d}" for bucket in buckets]
return buckets, resolutions
class MultiIndexV2(object):
"""
Multi-bucket index. Support multi-GPU (either single node or multi-node distributed) training.
Parameters
----------
index_files: list
The index files.
batch_size: int
The batch size of each GPU. Required when using MultiResolutionBucketIndexV2 as base index class.
world_size: int
The number of GPUs. Required when using MultiResolutionBucketIndexV2 as base index class.
sample_strategy: str
The sample strategy. Can be 'uniform' or 'probability'. Default to 'uniform'.
If set to probability, a list of probability must be provided. The length of the list must be the same
as the number of buckets. Each probability value means the sample rate of the corresponding bucket.
probability: list
A list of probability. Only used when sample_strategy=='probability'.
shadow_file_fn: callable or dict
A callable function to map shadow file path to a new path. If None, the shadow file path will not be
changed. If a dict is provided, the keys are the shadow names to call the function, and the values are the
callable functions to map the shadow file path to a new path. If a callable function is provided, the key
is 'default'.
seed: int
Only used when sample_strategy=='probability'. The seed to sample the indices.
"""
buckets: List[ArrowIndexV2]
def __init__(
self,
index_files: List[str],
batch_size: Optional[int] = None,
world_size: Optional[int] = None,
sample_strategy: str = "uniform",
probability: Optional[List[float]] = None,
shadow_file_fn: Optional[Union[Callable, Dict[str, Callable]]] = None,
seed: Optional[int] = None,
):
self.buckets = self.load_buckets(
index_files,
batch_size=batch_size,
world_size=world_size,
shadow_file_fn=shadow_file_fn,
)
self.sample_strategy = sample_strategy
self.probability = probability
self.check_sample_strategy(sample_strategy, probability)
self.cum_length = self.calc_cum_length()
self.sampler = np.random.RandomState(seed)
if sample_strategy == "uniform":
self.total_length = sum([len(bucket) for bucket in self.buckets])
self.ind_mapper = np.arange(self.total_length)
elif sample_strategy == "probability":
self.ind_mapper = self.sample_indices_with_probability()
self.total_length = len(self.ind_mapper)
else:
raise ValueError(f"Not supported sample_strategy {sample_strategy}.")
def load_buckets(self, index_files, **kwargs):
buckets = [ArrowIndexV2(index_file, **kwargs) for index_file in index_files]
return buckets
def __len__(self):
return self.total_length
def check_sample_strategy(self, sample_strategy, probability):
if sample_strategy == "uniform":
pass
elif sample_strategy == "probability":
if probability is None:
raise ValueError(
f"probability must be provided when sample_strategy is 'probability'."
)
assert isinstance(
probability, (list, tuple)
), f"probability must be a list, but got {type(probability)}"
assert len(self.buckets) == len(
probability
), f"Length of index_files {len(self.buckets)} != Length of probability {len(probability)}"
else:
raise ValueError(f"Not supported sample_strategy {sample_strategy}.")
def sample_indices_with_probability(self):
ind_mapper_list = []
accu = 0
for bucket, p in zip(self.buckets, self.probability):
if p == 1:
# Just use all indices
indices = np.arange(len(bucket)) + accu
else:
# Use all indices multiple times, and then sample some indices without replacement
repeat_times = int(p)
indices_part1 = np.arange(len(bucket)).repeat(repeat_times)
indices_part2 = self.sampler.choice(
len(bucket), int(len(bucket) * (p - repeat_times)), replace=False
)
indices = np.sort(np.concatenate([indices_part1, indices_part2])) + accu
ind_mapper_list.append(indices)
accu += len(bucket)
ind_mapper = np.concatenate(ind_mapper_list)
return ind_mapper
def calc_cum_length(self):
cum_length = []
length = 0
for bucket in self.buckets:
length += len(bucket)
cum_length.append(length)
return cum_length
def shuffle(self, seed=None, fast=False):
if self.sample_strategy == "probability":
# Notice: In order to resample indices when shuffling, shuffle will not preserve the
# initial sampled indices when loading the index.
pass
# Shuffle indexes
if seed is not None:
state = random.getstate()
random.seed(seed)
random.shuffle(self.buckets)
random.setstate(state)
else:
random.shuffle(self.buckets)
self.cum_length = self.calc_cum_length()
# Shuffle indices in each index
for i, bucket in enumerate(self.buckets):
bucket.shuffle(seed + i, fast=fast)
# Shuffle ind_mapper
if self.sample_strategy == "uniform":
self.ind_mapper = np.arange(self.total_length)
elif self.sample_strategy == "probability":
self.ind_mapper = self.sample_indices_with_probability()
else:
raise ValueError(f"Not supported sample_strategy {self.sample_strategy}.")
if seed is not None:
sampler = np.random.RandomState(seed)
sampler.shuffle(self.ind_mapper)
else:
np.random.shuffle(self.ind_mapper)
def get_arrow_file(self, ind, **kwargs):
"""
Get arrow file by in-dataset index.
Parameters
----------
ind: int
The in-dataset index.
kwargs: dict
shadow: str
The shadow name. If None, return the main arrow file. If not None, return the shadow arrow file.
Returns
-------
arrow_file: str
"""
ind = self.ind_mapper[ind]
i = bisect.bisect_right(self.cum_length, ind)
bias = self.cum_length[i - 1] if i > 0 else 0
return self.buckets[i].get_arrow_file(ind - bias, **kwargs)
def get_data(
self, ind, columns=None, allow_missing=False, return_meta=True, **kwargs
):
"""
Get data by in-dataset index.
Parameters
----------
ind: int
The in-dataset index.
columns: str or list
The columns to be returned. If None, return all columns.
allow_missing: bool
If True, omit missing columns. If False, raise an error if the column is missing.
return_meta: bool
If True, the resulting dict will contain some meta information:
in-json index, in-arrow index, and arrow_name.
kwargs: dict
shadow: str
The shadow name. If None, return the main arrow file. If not None, return the shadow arrow file.
Returns
-------
data: dict
A dict containing the data.
"""
ind = self.ind_mapper[ind]
i = bisect.bisect_right(self.cum_length, ind)
bias = self.cum_length[i - 1] if i > 0 else 0
return self.buckets[i].get_data(
ind - bias,
columns=columns,
allow_missing=allow_missing,
return_meta=return_meta,
**kwargs,
)
def get_attribute(self, ind, column, **kwargs):
"""
Get single attribute by in-dataset index.
Parameters
----------
ind: int
The in-dataset index.
column: str
The column name.
kwargs: dict
shadow: str
The shadow name. If None, return the main arrow file. If not None, return the shadow arrow file.
Returns
-------
attribute: Any
"""
ind = self.ind_mapper[ind]
i = bisect.bisect_right(self.cum_length, ind)
bias = self.cum_length[i - 1] if i > 0 else 0
return self.buckets[i].get_attribute(ind - bias, column, **kwargs)
def get_image(self, ind, column="image", ret_type="pil", max_size=-1, **kwargs):
"""
Get image by in-dataset index.
Parameters
----------
ind: int
The in-dataset index.
column: str
[Deprecated] The column name of the image. Default to 'image'.
ret_type: str
The return type. Can be 'pil' or 'numpy'. Default to 'pil'.
max_size: int
If not -1, resize the image to max_size. max_size is the size of long edge.
kwargs: dict
shadow: str
The shadow name. If None, return the main arrow file. If not None, return the shadow arrow file.
Returns
-------
image: PIL.Image.Image or np.ndarray
"""
ind = self.ind_mapper[ind]
i = bisect.bisect_right(self.cum_length, ind)
bias = self.cum_length[i - 1] if i > 0 else 0
return self.buckets[i].get_image(
ind - bias, column, ret_type, max_size, **kwargs
)
def get_md5(self, ind, **kwargs):
"""Get md5 by in-dataset index."""
ind = self.ind_mapper[ind]
i = bisect.bisect_right(self.cum_length, ind)
bias = self.cum_length[i - 1] if i > 0 else 0
return self.buckets[i].get_md5(ind - bias, **kwargs)
def get_columns(self, ind, **kwargs):
"""Get columns by in-dataset index."""
ind = self.ind_mapper[ind]
i = bisect.bisect_right(self.cum_length, ind)
bias = self.cum_length[i - 1] if i > 0 else 0
return self.buckets[i].get_columns(ind - bias, **kwargs)
@staticmethod
def resize_and_crop(image, target_size, resample=Image.LANCZOS, crop_type="random"):
"""
Resize image without changing aspect ratio, then crop the center/random part.
Parameters
----------
image: PIL.Image.Image
The input image to be resized and cropped.
target_size: tuple
The target size of the image.
resample:
The resample method. See PIL.Image.Image.resize for details. Default to Image.LANCZOS.
crop_type: str
'center' or 'random'. If 'center', crop the center part of the image. If 'random',
crop a random part of the image. Default to 'random'.
Returns
-------
image: PIL.Image.Image
The resized and cropped image.
crop_pos: tuple
The position of the cropped part. (crop_left, crop_top)
"""
return ArrowIndexV2.resize_and_crop(image, target_size, resample, crop_type)
class MultiResolutionBucketIndexV2(MultiIndexV2):
"""
Multi-resolution bucket index. Support multi-GPU (either single node or multi-node distributed) training.
Parameters
----------
index_file: str
The index file of the bucket index.
batch_size: int
The batch size of each GPU.
world_size: int
The number of GPUs.
shadow_file_fn: callable or dict
A callable function to map shadow file path to a new path. If None, the shadow file path will not be
changed. If a dict is provided, the keys are the shadow names to call the function, and the values are the
callable functions to map the shadow file path to a new path. If a callable function is provided, the key
is 'default'.
"""
buckets: List[Bucket]
def __init__(
self,
index_file: str,
batch_size: int,
world_size: int,
shadow_file_fn: Optional[Union[Callable, Dict[str, Callable]]] = None,
):
align = batch_size * world_size
if align <= 0:
raise ValueError(
f"Align size must be positive, but got {align} = {batch_size} x {world_size}"
)
self.buckets, self._resolutions = Bucket.from_bucket_index(
index_file,
align=align,
shadow_file_fn=shadow_file_fn,
)
self.arrow_files = self.buckets[0].arrow_files
self._base_size = self._resolutions.base_size
self._step = self._resolutions.step
self.buckets = sorted(self.buckets, key=lambda x: x.ratio)
self.cum_length = self.calc_cum_length()
self.total_length = sum([len(bucket) for bucket in self.buckets])
assert (
self.total_length % align == 0
), f"Total length {self.total_length} is not divisible by align size {align}"
self.align_size = align
self.batch_size = batch_size
self.world_size = world_size
self.ind_mapper = np.arange(self.total_length)
@property
def step(self):
return self._step
@property
def base_size(self):
return self._base_size
@property
def resolutions(self):
return self._resolutions
def shuffle(self, seed=None, fast=False):
# Shuffle indexes
if seed is not None:
state = random.getstate()
random.seed(seed)
random.shuffle(self.buckets)
random.setstate(state)
else:
random.shuffle(self.buckets)
self.cum_length = self.calc_cum_length()
# Shuffle indices in each index
for i, bucket in enumerate(self.buckets):
bucket.shuffle(seed + i, fast=fast)
# Shuffle ind_mapper
batch_ind_mapper = (
np.arange(self.total_length // self.batch_size) * self.batch_size
)
if seed is not None:
sampler = np.random.RandomState(seed)
sampler.shuffle(batch_ind_mapper)
else:
np.random.shuffle(batch_ind_mapper)
ind_mapper = np.stack(
[batch_ind_mapper + i for i in range(self.batch_size)], axis=1
).reshape(-1)
self.ind_mapper = ind_mapper
def get_ratio(self, ind, **kwargs):
"""
Get the ratio of the image by in-dataset index.
Parameters
----------
ind: int
The in-dataset index.
Returns
-------
width, height, ratio
"""
ind = self.ind_mapper[ind]
width, height = self.get_image(ind, **kwargs).size
return width, height, height / width
def get_target_size(self, ind):
"""
Get the target size of the image by in-dataset index.
Parameters
----------
ind: int
The in-dataset index.
Returns
-------
target_width, target_height
"""
ind = self.ind_mapper[ind]
i = bisect.bisect_right(self.cum_length, ind)
return self.buckets[i].width, self.buckets[i].height
def scale_distribution(self, save_file=None):
if save_file is not None:
scale_dict = np.load(save_file)
for bucket in self.buckets:
bucket.scale_dist = scale_dict[f"{bucket.height}x{bucket.width}"]
else:
for bucket in tqdm(self.buckets):
for index in tqdm(bucket.indices, leave=False):
scale = bucket.get_scale_by_index(index)
bucket.scale_dist.append(scale)
scale_dict = {
f"{bucket.height}x{bucket.width}": bucket.scale_dist
for bucket in self.buckets
}
if save_file is not None:
save_file = Path(save_file)
save_file.parent.mkdir(exist_ok=True, parents=True)
np.savez_compressed(save_file, **scale_dict)
return self
class MultiMultiResolutionBucketIndexV2(MultiIndexV2):
buckets: List[MultiResolutionBucketIndexV2]
@property
def step(self):
return [b.step for b in self.buckets]
@property
def base_size(self):
return [b.base_size for b in self.buckets]
@property
def resolutions(self):
return [b.resolutions for b in self.buckets]
def load_buckets(self, index_files, **kwargs):
self.batch_size = kwargs.get("batch_size", None)
self.world_size = kwargs.get("world_size", None)
if self.batch_size is None or self.world_size is None:
raise ValueError(
"`batch_size` and `world_size` must be provided when using "
"`MultiMultiResolutionBucketIndexV2`."
)
buckets = [
MultiResolutionBucketIndexV2(
index_file,
self.batch_size,
self.world_size,
shadow_file_fn=kwargs.get("shadow_file_fn", None),
)
for index_file in index_files
]
return buckets
def sample_indices_with_probability(self, return_batch_indices=False):
bs = self.batch_size
ind_mapper_list = []
accu = 0
for bucket, p in zip(self.buckets, self.probability):
if p == 1:
# Just use all indices
batch_indices = np.arange(len(bucket) // bs) * bs + accu
else:
# Use all indices multiple times, and then sample some indices without replacement
repeat_times = int(p)
indices_part1 = np.arange(len(bucket) // bs).repeat(repeat_times) * bs
indices_part2 = (
self.sampler.choice(
len(bucket) // bs,
int(len(bucket) * (p / bs - repeat_times)),
replace=False,
)
* bs
)
batch_indices = (
np.sort(np.concatenate([indices_part1, indices_part2])) + accu
)
if return_batch_indices:
indices = batch_indices
else:
indices = np.stack(
[batch_indices + i for i in range(bs)], axis=1
).reshape(-1)
ind_mapper_list.append(indices)
accu += len(bucket)
ind_mapper = np.concatenate(ind_mapper_list)
return ind_mapper
def shuffle(self, seed=None, fast=False):
if self.sample_strategy == "probability":
# Notice: In order to resample indices when shuffling, shuffle will not preserve the
# initial sampled indices when loading the index.
pass
# Shuffle indexes
if seed is not None:
state = random.getstate()
random.seed(seed)
random.shuffle(self.buckets)
random.setstate(state)
else:
random.shuffle(self.buckets)
self.cum_length = self.calc_cum_length()
# Shuffle indices in each index
for i, bucket in enumerate(self.buckets):
bucket.shuffle(seed + i, fast=fast)
# Shuffle ind_mapper in batch level
if self.sample_strategy == "uniform":
batch_ind_mapper = (
np.arange(self.total_length // self.batch_size) * self.batch_size
)
elif self.sample_strategy == "probability":
batch_ind_mapper = self.sample_indices_with_probability(
return_batch_indices=True
)
else:
raise ValueError(f"Not supported sample_strategy {self.sample_strategy}.")
if seed is not None:
sampler = np.random.RandomState(seed)
sampler.shuffle(batch_ind_mapper)
else:
np.random.shuffle(batch_ind_mapper)
self.ind_mapper = np.stack(
[batch_ind_mapper + i for i in range(self.batch_size)], axis=1
).reshape(-1)
def get_ratio(self, ind, **kwargs):
"""
Get the ratio of the image by in-dataset index.
Parameters
----------
ind: int
The in-dataset index.
Returns
-------
width, height, ratio
"""
ind = self.ind_mapper[ind]
i = bisect.bisect_right(self.cum_length, ind)
bias = self.cum_length[i - 1] if i > 0 else 0
return self.buckets[i].get_ratio(ind - bias, **kwargs)
def get_target_size(self, ind):
"""
Get the target size of the image by in-dataset index.
Parameters
----------
ind: int
The in-dataset index.
Returns
-------
target_width, target_height
"""
ind = self.ind_mapper[ind]
i = bisect.bisect_right(self.cum_length, ind)
bias = self.cum_length[i - 1] if i > 0 else 0
return self.buckets[i].get_target_size(ind - bias)
def build_multi_resolution_bucket(
config_file,
base_size,
src_index_files,
save_file,
reso_step=64,
target_ratios=None,
align=1,
min_size=0,
md5_hw=None,
):
# Compute base size
resolutions = ResolutionGroup(
base_size, step=reso_step, target_ratios=target_ratios, align=align
)
print(resolutions)
save_file = Path(save_file)
save_file.parent.mkdir(exist_ok=True, parents=True)
if isinstance(src_index_files, str):
src_index_files = [src_index_files]
src_indexes = []
print(f"Loading indexes:")
for src_index_file in src_index_files:
src_indexes.append(ArrowIndexV2(src_index_file))
print(
f" {src_index_file} | cum_length: {src_indexes[-1].cum_length[-1]} | indices: {len(src_indexes[-1])}"
)
if md5_hw is None:
md5_hw = {}
arrow_files = src_indexes[0].arrow_files[:] # !!!important!!!, copy the list
for src_index in src_indexes[1:]:
arrow_files.extend(src_index.arrow_files[:])
cum_length = src_indexes[0].cum_length[:]
for src_index in src_indexes[1:]:
cum_length.extend([x + cum_length[-1] for x in src_index.cum_length])
print(f"cum_length: {cum_length[-1]}")
group_length_list = src_indexes[0].group_length[:]
for src_index in src_indexes[1:]:
group_length_list.extend(src_index.group_length[:])
total_indices = sum([len(src_index) for src_index in src_indexes])
total_group_length = sum(group_length_list)
assert (
total_indices == total_group_length
), f"Total indices {total_indices} != Total group length {total_group_length}"
buckets = [[] for _ in range(len(resolutions))]
cum_length_tmp = 0
total_index_count = 0
for src_index, src_index_file in zip(src_indexes, src_index_files):
index_count = 0
pbar = tqdm(src_index.indices.tolist())
for i in pbar:
try:
height = int(src_index.get_attribute_by_index(i, "height"))
width = int(src_index.get_attribute_by_index(i, "width"))
except Exception as e1:
try:
md5 = src_index.get_attribute_by_index(i, "md5")
height, width = md5_hw[md5]
except Exception as e2:
try:
width, height = src_index.get_image_by_index(i).size
except Exception as e3:
print(
f"Error: {e1} --> {e2} --> {e3}. We will skip this image."
)
continue
if height < min_size or width < min_size:
continue
ratio = height / width
idx = np.argmin(np.abs(resolutions.ratio - ratio))
buckets[idx].append(i + cum_length_tmp)
index_count += 1
print(f"Valid indices {index_count} in {src_index_file}.")
cum_length_tmp += src_index.cum_length[-1]
total_index_count += index_count
print(f"Total indices: {total_index_count}")
print(f"Making bucket index.")
indices = {}
for i, bucket in tqdm(enumerate(buckets)):
if len(bucket) == 0:
continue
reso = f"{resolutions[i]}"
resolutions.attr[i] = f"{len(bucket):>6d}"
indices[reso] = bucket
builder = IndexV2Builder(
data_type=[
"multi-resolution-bucket-v2",
f"base_size={base_size}",
f"reso_step={reso_step}",
f"target_ratios={target_ratios}",
f"align={align}",
f"min_size={min_size}",
f"src_files=",
]
+ [f"{src_index_file}" for src_index_file in src_index_files],
arrow_files=arrow_files,
cum_length=cum_length,
indices=indices,
config_file=config_file,
)
builder.build(save_file)
print(resolutions)
print(
f"Build index finished!\n\n"
f" Save path: {Path(save_file).absolute()}\n"
f" Number of indices: {sum([len(v) for k, v in indices.items()])}\n"
f"Number of arrow files: {len(arrow_files)}\n"
)
import json
import numpy as np
from pathlib import Path
from functools import partial
from .indexer import ArrowIndexV2
from .bucket import (
ResolutionGroup,
MultiIndexV2,
MultiResolutionBucketIndexV2,
MultiMultiResolutionBucketIndexV2,
IndexV2Builder,
)
def load_index(
src,
multireso=False,
batch_size=1,
world_size=1,
sample_strategy="uniform",
probability=None,
shadow_file_fn=None,
seed=None,
):
if isinstance(src, str):
src = [src]
if src[0].endswith(".arrow"):
if multireso:
raise ValueError(
"Arrow file does not support multiresolution. Please make base index V2 first and then"
"build multiresolution index."
)
idx = IndexV2Builder(src).to_index_v2()
elif src[0].endswith(".json"):
if multireso:
if len(src) == 1:
idx = MultiResolutionBucketIndexV2(
src[0],
batch_size=batch_size,
world_size=world_size,
shadow_file_fn=shadow_file_fn,
)
else:
idx = MultiMultiResolutionBucketIndexV2(
src,
batch_size=batch_size,
world_size=world_size,
sample_strategy=sample_strategy,
probability=probability,
shadow_file_fn=shadow_file_fn,
seed=seed,
)
else:
if len(src) == 1:
idx = ArrowIndexV2(
src[0],
shadow_file_fn=shadow_file_fn,
)
else:
idx = MultiIndexV2(
src,
sample_strategy=sample_strategy,
probability=probability,
shadow_file_fn=shadow_file_fn,
seed=seed,
)
else:
raise ValueError(f"Unknown file type: {src[0]}")
return idx
def get_attribute(data, attr_list):
ret_data = {}
for attr in attr_list:
ret_data[attr] = data.get(attr, None)
if ret_data[attr] is None:
raise ValueError(f"Missing {attr} in {data}")
return ret_data
def get_optional_attribute(data, attr_list):
ret_data = {}
for attr in attr_list:
ret_data[attr] = data.get(attr, None)
return ret_data
def detect_index_type(data):
if isinstance(data["group_length"], dict):
return "multireso"
else:
return "base"
def show_index_info(src, only_arrow_files=False, depth=1):
"""
Show base/multireso index information.
"""
if not Path(src).exists():
raise ValueError(f"{src} does not exist.")
print(f"Loading index file {src} ...")
with open(src, "r") as f:
src_data = json.load(f)
print(f"Loaded.")
data = get_attribute(
src_data,
[
"data_type",
"indices_file",
"arrow_files",
"cum_length",
"group_length",
"indices",
"example_indices",
],
)
opt_data = get_optional_attribute(src_data, ["config_file"])
# Format arrow_files examples
arrow_files = data["arrow_files"]
if only_arrow_files:
existed = set()
arrow_files_output_list = []
for arrow_file in arrow_files:
if depth == 0:
if arrow_file not in existed:
arrow_files_output_list.append(arrow_file)
existed.add(arrow_file)
elif depth > 0:
parts = Path(arrow_file).parts
if depth >= len(parts):
continue
else:
arrow_file_part = "/".join(parts[:-depth])
if arrow_file_part not in existed:
arrow_files_output_list.append(arrow_file_part)
existed.add(arrow_file_part)
else:
raise ValueError(
f"Depth {depth} has exceeded the limit of arrow file {arrow_file}."
)
arrow_files_repr = "\n".join(arrow_files_output_list)
print(arrow_files_repr)
return None
return_space = "\n" + " " * 21
if len(arrow_files) <= 4:
arrow_files_repr = return_space.join([arrow_file for arrow_file in arrow_files])
else:
arrow_files_repr = return_space.join(
[_ for _ in arrow_files[:2]] + ["..."] + [_ for _ in arrow_files[-2:]]
)
arrow_files_length = len(arrow_files)
# Format data_type
data_type = data["data_type"]
if isinstance(data_type, str):
data_type = [data_type]
data_type_common = []
src_files = []
found_src_files = False
for data_type_item in data_type:
if not found_src_files and data_type_item.strip() != "src_files=":
data_type_common.append(data_type_item.strip())
continue
found_src_files = True
if data_type_item.endswith(".json"):
src_files.append(data_type_item.strip())
else:
data_type_common.append(data_type_item.strip())
data_type_part2_with_ids = []
max_id_len = len(str(len(src_files)))
for sid, data_type_item in enumerate(src_files, start=1):
data_type_part2_with_ids.append(
f"{str(sid).rjust(max_id_len)}. {data_type_item}"
)
data_type = data_type_common + data_type_part2_with_ids
data_repr = return_space.join(data_type)
# Format cum_length examples
cum_length = data["cum_length"]
if len(cum_length) <= 8:
cum_length_repr = ", ".join([str(i) for i in cum_length])
else:
cum_length_repr = ", ".join(
[str(i) for i in cum_length[:4]]
+ ["..."]
+ [str(i) for i in cum_length[-4:]]
)
cum_length_length = len(cum_length)
if detect_index_type(data) == "base":
# Format group_length examples
group_length = data["group_length"]
if len(group_length) <= 8:
group_length_repr = ", ".join([str(i) for i in group_length])
else:
group_length_repr = ", ".join(
[str(i) for i in group_length[:4]]
+ ["..."]
+ [str(i) for i in group_length[-4:]]
)
group_length_length = len(group_length)
# Format indices examples
indices = data["indices"]
if len(indices) == 0 and data["indices_file"] != "":
indices_file = Path(src).parent / data["indices_file"]
if Path(indices_file).exists():
print(f"Loading indices from {indices_file} ...")
indices = np.load(indices_file)["x"]
print(f"Loaded.")
else:
raise ValueError(
f"This Index file contains an extra file {indices_file} which is missed."
)
if len(indices) <= 8:
indices_repr = ", ".join([str(i) for i in indices])
else:
indices_repr = ", ".join(
[str(i) for i in indices[:4]] + ["..."] + [str(i) for i in indices[-4:]]
)
# Calculate indices total length
indices_length = len(indices)
print_str = f"""File: {Path(src).absolute()}
ArrowIndexV2(
\033[4mdata_type:\033[0m {data_repr}"""
# Process optional data
if opt_data["config_file"] is not None:
print_str += f"""
\033[4mconfig_file:\033[0m {opt_data['config_file']}"""
# Add common data
print_str += f"""
\033[4mindices_file:\033[0m {data['indices_file']}
\033[4marrow_files: Count = {arrow_files_length:,}\033[0m
Examples: {arrow_files_repr}
\033[4mcum_length: Count = {cum_length_length:,}\033[0m
Examples: {cum_length_repr}
\033[4mgroup_length: Count = {group_length_length:,}\033[0m
Examples: {group_length_repr}
\033[4mindices: Count = {indices_length:,}\033[0m
Examples: {indices_repr}"""
else:
group_length = data["group_length"]
indices_file = Path(src).parent / data["indices_file"]
assert Path(indices_file).exists(), f"indices_file {indices_file} not found"
print(f"Loading indices from {indices_file} ...")
indices_data = np.load(indices_file)
print(f"Loaded.")
indices_length = sum([len(indices) for key, indices in indices_data.items()])
keys = [k for k in group_length.keys() if len(indices_data[k]) > 0]
resolutions = ResolutionGroup.from_list_of_hxw(keys)
resolutions.attr = [
f"{len(indices):>,d}" for k, indices in indices_data.items()
]
resolutions.prefix_space = 25
print_str = f"""File: {Path(src).absolute()}
MultiResolutionBucketIndexV2(
\033[4mdata_type:\033[0m {data_repr}"""
# Process optional data
if opt_data["config_file"] is not None:
print_str += f"""
\033[4mconfig_file:\033[0m {opt_data['config_file']}"""
# Process config files of base index files
config_files = []
for src_file in src_files:
src_file = Path(src_file)
if src_file.exists():
with src_file.open() as f:
base_data = json.load(f)
if "config_file" in base_data:
config_files.append(base_data["config_file"])
else:
config_files.append("Unknown")
else:
config_files.append("Missing the src file")
if config_files:
config_file_str = return_space.join(
[
f"{str(sid).rjust(max_id_len)}. {config_file}"
for sid, config_file in enumerate(config_files, start=1)
]
)
print_str += f"""
\033[4mbase config files:\033[0m {config_file_str}"""
# Add common data
print_str += f"""
\033[4mindices_file:\033[0m {data['indices_file']}
\033[4marrow_files: Count = {arrow_files_length:,}\033[0m
Examples: {arrow_files_repr}
\033[4mcum_length: Count = {cum_length_length:,}\033[0m
Examples: {cum_length_repr}
\033[4mindices: Count = {indices_length:,}\033[0m
\033[4mbuckets: Count = {len(keys)}\033[0m
{resolutions}"""
print(print_str + "\n)\n")
import yaml
from glob import glob
import pathlib
from pathlib import Path
import pickle
import json
from typing import List
from collections import defaultdict
def source_names_to_names(source_names_):
"""
Get all arrow file names from a list of source names.
Examples:
`test.arrow@@x2` means repeat the arrow file `test.arrow` 2 times.
"""
names = []
exclude_count = 0
for data_name in source_names_:
if isinstance(data_name, dict):
data_name, value = list(data_name.items())[0]
exclude = value.get("exclude", [])
repeat_times = value.get("repeat", 1)
else:
exclude = []
repeat_times = 1
cur_count = 0
for name in glob(data_name):
for e in exclude:
if e in name:
print(f"Excluding {name}")
exclude_count += 1
break
else:
names.append((name, repeat_times))
cur_count += 1
print(f"Found {cur_count} arrow files in {data_name}.")
names = list(sorted(names, key=lambda x: x[0]))
return names, exclude_count
def get_or_create_arrow_name_list(source_names=None):
if not isinstance(source_names, list):
raise ValueError(
f"`source_names` should be of type list, got {type(source_names)}"
)
if len(source_names) == 0:
raise ValueError(f"`source_names` is an empty list.")
names, exclude_count = source_names_to_names(source_names)
print(f"Found {len(names):,} arrow files.")
print(f"Excluded {exclude_count} arrow files.")
return names
def to_numeric(x, default_val=0.0):
if isinstance(x, (float, int)):
pass
elif isinstance(x, str):
try:
x = float(x)
except ValueError:
x = default_val
else:
x = default_val
return x
TYPE_MAPPER = {
"int": int,
"float": float,
"str": str,
"dict": dict,
"list": list,
}
class Operator(object):
def __init__(self, config):
self.config = config
self.arrow_file_keyword = None
def applicable(self, arrow_file):
# If no keyword is provided, then it is applicable to all arrow files
if self.arrow_file_keyword is None:
return True
# If keyword is provided, then it is applicable to the arrow file if the keyword is in the arrow file name
for keyword in self.arrow_file_keyword:
if keyword in arrow_file:
return True
return False
@staticmethod
def parse_path(paths):
if isinstance(paths, str):
paths = [paths]
if not isinstance(paths, list):
raise ValueError(f"Invalid type: {type(paths)}")
paths = [Path(p) for p in paths]
return paths
@staticmethod
def load_md5s(paths: List[pathlib.Path], merge=True):
md5s_list = []
for path in paths:
if not path.exists():
raise ValueError(f"Path not found: {path}")
if path.suffix == ".pkl":
with path.open("rb") as f:
md5s = pickle.load(f)
assert isinstance(md5s, (set, dict)), f"Invalid type: {type(md5s)}"
elif path.suffix == ".json":
with path.open() as f:
md5s = json.load(f)
assert isinstance(md5s, dict), f"Invalid type: {type(md5s)}"
elif path.suffix == ".txt":
md5s = set(path.read_text().splitlines())
else:
raise ValueError(f"Invalid file type: {path}")
md5s_list.append(md5s)
print(f" {path}: {len(md5s):,} md5s")
if merge:
md5s = md5s_list[0]
for m in md5s_list[1:]:
md5s.update(m)
else:
md5s = md5s_list
return md5s
class ColumnFilter(Operator):
numeric_actions = {"eq", "ne", "gt", "lt", "ge", "le"}
str_actions = {
"eq",
"ne",
"len_eq",
"len_ne",
"len_gt",
"len_lt",
"len_ge",
"len_le",
"contains",
"not_contains",
"in",
"not_in",
"lower_last_in",
}
int_target_str_actions = {"len_eq", "len_gt", "len_lt", "len_ge", "len_le"}
action_mapper_left = {
"eq": "==",
"ne": "!=",
"len_eq": ".len()==",
"gt": ">",
"len_gt": ".len()>",
"lt": "<",
"len_lt": ".len()<",
"ge": ">=",
"len_ge": ".len()>=",
"le": "<=",
"len_le": ".len()<=",
"contains": ".contains(",
"not_contains": ".not_contains(",
"in": ".isin(",
"not_in": ".notin(",
"lower_last_in": "[-1].lower().isin(",
}
action_mapper_right = {
"eq": "",
"ne": "",
"len_eq": "",
"gt": "",
"len_gt": "",
"lt": "",
"len_lt": "",
"ge": "",
"len_ge": "",
"le": "",
"len_le": "",
"contains": ")",
"not_contains": ")",
"in": ")",
"not_in": ")",
"lower_last_in": ")",
}
def __init__(self, config):
super().__init__(config)
self.name = self.column_name = config["name"]
self.dtype = config["type"]
self.action = config["action"]
self.target = config["target"]
self.default = config["default"]
self.arrow_file_cond = config.get("arrow_file", None)
self.arrow_file_keyword = config.get("arrow_file_keyword", None)
if self.arrow_file_keyword is not None:
if isinstance(self.arrow_file_keyword, str):
self.arrow_file_keyword = [self.arrow_file_keyword]
def check_exists(self, arrow_file, table):
status = True
if self.column_name not in table.column_names:
print(f"Warning: Column `{self.column_name}` not found in {arrow_file}.")
status = status and False
return status
@property
def data_type(self):
if self.arrow_file_keyword is None or len(self.arrow_file_keyword) == 0:
arrow_file_keyword_repr = ""
else:
arrow_file_keyword_repr = f", arrow_file_keys={self.arrow_file_keyword}"
fmt_str = (
f"{self.column_name}"
f"{self.action_mapper_left[self.action]}{self.target}{self.action_mapper_right[self.action]}"
f" (default={self.default}{arrow_file_keyword_repr})"
)
return fmt_str
class NumericFilter(ColumnFilter):
def run_on_column(self, column):
if self.action == "eq":
return column == self.target
elif self.action == "ne":
return column != self.target
elif self.action == "gt":
return column > self.target
elif self.action == "lt":
return column < self.target
elif self.action == "ge":
return column >= self.target
elif self.action == "le":
return column <= self.target
else:
raise ValueError(f"Invalid action: {self.action}")
class IntFilter(NumericFilter):
def __call__(self, arrow_file, table):
success = self.check_exists(arrow_file, table)
if not success:
return None
column = (
table[self.column_name]
.to_pandas()
.apply(lambda x: to_numeric(x, self.default))
)
return self.run_on_column(column)
class FloatFilter(NumericFilter):
def __call__(self, arrow_file, table):
success = self.check_exists(arrow_file, table)
if not success:
return None
column = (
table[self.column_name]
.to_pandas()
.apply(lambda x: to_numeric(x, self.default))
)
return self.run_on_column(column)
class StrFilter(ColumnFilter):
def __call__(self, arrow_file, table):
success = self.check_exists(arrow_file, table)
if not success:
return None
column = table[self.column_name].to_pandas()
if self.action == "eq":
return column.apply(lambda x: x == self.target)
elif self.action == "ne":
return column.apply(lambda x: x != self.target)
elif self.action == "len_eq":
return column.str.len() == self.target
elif self.action == "len_ne":
return column.str.len() != self.target
elif self.action == "len_gt":
return column.str.len() > self.target
elif self.action == "len_lt":
return column.str.len() < self.target
elif self.action == "len_ge":
return column.str.len() >= self.target
elif self.action == "len_le":
return column.str.len() <= self.target
elif self.action == "contains":
return column.str.contains(self.target)
elif self.action == "not_contains":
return ~column.str.contains(self.target)
elif self.action == "in":
return column.apply(lambda x: x in self.target)
elif self.action == "not_in":
return column.apply(lambda x: x not in self.target)
elif self.action == "lower_last_in":
return column.apply(lambda x: x[-1].lower() in self.target)
else:
raise ValueError(f"Invalid action: {self.action}")
def check_type(parent, config):
if "type" not in config:
return False, f"Missing required argument: {parent}.type"
if not isinstance(config["type"], str):
return False, f"Argument {parent}.type should be of type str"
if config["type"] not in TYPE_MAPPER:
return False, f"Invalid type: {parent}.type: {config['type']}"
return True, ""
def check_keys(parent, config, required_args, types):
for arg, ts in zip(required_args, types):
# Check key exists
if arg not in config:
return False, f"Missing required argument: {parent}.{arg}"
# Check values' type
if ts is not None:
if not isinstance(ts, list):
ts = [ts]
for t in ts:
if not isinstance(config[arg], t):
return (
False,
f"Argument {parent}.{arg} should be of type {t}, got {type(config[arg])}",
)
return True, ""
def get_column_filter(parent, config):
success, msg = check_type(parent, config)
if not success:
return False, msg, None
if (
config["type"] == "str"
and config["action"] in ColumnFilter.int_target_str_actions
):
dtype = TYPE_MAPPER["int"]
else:
dtype = TYPE_MAPPER[config["type"]]
# Check other keys
required_args = ["name", "action", "target", "default"]
types = [str, str, dtype, dtype]
success, msg = check_keys(parent, config, required_args, types)
if not success:
return False, msg, None
# Check action values
if config["type"] == "str":
valid_actions = ColumnFilter.str_actions
else:
valid_actions = ColumnFilter.numeric_actions
if config["action"] not in valid_actions:
return False, f"Invalid action: {parent}.action: {config['action']}", None
if config["type"] == "int":
return True, "", IntFilter(config)
elif config["type"] == "float":
return True, "", FloatFilter(config)
elif config["type"] == "str":
return True, "", StrFilter(config)
else:
raise ValueError(f"Invalid type: {config['type']}")
class MD5Filter(Operator):
valid_actions = {
"list": {"in", "not_in"},
"dict": {"eq", "ne", "gt", "lt", "ge", "le"},
}
action_mapper_left = {
"eq": "==",
"ne": "!=",
"gt": ">",
"lt": "<",
"ge": ">=",
"le": "<=",
"in": ".isin(",
"not_in": ".isin(",
}
action_mapper_right = {
"eq": "",
"ne": "",
"gt": "",
"lt": "",
"ge": "",
"le": "",
"in": ")",
"not_in": ")",
}
def __init__(self, config):
super().__init__(config)
self.name = config["name"]
self.paths = self.parse_path(config["path"])
self.dtype = config["type"]
self.action = config["action"]
self.target = config.get("target", None) # only for type=='dict'
self.is_valid = config.get("is_valid", None) # only for type=='dict'
self.arrow_file_keyword = config.get("arrow_file_keyword", None)
if self.arrow_file_keyword is not None:
if isinstance(self.arrow_file_keyword, str):
self.arrow_file_keyword = [self.arrow_file_keyword]
self.md5_data = self.load_md5s(self.paths)
@property
def data_type(self):
if self.arrow_file_keyword is None or len(self.arrow_file_keyword) == 0:
arrow_file_keyword_repr = ""
else:
arrow_file_keyword_repr = f"(arrow_file_keys={self.arrow_file_keyword})"
return_space = "\n" + " " * 22
if self.dtype == "list":
fmt_str = (
("Good Cases" if self.is_valid else " Bad Cases")
+ f" (md5): {return_space.join([str(p) for p in self.paths])} "
f"{arrow_file_keyword_repr} "
f"| {self.name}"
)
elif self.dtype == "dict":
fmt_str = (
("Good Cases" if self.is_valid else " Bad Cases")
+ f" (md5): {return_space.join([str(p) for p in self.paths])}\n"
f" --> value"
f"{self.action_mapper_left[self.action]}{self.target}{self.action_mapper_right[self.action]} "
f"{arrow_file_keyword_repr} "
f"| {self.name}"
)
else:
raise ValueError(f"Invalid type: {self.dtype}")
return fmt_str
class ListFilter(MD5Filter):
def __call__(self, md5s):
if self.action == "in":
find_valid = self.is_valid
elif self.action == "not_in":
find_valid = not self.is_valid
else:
raise ValueError(f"Invalid action: {self.action}")
if find_valid:
return md5s.apply(lambda x: x in self.md5_data)
else:
return md5s.apply(lambda x: x not in self.md5_data)
class DictFilter(MD5Filter):
def __call__(self, md5s):
if self.is_valid:
return md5s.apply(
lambda x: x in self.md5_data and self.cond(self.md5_data[x])
)
else:
return md5s.apply(
lambda x: (x not in self.md5_data) or not self.cond(self.md5_data[x])
)
def cond(self, value):
if self.action == "eq":
return value == self.target
elif self.action == "ne":
return value != self.target
elif self.action == "gt":
return value > self.target
elif self.action == "lt":
return value < self.target
elif self.action == "ge":
return value >= self.target
elif self.action == "le":
return value <= self.target
else:
raise ValueError(f"Invalid action: {self.action}")
def get_md5_filter(parent, config):
success, msg = check_type(parent, config)
if not success:
return False, msg, None
required_args = ["name", "path", "action"]
types = [str, (str, list), str]
if config["type"] == "dict":
required_args.extend(["target", "is_valid"])
types.extend([None, bool])
success, msg = check_keys(parent, config, required_args, types)
if not success:
return False, msg, None
if config["action"] not in MD5Filter.valid_actions[config["type"]]:
return False, f"Invalid action: {parent}.action: {config['action']}", None
if config["type"] == "list":
return True, "", ListFilter(config)
elif config["type"] == "dict":
return True, "", DictFilter(config)
else:
raise ValueError(f"Invalid type: {config['type']}")
class FilterCompose(object):
def __init__(self, column_filter_list, md5_filter_list):
self.column_filter_list = column_filter_list
self.md5_filter_list = md5_filter_list
def __call__(self, arrow_file, table):
stats = {}
length = len(table)
assert length > 0, "Empty table"
mask = None
for filter_ in self.column_filter_list:
if isinstance(filter_, tuple):
op, filter_list = filter_
if op == "logical_or":
sub_mask = None
for sub_filter in filter_list:
if not sub_filter.applicable(arrow_file):
continue
sub_current_mask = sub_filter(arrow_file, table)
if sub_current_mask is not None:
if sub_mask is None:
sub_mask = sub_current_mask
else:
sub_mask = sub_mask | sub_current_mask
if sub_mask is not None:
name = "|".join([f.name for f in filter_list])
stats.update({name: length - sum(sub_mask)})
if mask is None:
mask = sub_mask
else:
mask = mask & sub_mask
else:
raise ValueError(f"Invalid operation: {op}")
else:
if not filter_.applicable(arrow_file):
continue
current_mask = filter_(arrow_file, table)
if current_mask is not None:
stats.update({filter_.name: length - sum(current_mask)})
if mask is None:
mask = current_mask
else:
mask = mask & current_mask
md5s = None
for filter_ in self.md5_filter_list:
if not filter_.applicable(arrow_file):
continue
if "md5" not in table.column_names:
print(f"Warning: Column 'md5' not found in {arrow_file}.")
md5s = table["md5"].to_pandas()
current_mask = filter_(md5s)
if current_mask is not None:
stats.update({filter_.name: length - sum(current_mask)})
if mask is None:
mask = current_mask
else:
mask = mask & current_mask
return mask, stats, md5s
@property
def data_type(self):
data_type_list = []
for filter_ in self.column_filter_list + self.md5_filter_list:
if isinstance(filter_, tuple):
data_type_list.append(" || ".join([f.data_type for f in filter_[1]]))
else:
data_type_list.append(filter_.data_type)
return data_type_list
class MD5Repeater(Operator):
def __init__(self, config):
super().__init__(config)
self.name = config["name"]
self.paths = self.parse_path(config["path"])
self.dtype = config["type"]
self.plus = config.get("plus", 0)
self.repeat = config.get("repeat", None)
self.arrow_file_keyword = config.get("arrow_file_keyword", None)
if self.arrow_file_keyword is not None:
if isinstance(self.arrow_file_keyword, str):
self.arrow_file_keyword = [self.arrow_file_keyword]
self.md5_data = self.load_md5s(self.paths)
if self.repeat is None:
# Check if md5_data.values() are integers
for v in self.md5_data.values():
if not isinstance(v, int):
raise ValueError(
f"Values from {self.paths} are not integers. For example: {v}"
)
# We only check the first value for performance
# We assume all values are the same type
break
def __call__(self, arrow_file, index, md5):
if md5 in self.md5_data:
if self.repeat is None:
return self.md5_data[md5] + self.plus
else:
return self.repeat
else:
return 1
@property
def data_type(self):
path_repr = ", ".join([str(p) for p in self.paths])
if self.dtype == "list":
fmt_str = f"[MD5] Repeat {self.repeat} times: {path_repr}"
elif self.dtype == "dict":
fmt_str = f"[MD5] Repeat multiple times: {path_repr}"
else:
raise ValueError(f"Invalid type: {self.dtype}")
return fmt_str
def get_md5_repeater(parent, config):
success, msg = check_type(parent, config)
if not success:
return False, msg, None
required_args = ["name", "path"]
types = [str, str]
if config["type"] == "list":
required_args.extend(["repeat"])
types.extend([int])
success, msg = check_keys(parent, config, required_args, types)
if not success:
return False, msg, None
return True, "", MD5Repeater(config)
class KeywordRepeater(Operator):
def __init__(self, config):
super().__init__(config)
self.keywords = config["keyword"]
self.repeat = config["repeat"]
self.name = f"Repeat {self.repeat} times"
def __call__(self, arrow_file, idx, md5):
for key in self.keywords:
if key in arrow_file:
return self.repeat
return 1
@property
def data_type(self):
fmt_str = f"[Keyword] Repeat {self.repeat} times: {self.keywords}"
return fmt_str
def get_keyword_repeater(parent, config):
required_args = ["repeat", "keyword"]
types = [int, list]
success, msg = check_keys(parent, config, required_args, types)
if not success:
return False, msg, None
return True, "", KeywordRepeater(config)
class RepeaterCompose(object):
def __init__(self, repeater_list):
self.repeater_list = repeater_list
def __call__(self, arrow_file, table, indices, repeat_times=1, md5s=None):
stats = defaultdict(int)
length = len(table)
assert length > 0, "Empty table"
if md5s is None:
if "md5" not in table.column_names:
print(f"Warning: Column 'md5' not found in {arrow_file}.")
md5s = None
else:
md5s = table["md5"].to_pandas()
repeated_indices = []
for idx in indices:
md5 = md5s[idx] if md5s is not None else None
max_repeat = repeat_times
max_i = -1
for i, repeater in enumerate(self.repeater_list):
if not repeater.applicable(arrow_file):
continue
repeat = repeater(arrow_file, idx, md5)
if repeat > max_repeat:
max_repeat = repeat
max_i = i
if max_i >= 0:
stats[self.repeater_list[max_i].name] += max_repeat - 1
repeated_indices.extend([idx] * max_repeat)
return repeated_indices, stats
@property
def data_type(self):
data_type_list = []
for repeater in self.repeater_list:
data_type_list.append(repeater.data_type)
return data_type_list
class DatasetConfig(object):
def __init__(self, work_dir, config_file):
self.work_dir = work_dir
self.config_file = str(Path(config_file).absolute())
with open(self.config_file, "r") as f:
self.data = yaml.safe_load(f)
self.names = self.parse_names() # arrow names
arrow_max_repeat = max([x[1] for x in self.names])
self.filter = self.parse_filter()
self.repeater = self.parse_repeater(enable_arrow_repeat=arrow_max_repeat > 1)
# Extra arguments
self.remove_md5_dup = self.data.get("remove_md5_dup", False)
def assert_unknown_args(self):
unknown_args = set(self.data.keys()) - {
"source",
"filter",
"repeater",
"remove_md5_dup",
}
if len(unknown_args) > 0:
raise ValueError(
f"Unknown arguments in config file ({self.config_file}): {unknown_args}"
)
def parse_filter(self):
column_filter_list = []
md5_filter_list = []
if "filter" in self.data:
filters = self.data["filter"]
if "column" in filters:
column_filter_configs = filters["column"]
assert isinstance(
column_filter_configs, list
), "filter.column should be a list."
for i, config in enumerate(column_filter_configs):
if config.get("logical_or", None) is not None:
assert isinstance(
config["logical_or"], list
), f"filter.column[{i}].logical_or should be a list, got {type(config['logical_or'])}"
sub_column_filter_list = []
for j, sub_config in enumerate(config["logical_or"]):
sub_success, sub_msg, sub_filter_ = get_column_filter(
f"filter.column[{i}-logical_or].[{j}]", sub_config
)
if not sub_success:
raise ValueError(sub_msg)
sub_column_filter_list.append(sub_filter_)
success = True
msg = ""
filter_ = ("logical_or", sub_column_filter_list)
else:
success, msg, filter_ = get_column_filter(
f"filter.column[{i}]", config
)
if not success:
raise ValueError(msg)
column_filter_list.append(filter_)
if "md5" in filters:
md5_filter_configs = filters["md5"]
assert isinstance(
md5_filter_configs, list
), "filter.md5 should be a list."
for i, config in enumerate(md5_filter_configs):
if config.get("logical_or", None) is not None:
assert isinstance(
config["logical_or"], list
), f"filter.md5[{i}].logical_or should be a list, got {type(config['logical_or'])}"
sub_md5_filter_list = []
for j, sub_config in enumerate(config["logical_or"]):
sub_success, sub_msg, sub_filter_ = get_md5_filter(
f"filter.md5[{i}-logical_or].[{j}]", sub_config
)
if not sub_success:
raise ValueError(sub_msg)
sub_md5_filter_list.append(sub_filter_)
success = True
msg = ""
filter_ = ("logical_or", sub_md5_filter_list)
else:
success, msg, filter_ = get_md5_filter(
f"filter.md5[{i}]", config
)
if not success:
raise ValueError(msg)
md5_filter_list.append(filter_)
if column_filter_list or md5_filter_list:
composed_filter = FilterCompose(column_filter_list, md5_filter_list)
else:
composed_filter = None
return composed_filter
def parse_repeater(self, enable_arrow_repeat=False):
repeater_list = []
if "repeater" in self.data:
repeaters = self.data["repeater"]
if "md5" in repeaters:
md5_repeater_configs = repeaters["md5"]
assert isinstance(
md5_repeater_configs, list
), "repeater.md5 should be a list."
for i, config in enumerate(md5_repeater_configs):
success, msg, repeater = get_md5_repeater(
f"repeater.md5[{i}]", config
)
if not success:
raise ValueError(msg)
repeater_list.append(repeater)
if "arrow_file_keyword" in repeaters:
keyword_repeater_configs = repeaters["arrow_file_keyword"]
assert isinstance(
keyword_repeater_configs, list
), "repeater.arrow_file_keyword should be a list."
for i, config in enumerate(keyword_repeater_configs):
success, msg, repeater = get_keyword_repeater(
f"repeater.arrow_file_keyword[{i}]", config
)
if not success:
raise ValueError(msg)
repeater_list.append(repeater)
if repeater_list or enable_arrow_repeat:
composed_repeater = RepeaterCompose(repeater_list)
else:
composed_repeater = None
return composed_repeater
def parse_names(self):
if "source" in self.data:
source = self.data["source"]
assert isinstance(source, list), "source should be a list."
else:
raise ValueError(
"In the YAML file, the ‘source’ field is filled in incorrectly, please check"
)
names = get_or_create_arrow_name_list(source)
return names
@property
def data_type(self):
data_types = {
"Filters": [] if self.filter is None else self.filter.data_type,
"Repeaters": [] if self.repeater is None else self.repeater.data_type,
}
return data_types
import json
import pickle
from collections import defaultdict
from glob import glob
from multiprocessing import Pool
from pathlib import Path
import pandas as pd
import yaml
import numpy as np
import pyarrow as pa
from tqdm import tqdm
from index_kits.indexer import IndexV2Builder
from index_kits.bucket import build_multi_resolution_bucket
from index_kits.dataset.config_parse import DatasetConfig
def get_table(arrow_file):
return pa.ipc.RecordBatchFileReader(pa.memory_map(arrow_file, "r")).read_all()
def get_indices(arrow_file, repeat_times, filter_fn, repeat_fn, callback=None):
"""
Get valid indices from a single arrow_file.
Parameters
----------
arrow_file: str
repeat_times: int
Repeat remain indices multiple times.
filter_fn
callback
Returns
-------
"""
try:
table = pa.ipc.RecordBatchFileReader(pa.memory_map(arrow_file, "r")).read_all()
except Exception as e:
print(arrow_file, e)
raise e
length = len(table)
if len(table) == 0:
print(f"Warning: Empty table: {arrow_file}")
indices = []
stats = {}
else:
# Apply filter_fn if available
if filter_fn is not None:
mask, stats, md5s = filter_fn(arrow_file, table)
else:
mask = pd.Series([True] * length)
stats = {}
md5s = None
# Apply callback function if available
if callback is not None:
mask, stats = callback(arrow_file, table, mask, stats, md5s)
# Get indices
if mask is not None:
indices = np.where(mask)[0].tolist()
else:
indices = list(range(length))
# Apply indices repeat
if repeat_fn is not None:
indices, repeat_stats = repeat_fn(
arrow_file, table, indices, repeat_times, md5s
)
stats.update(repeat_stats)
return arrow_file, length, indices, stats
def load_md5_files(files, name=None):
if isinstance(files, str):
files = [files]
md5s = set()
for file in files:
md5s.update(Path(file).read_text().splitlines())
print(f" {name} md5s: {len(md5s):,}")
return md5s
def load_md52cls_files(files, name=None):
if isinstance(files, str):
files = [files]
md52cls = {}
for file in files:
with Path(file).open() as f:
md52cls.update(json.load(f))
print(f" {name} md52cls: {len(md52cls):,}")
return md52cls
def merge_and_build_index(data_type, src, dconfig, save_path):
if isinstance(src, str):
files = list(sorted(glob(src)))
else:
files = list(sorted(src))
print(f"Found {len(files):,} temp pickle files.")
for fname in files:
print(f" {fname}")
arrow_files = []
table_lengths = []
indices_list = []
bad_stats_total = defaultdict(int)
total_indices = 0
total_processed_length = 0
for file_name in tqdm(files):
with Path(file_name).open("rb") as f:
data = pickle.load(f)
for arrow_file, table_length, indices, *args in tqdm(data, leave=False):
arrow_files.append(arrow_file)
table_lengths.append(table_length)
total_processed_length += table_length
indices_list.append(indices)
total_indices += len(indices)
if len(args) > 0 and args[0]:
bad_stats = args[0]
for k, v in bad_stats.items():
bad_stats_total[k] += v
if len(bad_stats_total):
stats_save_dir = Path(save_path).parent
stats_save_dir.mkdir(parents=True, exist_ok=True)
stats_save_path = stats_save_dir / (Path(save_path).stem + "_stats.txt")
stats_save_path.write_text(
"\n".join([f"{k:>50s} {v}" for k, v in bad_stats_total.items()]) + "\n"
)
print(f"Save stats to {stats_save_path}")
print(f"Arrow files: {len(arrow_files):,}")
print(f"Processed indices: {total_processed_length:,}")
print(f"Valid indices: {total_indices:,}")
cum_length = 0
total_indices = []
cum_lengths = []
group_lengths = []
existed = set()
print(f"Accumulating indices...")
pbar = tqdm(
zip(arrow_files, table_lengths, indices_list),
total=len(arrow_files),
mininterval=1,
)
_count = 0
for arrow_file, table_length, indices in pbar:
if len(indices) > 0 and dconfig.remove_md5_dup:
new_indices = []
table = get_table(arrow_file)
if "md5" not in table.column_names:
raise ValueError(
f"Column 'md5' not found in {arrow_file}. "
f"When `remove_md5_dup: true` is set, md5 column is required."
)
md5s = table["md5"].to_pandas()
for i in indices:
md5 = md5s[i]
if md5 in existed:
continue
existed.add(md5)
new_indices.append(i)
indices = new_indices
total_indices.extend([int(i + cum_length) for i in indices])
cum_length += table_length
cum_lengths.append(cum_length)
group_lengths.append(len(indices))
_count += 1
if _count % 100 == 0:
pbar.set_description(f"Indices: {len(total_indices):,}")
builder = IndexV2Builder(
data_type=data_type,
arrow_files=arrow_files,
cum_length=cum_lengths,
group_length=group_lengths,
indices=total_indices,
config_file=dconfig.config_file,
)
builder.build(save_path)
print(
f"Build index finished!\n\n"
f" Save path: {Path(save_path).absolute()}\n"
f" Number of indices: {len(total_indices)}\n"
f"Number of arrow files: {len(arrow_files)}\n"
)
def worker_startup(rank, world_size, dconfig, prefix, work_dir, callback=None):
# Prepare names for this worker
num = (len(dconfig.names) + world_size - 1) // world_size
arrow_names = dconfig.names[rank * num : (rank + 1) * num]
print(f"Rank {rank} has {len(arrow_names):,} names.")
# Run get indices
print(f"Start getting indices...")
indices = []
for arrow_name, repeat_times in tqdm(
arrow_names, position=rank, desc=f"#{rank}: ", leave=False
):
indices.append(
get_indices(
arrow_name, repeat_times, dconfig.filter, dconfig.repeater, callback
)
)
# Save to a temp file
temp_save_path = (
work_dir / f"data/temp_pickles/{prefix}-{rank + 1}_of_{world_size}.pkl"
)
temp_save_path.parent.mkdir(parents=True, exist_ok=True)
with temp_save_path.open("wb") as f:
pickle.dump(indices, f)
print(f"Rank {rank} finished. Write temporary data to {temp_save_path}")
return temp_save_path
def startup(
config_file,
save,
world_size=1,
work_dir=".",
callback=None,
use_cache=False,
):
work_dir = Path(work_dir)
save_path = Path(save)
if save_path.suffix != ".json":
save_path = save_path.parent / (save_path.name + ".json")
print(f"Using save_path: {save_path}")
prefix = f"{save_path.stem}"
# Parse dataset config and build the data_type list
dconfig = DatasetConfig(work_dir, config_file)
data_type = []
for k, v in dconfig.data_type.items():
data_type.extend(v)
print(f"{k}:")
for x in v:
print(f" {x}")
if dconfig.remove_md5_dup:
data_type.append("Remove md5 duplicates.")
else:
data_type.append("Keep md5 duplicates.")
# Start processing
if not use_cache:
temp_pickles = []
if world_size == 1:
print(f"\nRunning in single process mode...")
temp_pickles.append(
worker_startup(
rank=0,
world_size=1,
dconfig=dconfig,
prefix=prefix,
work_dir=work_dir,
callback=callback,
)
)
else:
print(f"\nRunning in multi-process mode (world_size={world_size})...")
p = Pool(world_size)
temp_pickles_ = []
for i in range(world_size):
temp_pickles_.append(
p.apply_async(
worker_startup,
args=(i, world_size, dconfig, prefix, work_dir, callback),
)
)
for res in temp_pickles_:
temp_pickles.append(res.get())
# close
p.close()
p.join()
else:
temp_pickles = glob(
f"{work_dir}/data/temp_pickles/{prefix}-*_of_{world_size}.pkl"
)
# Merge temp pickles and build index
merge_and_build_index(
data_type,
temp_pickles,
dconfig,
save_path,
)
def make_multireso(
target,
config_file=None,
src=None,
base_size=None,
reso_step=None,
target_ratios=None,
align=None,
min_size=None,
md5_file=None,
):
if config_file is not None:
with Path(config_file).open() as f:
config = yaml.safe_load(f)
else:
config = {}
src = config.get("src", src)
base_size = config.get("base_size", base_size)
reso_step = config.get("reso_step", reso_step)
target_ratios = config.get("target_ratios", target_ratios)
align = config.get("align", align)
min_size = config.get("min_size", min_size)
md5_file = config.get("md5_file", md5_file)
if src is None:
raise ValueError("src must be provided in either config file or command line.")
if base_size is None:
raise ValueError("base_size must be provided.")
if reso_step is None and target_ratios is None:
raise ValueError("Either reso_step or target_ratios must be provided.")
if md5_file is not None:
with open(md5_file, "rb") as f:
md5_hw = pickle.load(f)
print(f"Md5 to height and width: {len(md5_hw):,}")
else:
md5_hw = None
build_multi_resolution_bucket(
config_file=config_file,
base_size=base_size,
reso_step=reso_step,
target_ratios=target_ratios,
align=align,
min_size=min_size,
src_index_files=src,
save_file=target,
md5_hw=md5_hw,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment