Commit 727428ec authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit CI/CD

parents
blank_issues_enabled: false
---
name: Error report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''
---
Thanks for your error report and we appreciate it a lot.
**Checklist**
1. I have searched related issues but cannot get the expected help.
2. The bug has not been fixed in the latest version.
**Describe the bug**
A clear and concise description of what the bug is.
**Reproduction**
1. What command or script did you run?
```none
A placeholder for the command.
```
2. Did you make any modifications on the code or config? Did you understand what you have modified?
3. What dataset did you use?
**Environment**
1. Please run `python utils/collect_env.py` to collect necessary environment information and paste it here.
2. You may add addition that may be helpful for locating the problem, such as
- How you installed PyTorch \[e.g., pip, conda, source\]
- Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.)
**Error traceback**
If applicable, paste the error trackback here.
```none
A placeholder for trackback.
```
**Bug fix**
If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ''
assignees: ''
---
**Describe the feature**
**Motivation**
A clear and concise description of the motivation of the feature.
Ex1. It is inconvenient when \[....\].
Ex2. There is a recent paper \[....\], which is very helpful for \[....\].
**Related resources**
If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful.
**Additional context**
Add any other context or screenshots about the feature request here.
If you would like to implement the feature and create a PR, please leave a comment here and that would be much appreciated.
---
name: General questions
about: Ask general questions to get help
title: ''
labels: ''
assignees: ''
---
---
name: Reimplementation Questions
about: Ask about questions during model reimplementation
title: ''
labels: reimplementation
assignees: ''
---
**Notice**
There are several common situations in the reimplementation issues as below
1. Reimplement a model in the model zoo using the provided configs
2. Reimplement a model in the model zoo on other dataset (e.g., custom datasets)
3. Reimplement a custom model but all the components are implemented in HunyuanDiT
4. Reimplement a custom model with new modules implemented by yourself
There are several things to do for different cases as below.
- For case 1 & 3, please follow the steps in the following sections thus we could help to quick identify the issue.
- For case 2 & 4, please understand that we are not able to do much help here because we usually do not know the full code and the users should be responsible to the code they write.
- One suggestion for case 2 & 4 is that the users should first check whether the bug lies in the self-implemented code or the original code. For example, users can first make sure that the same model runs well on supported datasets. If you still need help, please describe what you have done and what you obtain in the issue, and follow the steps in the following sections and try as clear as possible so that we can better help you.
**Checklist**
1. I have searched related issues but cannot get the expected help.
2. The issue has not been fixed in the latest version.
**Describe the issue**
A clear and concise description of what the problem you meet and what have you done.
**Reproduction**
1. What command or script did you run?
```none
A placeholder for the command.
```
2. What config dir you run?
```none
A placeholder for the config.
```
3. Did you make any modifications on the code or config? Did you understand what you have modified?
4. What dataset did you use?
**Environment**
1. Please run `python utils/collect_env.py` to collect necessary environment information and paste it here.
2. You may add addition that may be helpful for locating the problem, such as
1. How you installed PyTorch \[e.g., pip, conda, source\]
2. Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.)
**Results**
If applicable, paste the related results here, e.g., what you expect and what you get.
```none
A placeholder for results comparison
```
**Issue fix**
If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!
---
name: 一般问题
about: 提出一般性问题以获得帮助
title: ''
labels: ''
assignees: ''
---
---
name: 功能请求
about: 为此项目提出一个想法
title: ''
labels: ''
assignees: ''
---
**描述功能**
**动机**
清晰简洁地描述该功能的动机。
例如1. 当\[....\]时很不方便。
例如2. 最近有一篇论文\[....\],这对\[....\]非常有帮助。
**相关资源**
如果有官方代码发布或第三方实现,也请在此提供信息,这将非常有帮助。
**其他上下文**
在此添加有关功能请求的其他上下文或截图。
如果您希望实现该功能并创建 PR,请在此留言,我们将不胜感激。
---
name: 复现问题
about: 询问模型复现中的问题
title: ''
labels: reimplementation
assignees: ''
---
**注意**
在复现的问题中,通常有以下几种常见情况:
1. 使用提供的配置复现模型库中的模型
2. 在其他数据集(例如自定义数据集)上复现模型库中的模型
3. 实现自定义模型,但利用的都是HunyuanDiT中实现的组件
4. 使用自己实现的新模块实现自定义模型
针对不同情况需要做的事情如下:
- 对于情况 1 和 3,请按照以下部分的步骤操作,以便我们快速识别问题。
- 对于情况 2 和 4,请理解我们无法提供太多帮助,因为我们通常不了解全部代码,用户应对自己编写的代码负责。
- 针对情况 2 和 4 的一个建议是,用户应首先检查错误是否在自实现代码中或原始代码中。例如,用户可以首先确保相同的模型在支持的数据集上运行良好。如果您仍然需要帮助,请描述您所做的工作和遇到的问题,并按照以下部分的步骤尽可能清晰地描述,以便我们更好地帮助您。
**检查清单**
1. 我已搜索相关问题,但无法获得预期的帮助。
2. 最新版本中尚未修复此问题。
**描述问题**
清晰简洁地描述您遇到的问题以及您所做的工作。
**重现步骤**
1. 您运行了什么命令或脚本?
```none
命令的占位符。
```
2. 您运行的配置目录是什么?
```none
配置目录的占位符
```
3. 您是否对代码或配置进行了任何修改?您是否理解您所修改的内容?
4. 您使用了什么数据集?
**环境**
1. 请运行 python utils/collect_env.py 收集必要的环境信息并粘贴在此。
2. 您可以添加其他有助于定位问题的信息,例如
- 您如何安装的 PyTorch \[例如,pip, conda, source\]
- 其他可能相关的环境变量 (例如 `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH` 等)
**结果**
如果适用,请在此粘贴相关结果,例如,您的预期结果和实际结果。
```none
结果比较的占位符
```
**Bug修复**
如果您已经确定了原因,可以在此提供信息。如果您愿意创建 PR 进行修复,也请在此留言,我们将不胜感激!
\ No newline at end of file
---
name: 错误报告
about: 创建报告以帮助我们改进
title: ''
labels: ''
assignees: ''
---
感谢您的错误报告,我们非常感谢。
**检查清单**
1. 我已搜索相关问题,但无法获得预期的帮助。
2. 最新版本中尚未修复此错误。
**描述错误**
对错误的清晰简洁描述。
**重现步骤**
1. 您运行了什么命令或脚本?
```none
命令的占位符。
```
2. 您是否对代码或配置进行了任何修改?您是否理解您所修改的内容?
3. 您使用了什么数据集?
**环境**
1. 请运行 python utils/collect_env.py 收集必要的环境信息并粘贴在此。
2. 您可以添加其他有助于定位问题的信息,例如
- 您如何安装的 PyTorch \[例如,pip, conda, source\]
- 其他可能相关的环境变量 (例如 `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH` 等)
**错误回溯**
如可以,请在此粘贴错误回溯信息。
```none
回溯信息的占位符。
```
**Bug修复**
如果您已经确定了原因,可以在此提供信息。如果您愿意创建 PR 进行修复,也请在此留言,我们将不胜感激!
results/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
.vscode/
.idea/
.DS_Store
results/
trt/activate.sh
trt/deactivate.sh
*.onnx
ckpts/
app/output/
[submodule "comfyui-hydit"]
path = comfyui-hydit
url = https://github.com/zml-ai/comfyui-hydit
# IndexKits
[TOC]
## Introduction
Index Kits (`index_kits`) for streaming Arrow data.
* Supports creating datasets from configuration files.
* Supports creating index v2 from arrows.
* Supports the creation, reading, and usage of index files.
* Supports the creation, reading, and usage of multi-bucket, multi-resolution indexes.
`index_kits` has the following advantages:
* Optimizes memory usage for streaming reads, supporting the reading of data at the billion level.
* Optimizes memory usage and reading speed of Index files.
* Supports the creation of Base/Multireso Index V2 datasets, including data filtering, repeating, and deduplication during creation.
* Supports loading various dataset types (Arrow files, Base Index V2, multiple Base Index V2, Multireso Index V2, multiple Multireso Index V2).
* Uses a unified API interface to access images, text, and other attributes for all dataset types.
* Built-in shuffle, `resize_and_crop`, and returns the coordinates of the crop.
## Installation
* **Install from pre-compiled whl (recommended)**
* Install from source
```shell
cd IndexKits
pip install -e .
```
## Usage
### Loading an Index V2 File
```python
from index_kits import ArrowIndexV2
index_manager = ArrowIndexV2('data.json')
# You can shuffle the dataset using a seed. If a seed is provided (i.e., the seed is not None), this shuffle will not affect any random state.
index_manager.shuffle(1234, fast=True)
for i in range(len(index_manager)):
# Get an image (requires the arrow to contain an image/binary column):
pil_image = index_manager.get_image(i)
# Get text (requires the arrow to contain a text_zh column):
text = index_manager.get_attribute(i, column='text_zh')
# Get MD5 (requires the arrow to contain an md5 column):
md5 = index_manager.get_md5(i)
# Get any attribute by specifying the column (must be contained in the arrow):
ocr_num = index_manager.get_attribute(i, column='ocr_num')
# Get multiple attributes at once
item = index_manager.get_data(i, columns=['text_zh', 'md5']) # i: in-dataset index
print(item)
# {
# 'index': 3, # in-json index
# 'in_arrow_index': 3, # in-arrow index
# 'arrow_name': '/HunYuanDiT/dataset/porcelain/00000.arrow',
# 'text_zh': 'Fortune arrives with the auspicious green porcelain tea cup',
# 'md5': '1db68f8c0d4e95f97009d65fdf7b441c'
# }
```
### Loading a Set of Arrow Files
If you have a small batch of arrow files, you can directly use `IndexV2Builder` to load these arrow files without creating an Index V2 file.
```python
from index_kits import IndexV2Builder
index_manager = IndexV2Builder(arrow_files).to_index_v2()
```
### Loading a Multi-Bucket (Multi-Resolution) Index V2 File
When using a multi-bucket (multi-resolution) index file, refer to the following example code, especially the definition of SimpleDataset and the usage of MultiResolutionBucketIndexV2.
```python
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
from index_kits import MultiResolutionBucketIndexV2
from index_kits.sampler import BlockDistributedSampler
class SimpleDataset(Dataset):
def __init__(self, index_file, batch_size, world_size):
# When using multi-bucket (multi-resolution), batch_size and world_size need to be specified for underlying data alignment.
self.index_manager = MultiResolutionBucketIndexV2(index_file, batch_size, world_size)
self.flip_norm = T.Compose(
[
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize([0.5], [0.5]),
]
)
def shuffle(self, seed, fast=False):
self.index_manager.shuffle(seed, fast=fast)
def __len__(self):
return len(self.index_manager)
def __getitem__(self, idx):
image = self.index_manager.get_image(idx)
original_size = image.size # (w, h)
target_size = self.index_manager.get_target_size(idx) # (w, h)
image, crops_coords_left_top = self.index_manager.resize_and_crop(image, target_size)
image = self.flip_norm(image)
text = self.index_manager.get_attribute(idx, column='text_zh')
return image, text, (original_size, target_size, crops_coords_left_top)
batch_size = 8 # batch_size per GPU
world_size = 8 # total number of GPUs
rank = 0 # rank of the current process
num_workers = 4 # customizable based on actual conditions
shuffle = False # must be set to False
drop_last = True # must be set to True
# Correct batch_size and world_size must be passed in, otherwise, the multi-resolution data cannot be aligned correctly.
dataset = SimpleDataset('data_multireso.json', batch_size=batch_size, world_size=world_size)
# Must use BlockDistributedSampler to ensure the samples in a batch have the same resolution.
sampler = BlockDistributedSampler(dataset, num_replicas=world_size, rank=rank,
shuffle=shuffle, drop_last=drop_last, batch_size=batch_size)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
num_workers=num_workers, pin_memory=True, drop_last=drop_last)
for epoch in range(10):
# Please use the shuffle method provided by the dataset, not the DataLoader's shuffle parameter.
dataset.shuffle(epoch, fast=True)
for batch in loader:
pass
```
## Fast Shuffle
When your index v2 file contains hundreds of millions of samples, using the default shuffle can be quite slow. Therefore, it’s recommended to enable `fast=True` mode:
```python
index_manager.shuffle(seed=1234, fast=True)
```
This will globally shuffle without keeping the indices of the same arrow together. Although this may reduce reading speed, it requires a trade-off based on the model’s forward time.
\ No newline at end of file
#!/usr/bin/env python3
import argparse
from index_kits.dataset.make_dataset_core import startup, make_multireso
from index_kits.common import show_index_info
from index_kits import __version__
def common_args(parser):
parser.add_argument('-t', '--target', type=str, required=True, help='Save path')
def get_args():
parser = argparse.ArgumentParser(description="""
IndexKits is a tool to build and manage index files for large-scale datasets.
It supports both base index and multi-resolution index.
Introduction
------------
This command line tool provides the following functionalities:
1. Show index v2 information
2. Build base index v2
3. Build multi-resolution index v2
Examples
--------
1. Show index v2 information
index_kits show /path/to/index.json
2. Build base index v2
Default usage:
index_kits base -c /path/to/config.yaml -t /path/to/index.json
Use multiple processes:
index_kits base -c /path/to/config.yaml -t /path/to/index.json -w 40
3. Build multi-resolution index v2
Build with a configuration file:
index_kits multireso -c /path/to/config.yaml -t /path/to/index_mb_gt512.json
Build by specifying arguments without a configuration file:
index_kits multireso --src /path/to/index.json --base-size 512 --reso-step 32 --min-size 512 -t /path/to/index_mb_gt512.json
Build by specifying target-ratios:
index_kits multireso --src /path/to/index.json --base-size 512 --target-ratios 1:1 4:3 3:4 16:9 9:16 --min-size 512 -t /path/to/index_mb_gt512.json
Build with multiple source index files.
index_kits multireso --src /path/to/index1.json /path/to/index2.json --base-size 512 --reso-step 32 --min-size 512 -t /path/to/index_mb_gt512.json
""", formatter_class=argparse.RawTextHelpFormatter)
sub_parsers = parser.add_subparsers(dest='task', required=True)
# Show index message
show_parser = sub_parsers.add_parser('show', description="""
Show base/multireso index v2 information.
Example
-------
index_kits show /path/to/index.json
""", formatter_class=argparse.RawTextHelpFormatter)
show_parser.add_argument('src', type=str, help='Path to a base/multireso index file.')
show_parser.add_argument('--arrow-files', action='store_true', help='Show arrow files only.')
show_parser.add_argument('--depth', type=int, default=1,
help='Arrow file depth. Default is 1, the level of last folder in the arrow file path. '
'Set it to 0 to show the full path including `xxx/last_folder/*.arrow`.')
# Single resolution bucket
base_parser = sub_parsers.add_parser('base', description="""
Build base index v2.
Example
-------
index_kits base -c /path/to/config.yaml -t /path/to/index.json
""", formatter_class=argparse.RawTextHelpFormatter)
base_parser.add_argument('-c', '--config', type=str, required=True, help='Configuration file path')
common_args(base_parser)
base_parser.add_argument('-w', '--world-size', type=int, default=1)
base_parser.add_argument('--work-dir', type=str, default='.', help='Work directory')
base_parser.add_argument('--use-cache', action='store_true', help='Use cache to avoid reprocessing. '
'Perform merge pkl results directly.')
# Multi-resolution bucket
mo_parser = sub_parsers.add_parser('multireso', description="""
Build multi-resolution index v2
Example
-------
Build with a configuration file:
index_kits multireso -c /path/to/config.yaml -t /path/to/index_mb_gt512.json
Build by specifying arguments without a configuration file:
index_kits multireso --src /path/to/index.json --base-size 512 --reso-step 32 --min-size 512 -t /path/to/index_mb_gt512.json
Build by specifying target-ratios:
index_kits multireso --src /path/to/index.json --base-size 512 --target-ratios 1:1 4:3 3:4 16:9 9:16 --min-size 512 -t /path/to/index_mb_gt512.json
Build with multiple source index files.
index_kits multireso --src /path/to/index1.json /path/to/index2.json --base-size 512 --reso-step 32 --min-size 512 -t /path/to/index_mb_gt512.json
""", formatter_class=argparse.RawTextHelpFormatter)
mo_parser.add_argument('-c', '--config', type=str, default=None,
help='Configuration file path in a yaml format. Either --config or --src must be provided.')
mo_parser.add_argument('-s', '--src', type=str, nargs='+', default=None,
help='Source index files. Either --config or --src must be provided.')
common_args(mo_parser)
mo_parser.add_argument('--base-size', type=int, default=None, help="Base size. Typically set as 256/512/1024 according to image size you train model.")
mo_parser.add_argument('--reso-step', type=int, default=None,
help="Resolution step. Either reso_step or target_ratios must be provided.")
mo_parser.add_argument('--target-ratios', type=str, nargs='+', default=None,
help="Target ratios. Either reso_step or target_ratios must be provided.")
mo_parser.add_argument('--md5-file', type=str, default=None,
help='You can provide an md5 to height and width file to accelerate the process. '
'It is a pickle file that contains a dict, which maps md5 to (height, width) tuple.')
mo_parser.add_argument('--align', type=int, default=16, help="Used when --target-ratios is provided. Align size of source image height and width.")
mo_parser.add_argument('--min-size', type=int, default=0,
help="Minimum size. Images smaller than this size will be ignored.")
# Common
parser.add_argument('-v', '--version', action='version', version=f'%(prog)s {__version__}')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = get_args()
if args.task == 'show':
show_index_info(args.src,
args.arrow_files,
args.depth,
)
elif args.task == 'base':
startup(args.config,
args.target,
args.world_size,
args.work_dir,
use_cache=args.use_cache,
)
elif args.task == 'multireso':
make_multireso(args.target,
args.config,
args.src,
args.base_size,
args.reso_step,
args.target_ratios,
args.align,
args.min_size,
args.md5_file,
)
# Use the IndexKits library to create datasets.
[TOC]
## Introdution
The IndexKits library offers a command-line tool `idk` for creating datasets and viewing their statistical information. You can view the instructions by using `idk -h`. Here, it refers to creating an Index V2 format dataset from a series of arrow files.
## 1. Creating a Base Index V2 Dataset
### 1.1 Create a Base Index V2 dataset using `idk`
When creating a Base Index V2 dataset, you need to specify the path to a configuration file using the `-c` parameter and a save path using the `-t` parameter.
```shell
idk base -c base_config.yaml -t base_dataset.json
```
### 1.2 Basic Configuration
Next, let’s discuss how to write the configuration file. The configuration file is in `yaml` format, and below is a basic example:
Filename: `base_config.yaml`
```yaml
source:
- /HunYuanDiT/dataset/porcelain/arrows/00000.arrow
```
| Field Name | Type | Description |
|:---------:|:---:|:----------:|
| source | Optional | Arrow List |
We provide an example that includes all features and fields in [full_config.yaml](./docs/full_config.yaml).
### 1.3 Filtering Criteria
`idk` offers two types of filtering capabilities during the dataset creation process: (1) based on columns in Arrow, and (2) based on MD5 files.
To enable filtering criteria, add the `filter` field in the configuration file.
#### 1.3.1 Column Filtering
To enable column filtering, add the `column` field under the `filter` section.
Multiple column filtering criteria can be applied simultaneously, with the intersection of multiple conditions being taken.
For example, to select data where both the length and width are greater than or equal to 512,
with the default being 1024 if the length and width are invalid:
```yaml
filter:
column:
- name: height
type: int
action: ge
target: 512
default: 1024
- name: width
type: int
action: ge
target: 512
default: 1024
```
This filtering condition is equivalent to `table['height'].to_int(default=1024) >= 512 && table['width'].to_int(default=1024) >= 512`.
Each filtering criterion includes the following fields:
| Field Name | Type | Description | Value Range |
|:------------------:|:---:|:---------------:|:----------------------:|
| name | Required | Column name in Arrow | Column in Arrow |
| type | Required | Type of Elements in the Column | `int`, `float` or `str` |
| action | Required | Filtering Criteria | See the table below for possible values |
| target | Required | Filtering Criteria | Numeric or String |
| default | Required | Default value when the element is invalid | Numeric or String |
| arrow_file_keyword | Optional | Keywords in the Arrow file path | - |
Below are the specific meanings of “action” and the optional values under different circumstances::
| Action | Description | Type | Action | Description | Type |
|:-------------:|:----------------------------------------------------:|:---------------------:|:------------:|:--------------------------------------------:|:---------------------:|
| eq | equal, `==` | `int`, `float`, `str` | ne | not equal, `!=` | `int`, `float`, `str` |
| gt | great than, `>` | `int`, `float` | lt | less than, `<` | `int`, `float` |
| ge | great or equal, `>=` | `int`, `float` | le | less or equal, `<=` | `int`, `float` |
| len_eq | str length equal, `str.len()==` | `str` | len_ne | str length not equal, `str.len()!=` | `str` |
| len_gt | str length great than, `str.len()>` | `str` | len_lt | str length less than, `str.len()<` | `str` |
| len_ge | str length great or equal, `str.len()>=` | `str` | len_le | str length less or equal, `str.len()<=` | `str` |
| contains | str contains, `str.contains(target)` | `str` | not_contains | str not contains, `str.not_contains(target)` | `str` |
| in | str in, `str.in(target)` | `str` | not_in | str not in, `str.not_in(target)` | `str` |
| lower_last_in | lower str last char in, `str.lower()[-1].in(target)` | `str` | | | |
#### 1.3.2 MD5 Filtering
Add an `md5` field under the `filter` section to initiate MD5 filtering. Multiple MD5 filtering criteria can be applied simultaneously, with the intersection of multiple conditions being taken.
例如:
* `badcase.txt` is a list of MD5s, aiming to filter out the entries listed in these lists.
* `badcase.json` is a dictionary, where the key is the MD5 and the value is a text-related `tag`. The goal is to filter out specific `tags`.
```shell
filter:
md5:
- name: badcase1
path:
- badcase1.txt
type: list
action: in
is_valid: false
- name: badcase2
path: badcase2.json
type: dict
action: eq
target: 'Specified tag'
is_valid: false
```
Each filtering criterion includes five fields:
| Field Name | Type | Description | Value Range |
|:------------------:|:---:|:-----------------------------------------------------------------:|:------------------------------------------------------------------------------------:|
| name | Required | The name of the filtering criterion, which can be customized for ease of statistics. | - |
| path | Required | The path to the filtering file, which can be a single path or multiple paths provided in a list format. Supports `.txt`, `.json`, `.pkl` formats. | - |
| type | Required | The type of records in the filtering file. | `list` or `dict` |
| action | Required | The filtering action. | For `list`: `in`, `not_in`; For `dict`: `eq`, `ne`, `gt`, `lt`, `ge`, `le`|
| target | Optional | The filtering criterion. | Required when type is `dict` |
| is_valid | Required | Whether a hit on action+target is considered valid or invalid. | `true` or `false` |
| arrow_file_keyword | Optional | Keywords in the Arrow file path. | - |
### 1.4 Advanced Filtering
`idk` also supports some more advanced filtering functions.
#### 1.4.1 Filtering Criteria Applied to Part of Arrow Files
Using the `arrow_file_keyword` parameter, filtering criteria can be applied only to part of the Arrow files.
For example:
* The filtering criterion `height>=0` applies only to arrows whose path includes `human`.
* The filtering criterion “keep samples in goodcase.txt” applies only to arrows whose path includes `human`.
```yaml
filter:
column:
- name: height
type: int
action: ge
target: 512
default: 1024
arrow_file_keyword:
- human
md5:
- name: goodcase
path: goodcase.txt
type: list
action: in
is_valid: true
arrow_file_keyword:
- human
```
#### 1.4.2 The “or” Logic in Filtering Criteria
By default, filtering criteria follow an “and” logic. If you want two or more filtering criteria to follow an “or” logic, you should use the `logical_or` field. The column filtering criteria listed under this field will be combined using an “or” logic.
```yaml
filter:
column:
- logical_or:
- name: md5
type: str
action: lower_last_in
target: '02468ace'
default: ''
- name: text_zh
type: str
action: contains
target: 'turtle|rabbit|duck|swan|peacock|hedgehog|wolf|fox|seal|crow|deer'
default: ''
```
Special Note: The `logical_or` field is applicable only to the `column` filtering criteria within `filter`.
1.4.3 Excluding Certain Arrows from the Source
While wildcards can be used to fetch multiple arrows at once, there might be instances where we want to exclude some of them. This can be achieved through the `exclude` field. Keywords listed under `exclude`, if found in the path of the current group of arrows, will result in those arrows being excluded.
```yaml
source:
- /HunYuanDiT/dataset/porcelain/arrows/*.arrow:
exclude:
- arrow1
- arrow2
```
### 1.5 Repeating Samples
`idk` offers the capability to repeat either all samples or specific samples during dataset creation. There are three types of repeaters:
* Directly repeating the source
* Based on keywords in the Arrow file name (enable repeat conditions by adding the `repeater` field in the configuration file)
* Based on an MD5 file (enable repeat conditions by adding the `repeater` field in the configuration file)
**Special Note:** The above three conditions can be used simultaneously. If a sample meets multiple repeat conditions, the highest number of repeats will be taken.
#### 1.5.1 Repeating the Source
In the source, for the Arrow(s) you want to repeat (which can include wildcards like `*`, `?`), add `repeat: n` to mark the number of times to repeat.
```yaml
source:
- /HunYuanDiT/dataset/porcelain/arrows/*.arrow:
repeat: 10
```
**Special Note: Add a colon at the end of the Arrow path**
#### 1.5.2 Arrow Keywords
Add an `arrow_file_keyword` field under the `repeater` section.
```yaml
repeater:
arrow_file_keyword:
- repeat: 8
keyword:
- Lolita anime style
- Minimalist style
- repeat: 5
keyword:
- Magical Barbie style
- Disney style
```
Each repeat condition includes two fields:
| Field Name | Type | Description | Value Range |
|:-------:|:---:|:---------------:|:----:|
| repeat | Required | The number of times to repeat | Number |
| keyword | Required | Keywords in the Arrow file path | - |
#### 1.5.3 MD5 File
Add an `md5` field under the `repeater` section.
```yaml
repeater:
md5:
- name: goodcase1
path: /HunYuanDiT/dataset/porcelain/md5_repeat_1.json
type: dict
plus: 3
- name: goodcase2
path: /HunYuanDiT/dataset/porcelain/md5_repeat_2.json
type: list
repeat: 6
```
Each repeat condition includes the following fields:
| Field Name | Type | Description | Value Range |
|:------:|:---:|:----------------------------------------------:|:---------------------------:|
| name | Required | Custom name for the repeat condition | - |
| path | Required | Path to the file containing MD5s | `.txt`, `.json`, or `.pkl` formats |
| type | Required | Type of the MD5 file | `list` or `dict` |
| repeat | Optional | Number of times to repeat, this will override plus | Integer |
| plus | Optional | Adds this value to the value obtained from the MD5 file as the number of repeats | Integer |
### 1.6 Deduplication
To deduplicate, add `remove_md5_dup` and set it to `true`. For example:
```yaml
remove_md5_dup: true
```
**Special Note:** Deduplication is performed after repeat conditions, so using them together will make repeat ineffective.
### 1.7 Creating a Base Index V2 Dataset with Python Code
```python
from index_kits import IndexV2Builder
builder = IndexV2Builder(arrow_files)
builder.save('data_v2.json')
```
## 2. Creating a Multireso Index V2 Dataset
### 2.1 Creating a Multireso Index V2 Dataset with `idk`
To create a multi-resolution dataset, you can do so through a configuration file:
```shell
src:
- /HunYuanDiT/dataset/porcelain/jsons/a.json
- /HunYuanDiT/dataset/porcelain/jsons/b.json
- /HunYuanDiT/dataset/porcelain/jsons/c.json
base_size: 512
reso_step: 32
min_size: 512
```
The fields are as follows:
| Field Name | Type | Description | Value Range |
|:-------------:|:---:|:------------------------------------------------------:|:-------------------------------:|
| src | Required | Path(s) to the Base Index V2 file, can be single or multiple | - |
| base_size | Required | The base resolution (n, n) from which to create multiple resolutions | Recommended values: 256/512/1024 |
| reso_step | Optional | The step size for traversing multiple resolutions. Choose either this or target_ratios | Recommended values: 16/32/64 |
| target_ratios | Optional | Target aspect ratios, a list. Choose either this or reso_step | Recommended values: 1:1, 4:3, 3:4, 16:9, 9:16 |
| align | Optional | When using target_ratios, the multiple to align the target resolution to | Recommended value: 16 (2x patchify, 8x VAE) |
| min_size | Optional | The minimum resolution filter for samples when creating from Base to Multireso Index V2 | Recommended values: 256/512/1024 |
| md5_file | Optional | A pre-calculated dictionary of image sizes in pkl format, key is MD5, value is (h, w) | - |
### 2.2 Creating a Multireso Index V2 Dataset with Python Code
First, import the necessary function
```python
from index_kits import build_multi_resolution_bucket
md5_hw = None
# If many arrows in your Index V2 lack height and width, you can pre-calculate them and pass through the md5_hw parameter.
# md5_hw = {'c67be1d8f30fd0edcff6ac99b703879f': (720, 1280), ...}
# with open('md5_hw.pkl', 'rb') as f:
# md5_hw = pickle.load(f)
index_v2_file_path = 'data_v2.json'
index_v2_save_path = 'data_multireso.json'
# Method 1: Given base_size and reso_step, automatically calculate all resolution buckets.
build_multi_resolution_bucket(base_size=1024,
reso_step=64,
min_size=1024,
src_index_files=index_v2_file_path,
save_file=index_v2_save_path,
md5_hw=md5_hw)
# Method 2: Given a series of target aspect ratios, automatically calculate all resolution buckets.
build_multi_resolution_bucket(base_size=1024,
target_ratios=["1:1", "3:4", "4:3", "16:9", "9:16"],
align=16,
min_size=1024,
src_index_files=index_v2_file_path,
save_file=index_v2_save_path,
md5_hw=md5_hw)
```
**Note:** If both reso_step and target_ratios are provided, target_ratios will be prioritized.
source:
- /HunyuanDiT/dataset/task1/*.arrow
- /HunyuanDiT/dataset/task2/*.arrow:
repeat: 2
remove_md5_dup: false
filter:
column:
- name: height
type: int
action: ge
target: 512
default: 1024
- name: width
type: int
action: ge
target: 512
default: 1024
md5:
- name: Badcase
path: /path/to/bad_case_md5.txt
type: list
action: in
is_valid: false
repeater:
md5:
- name: Hardcase Repeat
path: /path/to/hard_case_md5.json
type: dict
plus: 3
- name: Goodcase Repeat
path: /path/to/good_case_md5.txt
type: list
repeat: 6
from .bucket import (
MultiIndexV2,
MultiResolutionBucketIndexV2,
MultiMultiResolutionBucketIndexV2,
build_multi_resolution_bucket,
)
from .bucket import Resolution, ResolutionGroup
from .indexer import IndexV2Builder, ArrowIndexV2
from .common import load_index, show_index_info
__version__ = "0.3.5"
This diff is collapsed.
import json
import numpy as np
from pathlib import Path
from functools import partial
from .indexer import ArrowIndexV2
from .bucket import (
ResolutionGroup,
MultiIndexV2,
MultiResolutionBucketIndexV2,
MultiMultiResolutionBucketIndexV2,
IndexV2Builder,
)
def load_index(
src,
multireso=False,
batch_size=1,
world_size=1,
sample_strategy="uniform",
probability=None,
shadow_file_fn=None,
seed=None,
):
if isinstance(src, str):
src = [src]
if src[0].endswith(".arrow"):
if multireso:
raise ValueError(
"Arrow file does not support multiresolution. Please make base index V2 first and then"
"build multiresolution index."
)
idx = IndexV2Builder(src).to_index_v2()
elif src[0].endswith(".json"):
if multireso:
if len(src) == 1:
idx = MultiResolutionBucketIndexV2(
src[0],
batch_size=batch_size,
world_size=world_size,
shadow_file_fn=shadow_file_fn,
)
else:
idx = MultiMultiResolutionBucketIndexV2(
src,
batch_size=batch_size,
world_size=world_size,
sample_strategy=sample_strategy,
probability=probability,
shadow_file_fn=shadow_file_fn,
seed=seed,
)
else:
if len(src) == 1:
idx = ArrowIndexV2(
src[0],
shadow_file_fn=shadow_file_fn,
)
else:
idx = MultiIndexV2(
src,
sample_strategy=sample_strategy,
probability=probability,
shadow_file_fn=shadow_file_fn,
seed=seed,
)
else:
raise ValueError(f"Unknown file type: {src[0]}")
return idx
def get_attribute(data, attr_list):
ret_data = {}
for attr in attr_list:
ret_data[attr] = data.get(attr, None)
if ret_data[attr] is None:
raise ValueError(f"Missing {attr} in {data}")
return ret_data
def get_optional_attribute(data, attr_list):
ret_data = {}
for attr in attr_list:
ret_data[attr] = data.get(attr, None)
return ret_data
def detect_index_type(data):
if isinstance(data["group_length"], dict):
return "multireso"
else:
return "base"
def show_index_info(src, only_arrow_files=False, depth=1):
"""
Show base/multireso index information.
"""
if not Path(src).exists():
raise ValueError(f"{src} does not exist.")
print(f"Loading index file {src} ...")
with open(src, "r") as f:
src_data = json.load(f)
print(f"Loaded.")
data = get_attribute(
src_data,
[
"data_type",
"indices_file",
"arrow_files",
"cum_length",
"group_length",
"indices",
"example_indices",
],
)
opt_data = get_optional_attribute(src_data, ["config_file"])
# Format arrow_files examples
arrow_files = data["arrow_files"]
if only_arrow_files:
existed = set()
arrow_files_output_list = []
for arrow_file in arrow_files:
if depth == 0:
if arrow_file not in existed:
arrow_files_output_list.append(arrow_file)
existed.add(arrow_file)
elif depth > 0:
parts = Path(arrow_file).parts
if depth >= len(parts):
continue
else:
arrow_file_part = "/".join(parts[:-depth])
if arrow_file_part not in existed:
arrow_files_output_list.append(arrow_file_part)
existed.add(arrow_file_part)
else:
raise ValueError(
f"Depth {depth} has exceeded the limit of arrow file {arrow_file}."
)
arrow_files_repr = "\n".join(arrow_files_output_list)
print(arrow_files_repr)
return None
return_space = "\n" + " " * 21
if len(arrow_files) <= 4:
arrow_files_repr = return_space.join([arrow_file for arrow_file in arrow_files])
else:
arrow_files_repr = return_space.join(
[_ for _ in arrow_files[:2]] + ["..."] + [_ for _ in arrow_files[-2:]]
)
arrow_files_length = len(arrow_files)
# Format data_type
data_type = data["data_type"]
if isinstance(data_type, str):
data_type = [data_type]
data_type_common = []
src_files = []
found_src_files = False
for data_type_item in data_type:
if not found_src_files and data_type_item.strip() != "src_files=":
data_type_common.append(data_type_item.strip())
continue
found_src_files = True
if data_type_item.endswith(".json"):
src_files.append(data_type_item.strip())
else:
data_type_common.append(data_type_item.strip())
data_type_part2_with_ids = []
max_id_len = len(str(len(src_files)))
for sid, data_type_item in enumerate(src_files, start=1):
data_type_part2_with_ids.append(
f"{str(sid).rjust(max_id_len)}. {data_type_item}"
)
data_type = data_type_common + data_type_part2_with_ids
data_repr = return_space.join(data_type)
# Format cum_length examples
cum_length = data["cum_length"]
if len(cum_length) <= 8:
cum_length_repr = ", ".join([str(i) for i in cum_length])
else:
cum_length_repr = ", ".join(
[str(i) for i in cum_length[:4]]
+ ["..."]
+ [str(i) for i in cum_length[-4:]]
)
cum_length_length = len(cum_length)
if detect_index_type(data) == "base":
# Format group_length examples
group_length = data["group_length"]
if len(group_length) <= 8:
group_length_repr = ", ".join([str(i) for i in group_length])
else:
group_length_repr = ", ".join(
[str(i) for i in group_length[:4]]
+ ["..."]
+ [str(i) for i in group_length[-4:]]
)
group_length_length = len(group_length)
# Format indices examples
indices = data["indices"]
if len(indices) == 0 and data["indices_file"] != "":
indices_file = Path(src).parent / data["indices_file"]
if Path(indices_file).exists():
print(f"Loading indices from {indices_file} ...")
indices = np.load(indices_file)["x"]
print(f"Loaded.")
else:
raise ValueError(
f"This Index file contains an extra file {indices_file} which is missed."
)
if len(indices) <= 8:
indices_repr = ", ".join([str(i) for i in indices])
else:
indices_repr = ", ".join(
[str(i) for i in indices[:4]] + ["..."] + [str(i) for i in indices[-4:]]
)
# Calculate indices total length
indices_length = len(indices)
print_str = f"""File: {Path(src).absolute()}
ArrowIndexV2(
\033[4mdata_type:\033[0m {data_repr}"""
# Process optional data
if opt_data["config_file"] is not None:
print_str += f"""
\033[4mconfig_file:\033[0m {opt_data['config_file']}"""
# Add common data
print_str += f"""
\033[4mindices_file:\033[0m {data['indices_file']}
\033[4marrow_files: Count = {arrow_files_length:,}\033[0m
Examples: {arrow_files_repr}
\033[4mcum_length: Count = {cum_length_length:,}\033[0m
Examples: {cum_length_repr}
\033[4mgroup_length: Count = {group_length_length:,}\033[0m
Examples: {group_length_repr}
\033[4mindices: Count = {indices_length:,}\033[0m
Examples: {indices_repr}"""
else:
group_length = data["group_length"]
indices_file = Path(src).parent / data["indices_file"]
assert Path(indices_file).exists(), f"indices_file {indices_file} not found"
print(f"Loading indices from {indices_file} ...")
indices_data = np.load(indices_file)
print(f"Loaded.")
indices_length = sum([len(indices) for key, indices in indices_data.items()])
keys = [k for k in group_length.keys() if len(indices_data[k]) > 0]
resolutions = ResolutionGroup.from_list_of_hxw(keys)
resolutions.attr = [
f"{len(indices):>,d}" for k, indices in indices_data.items()
]
resolutions.prefix_space = 25
print_str = f"""File: {Path(src).absolute()}
MultiResolutionBucketIndexV2(
\033[4mdata_type:\033[0m {data_repr}"""
# Process optional data
if opt_data["config_file"] is not None:
print_str += f"""
\033[4mconfig_file:\033[0m {opt_data['config_file']}"""
# Process config files of base index files
config_files = []
for src_file in src_files:
src_file = Path(src_file)
if src_file.exists():
with src_file.open() as f:
base_data = json.load(f)
if "config_file" in base_data:
config_files.append(base_data["config_file"])
else:
config_files.append("Unknown")
else:
config_files.append("Missing the src file")
if config_files:
config_file_str = return_space.join(
[
f"{str(sid).rjust(max_id_len)}. {config_file}"
for sid, config_file in enumerate(config_files, start=1)
]
)
print_str += f"""
\033[4mbase config files:\033[0m {config_file_str}"""
# Add common data
print_str += f"""
\033[4mindices_file:\033[0m {data['indices_file']}
\033[4marrow_files: Count = {arrow_files_length:,}\033[0m
Examples: {arrow_files_repr}
\033[4mcum_length: Count = {cum_length_length:,}\033[0m
Examples: {cum_length_repr}
\033[4mindices: Count = {indices_length:,}\033[0m
\033[4mbuckets: Count = {len(keys)}\033[0m
{resolutions}"""
print(print_str + "\n)\n")
This diff is collapsed.
import json
import pickle
from collections import defaultdict
from glob import glob
from multiprocessing import Pool
from pathlib import Path
import pandas as pd
import yaml
import numpy as np
import pyarrow as pa
from tqdm import tqdm
from index_kits.indexer import IndexV2Builder
from index_kits.bucket import build_multi_resolution_bucket
from index_kits.dataset.config_parse import DatasetConfig
def get_table(arrow_file):
return pa.ipc.RecordBatchFileReader(pa.memory_map(arrow_file, "r")).read_all()
def get_indices(arrow_file, repeat_times, filter_fn, repeat_fn, callback=None):
"""
Get valid indices from a single arrow_file.
Parameters
----------
arrow_file: str
repeat_times: int
Repeat remain indices multiple times.
filter_fn
callback
Returns
-------
"""
try:
table = pa.ipc.RecordBatchFileReader(pa.memory_map(arrow_file, "r")).read_all()
except Exception as e:
print(arrow_file, e)
raise e
length = len(table)
if len(table) == 0:
print(f"Warning: Empty table: {arrow_file}")
indices = []
stats = {}
else:
# Apply filter_fn if available
if filter_fn is not None:
mask, stats, md5s = filter_fn(arrow_file, table)
else:
mask = pd.Series([True] * length)
stats = {}
md5s = None
# Apply callback function if available
if callback is not None:
mask, stats = callback(arrow_file, table, mask, stats, md5s)
# Get indices
if mask is not None:
indices = np.where(mask)[0].tolist()
else:
indices = list(range(length))
# Apply indices repeat
if repeat_fn is not None:
indices, repeat_stats = repeat_fn(
arrow_file, table, indices, repeat_times, md5s
)
stats.update(repeat_stats)
return arrow_file, length, indices, stats
def load_md5_files(files, name=None):
if isinstance(files, str):
files = [files]
md5s = set()
for file in files:
md5s.update(Path(file).read_text().splitlines())
print(f" {name} md5s: {len(md5s):,}")
return md5s
def load_md52cls_files(files, name=None):
if isinstance(files, str):
files = [files]
md52cls = {}
for file in files:
with Path(file).open() as f:
md52cls.update(json.load(f))
print(f" {name} md52cls: {len(md52cls):,}")
return md52cls
def merge_and_build_index(data_type, src, dconfig, save_path):
if isinstance(src, str):
files = list(sorted(glob(src)))
else:
files = list(sorted(src))
print(f"Found {len(files):,} temp pickle files.")
for fname in files:
print(f" {fname}")
arrow_files = []
table_lengths = []
indices_list = []
bad_stats_total = defaultdict(int)
total_indices = 0
total_processed_length = 0
for file_name in tqdm(files):
with Path(file_name).open("rb") as f:
data = pickle.load(f)
for arrow_file, table_length, indices, *args in tqdm(data, leave=False):
arrow_files.append(arrow_file)
table_lengths.append(table_length)
total_processed_length += table_length
indices_list.append(indices)
total_indices += len(indices)
if len(args) > 0 and args[0]:
bad_stats = args[0]
for k, v in bad_stats.items():
bad_stats_total[k] += v
if len(bad_stats_total):
stats_save_dir = Path(save_path).parent
stats_save_dir.mkdir(parents=True, exist_ok=True)
stats_save_path = stats_save_dir / (Path(save_path).stem + "_stats.txt")
stats_save_path.write_text(
"\n".join([f"{k:>50s} {v}" for k, v in bad_stats_total.items()]) + "\n"
)
print(f"Save stats to {stats_save_path}")
print(f"Arrow files: {len(arrow_files):,}")
print(f"Processed indices: {total_processed_length:,}")
print(f"Valid indices: {total_indices:,}")
cum_length = 0
total_indices = []
cum_lengths = []
group_lengths = []
existed = set()
print(f"Accumulating indices...")
pbar = tqdm(
zip(arrow_files, table_lengths, indices_list),
total=len(arrow_files),
mininterval=1,
)
_count = 0
for arrow_file, table_length, indices in pbar:
if len(indices) > 0 and dconfig.remove_md5_dup:
new_indices = []
table = get_table(arrow_file)
if "md5" not in table.column_names:
raise ValueError(
f"Column 'md5' not found in {arrow_file}. "
f"When `remove_md5_dup: true` is set, md5 column is required."
)
md5s = table["md5"].to_pandas()
for i in indices:
md5 = md5s[i]
if md5 in existed:
continue
existed.add(md5)
new_indices.append(i)
indices = new_indices
total_indices.extend([int(i + cum_length) for i in indices])
cum_length += table_length
cum_lengths.append(cum_length)
group_lengths.append(len(indices))
_count += 1
if _count % 100 == 0:
pbar.set_description(f"Indices: {len(total_indices):,}")
builder = IndexV2Builder(
data_type=data_type,
arrow_files=arrow_files,
cum_length=cum_lengths,
group_length=group_lengths,
indices=total_indices,
config_file=dconfig.config_file,
)
builder.build(save_path)
print(
f"Build index finished!\n\n"
f" Save path: {Path(save_path).absolute()}\n"
f" Number of indices: {len(total_indices)}\n"
f"Number of arrow files: {len(arrow_files)}\n"
)
def worker_startup(rank, world_size, dconfig, prefix, work_dir, callback=None):
# Prepare names for this worker
num = (len(dconfig.names) + world_size - 1) // world_size
arrow_names = dconfig.names[rank * num : (rank + 1) * num]
print(f"Rank {rank} has {len(arrow_names):,} names.")
# Run get indices
print(f"Start getting indices...")
indices = []
for arrow_name, repeat_times in tqdm(
arrow_names, position=rank, desc=f"#{rank}: ", leave=False
):
indices.append(
get_indices(
arrow_name, repeat_times, dconfig.filter, dconfig.repeater, callback
)
)
# Save to a temp file
temp_save_path = (
work_dir / f"data/temp_pickles/{prefix}-{rank + 1}_of_{world_size}.pkl"
)
temp_save_path.parent.mkdir(parents=True, exist_ok=True)
with temp_save_path.open("wb") as f:
pickle.dump(indices, f)
print(f"Rank {rank} finished. Write temporary data to {temp_save_path}")
return temp_save_path
def startup(
config_file,
save,
world_size=1,
work_dir=".",
callback=None,
use_cache=False,
):
work_dir = Path(work_dir)
save_path = Path(save)
if save_path.suffix != ".json":
save_path = save_path.parent / (save_path.name + ".json")
print(f"Using save_path: {save_path}")
prefix = f"{save_path.stem}"
# Parse dataset config and build the data_type list
dconfig = DatasetConfig(work_dir, config_file)
data_type = []
for k, v in dconfig.data_type.items():
data_type.extend(v)
print(f"{k}:")
for x in v:
print(f" {x}")
if dconfig.remove_md5_dup:
data_type.append("Remove md5 duplicates.")
else:
data_type.append("Keep md5 duplicates.")
# Start processing
if not use_cache:
temp_pickles = []
if world_size == 1:
print(f"\nRunning in single process mode...")
temp_pickles.append(
worker_startup(
rank=0,
world_size=1,
dconfig=dconfig,
prefix=prefix,
work_dir=work_dir,
callback=callback,
)
)
else:
print(f"\nRunning in multi-process mode (world_size={world_size})...")
p = Pool(world_size)
temp_pickles_ = []
for i in range(world_size):
temp_pickles_.append(
p.apply_async(
worker_startup,
args=(i, world_size, dconfig, prefix, work_dir, callback),
)
)
for res in temp_pickles_:
temp_pickles.append(res.get())
# close
p.close()
p.join()
else:
temp_pickles = glob(
f"{work_dir}/data/temp_pickles/{prefix}-*_of_{world_size}.pkl"
)
# Merge temp pickles and build index
merge_and_build_index(
data_type,
temp_pickles,
dconfig,
save_path,
)
def make_multireso(
target,
config_file=None,
src=None,
base_size=None,
reso_step=None,
target_ratios=None,
align=None,
min_size=None,
md5_file=None,
):
if config_file is not None:
with Path(config_file).open() as f:
config = yaml.safe_load(f)
else:
config = {}
src = config.get("src", src)
base_size = config.get("base_size", base_size)
reso_step = config.get("reso_step", reso_step)
target_ratios = config.get("target_ratios", target_ratios)
align = config.get("align", align)
min_size = config.get("min_size", min_size)
md5_file = config.get("md5_file", md5_file)
if src is None:
raise ValueError("src must be provided in either config file or command line.")
if base_size is None:
raise ValueError("base_size must be provided.")
if reso_step is None and target_ratios is None:
raise ValueError("Either reso_step or target_ratios must be provided.")
if md5_file is not None:
with open(md5_file, "rb") as f:
md5_hw = pickle.load(f)
print(f"Md5 to height and width: {len(md5_hw):,}")
else:
md5_hw = None
build_multi_resolution_bucket(
config_file=config_file,
base_size=base_size,
reso_step=reso_step,
target_ratios=target_ratios,
align=align,
min_size=min_size,
src_index_files=src,
save_file=target,
md5_hw=md5_hw,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment