Commit 109f0842 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2847 failed with stages
in 0 seconds
import dotenv
dotenv.load_dotenv(override=True)
import argparse
import os
from typing import List, Tuple
from PIL import Image, ImageOps
import torch
from torchvision.transforms.functional import to_pil_image, to_tensor
from accelerate import Accelerator
from diffusers.hooks import apply_group_offloading
from omnigen2.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline
from omnigen2.models.transformers.transformer_omnigen2 import OmniGen2Transformer2DModel
def parse_args() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="OmniGen2 image generation script.")
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Path to model checkpoint.",
)
parser.add_argument(
"--transformer_path",
type=str,
default=None,
help="Path to transformer checkpoint.",
)
parser.add_argument(
"--transformer_lora_path",
type=str,
default=None,
help="Path to transformer LoRA checkpoint.",
)
parser.add_argument(
"--scheduler",
type=str,
default="euler",
choices=["euler", "dpmsolver++"],
help="Scheduler to use.",
)
parser.add_argument(
"--num_inference_step",
type=int,
default=50,
help="Number of inference steps."
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="Random seed for generation."
)
parser.add_argument(
"--height",
type=int,
default=1024,
help="Output image height."
)
parser.add_argument(
"--width",
type=int,
default=1024,
help="Output image width."
)
parser.add_argument(
"--max_input_image_pixels",
type=int,
default=1048576,
help="Maximum number of pixels for each input image."
)
parser.add_argument(
"--dtype",
type=str,
default='bf16',
choices=['fp32', 'fp16', 'bf16'],
help="Data type for model weights."
)
parser.add_argument(
"--text_guidance_scale",
type=float,
default=5.0,
help="Text guidance scale."
)
parser.add_argument(
"--image_guidance_scale",
type=float,
default=2.0,
help="Image guidance scale."
)
parser.add_argument(
"--cfg_range_start",
type=float,
default=0.0,
help="Start of the CFG range."
)
parser.add_argument(
"--cfg_range_end",
type=float,
default=1.0,
help="End of the CFG range."
)
parser.add_argument(
"--instruction",
type=str,
default="A dog running in the park",
help="Text prompt for generation."
)
parser.add_argument(
"--negative_prompt",
type=str,
default="(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar",
help="Negative prompt for generation."
)
parser.add_argument(
"--input_image_path",
type=str,
nargs='+',
default=None,
help="Path(s) to input image(s)."
)
parser.add_argument(
"--output_image_path",
type=str,
default="output.png",
help="Path to save output image."
)
parser.add_argument(
"--num_images_per_prompt",
type=int,
default=1,
help="Number of images to generate per prompt."
)
parser.add_argument(
"--enable_model_cpu_offload",
action="store_true",
help="Enable model CPU offload."
)
parser.add_argument(
"--enable_sequential_cpu_offload",
action="store_true",
help="Enable sequential CPU offload."
)
parser.add_argument(
"--enable_group_offload",
action="store_true",
help="Enable group offload."
)
parser.add_argument(
"--enable_teacache",
action="store_true",
help="Enable teacache to speed up inference."
)
parser.add_argument(
"--teacache_rel_l1_thresh",
type=float,
default=0.05,
help="Relative L1 threshold for teacache."
)
parser.add_argument(
"--enable_taylorseer",
action="store_true",
help="Enable TaylorSeer Caching."
)
return parser.parse_args()
def load_pipeline(args: argparse.Namespace, accelerator: Accelerator, weight_dtype: torch.dtype) -> OmniGen2Pipeline:
pipeline = OmniGen2Pipeline.from_pretrained(
args.model_path,
torch_dtype=weight_dtype,
trust_remote_code=True,
)
if args.transformer_path:
print(f"Transformer weights loaded from {args.transformer_path}")
pipeline.transformer = OmniGen2Transformer2DModel.from_pretrained(
args.transformer_path,
torch_dtype=weight_dtype,
)
else:
pipeline.transformer = OmniGen2Transformer2DModel.from_pretrained(
args.model_path,
subfolder="transformer",
torch_dtype=weight_dtype,
)
if args.transformer_lora_path:
print(f"LoRA weights loaded from {args.transformer_lora_path}")
pipeline.load_lora_weights(args.transformer_lora_path)
if args.enable_teacache and args.enable_taylorseer:
print("WARNING: enable_teacache and enable_taylorseer are mutually exclusive. enable_teacache will be ignored.")
if args.enable_taylorseer:
pipeline.enable_taylorseer = True
elif args.enable_teacache:
pipeline.transformer.enable_teacache = True
pipeline.transformer.teacache_rel_l1_thresh = args.teacache_rel_l1_thresh
if args.scheduler == "dpmsolver++":
from omnigen2.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
scheduler = DPMSolverMultistepScheduler(
algorithm_type="dpmsolver++",
solver_type="midpoint",
solver_order=2,
prediction_type="flow_prediction",
)
pipeline.scheduler = scheduler
if args.enable_sequential_cpu_offload:
pipeline.enable_sequential_cpu_offload()
elif args.enable_model_cpu_offload:
pipeline.enable_model_cpu_offload()
elif args.enable_group_offload:
apply_group_offloading(pipeline.transformer, onload_device=accelerator.device, offload_type="block_level", num_blocks_per_group=2, use_stream=True)
apply_group_offloading(pipeline.mllm, onload_device=accelerator.device, offload_type="block_level", num_blocks_per_group=2, use_stream=True)
apply_group_offloading(pipeline.vae, onload_device=accelerator.device, offload_type="block_level", num_blocks_per_group=2, use_stream=True)
else:
pipeline = pipeline.to(accelerator.device)
return pipeline
def preprocess(input_image_path: List[str] = []) -> Tuple[str, str, List[Image.Image]]:
"""Preprocess the input images."""
# Process input images
input_images = None
if input_image_path:
input_images = []
if isinstance(input_image_path, str):
input_image_path = [input_image_path]
if len(input_image_path) == 1 and os.path.isdir(input_image_path[0]):
input_images = [Image.open(os.path.join(input_image_path[0], f)).convert("RGB")
for f in os.listdir(input_image_path[0])]
else:
input_images = [Image.open(path).convert("RGB") for path in input_image_path]
input_images = [ImageOps.exif_transpose(img) for img in input_images]
return input_images
def run(args: argparse.Namespace,
accelerator: Accelerator,
pipeline: OmniGen2Pipeline,
instruction: str,
negative_prompt: str,
input_images: List[Image.Image]) -> Image.Image:
"""Run the image generation pipeline with the given parameters."""
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
results = pipeline(
prompt=instruction,
input_images=input_images,
width=args.width,
height=args.height,
num_inference_steps=args.num_inference_step,
max_sequence_length=1024,
text_guidance_scale=args.text_guidance_scale,
image_guidance_scale=args.image_guidance_scale,
cfg_range=(args.cfg_range_start, args.cfg_range_end),
negative_prompt=negative_prompt,
num_images_per_prompt=args.num_images_per_prompt,
generator=generator,
output_type="pil",
)
return results
def create_collage(images: List[torch.Tensor]) -> Image.Image:
"""Create a horizontal collage from a list of images."""
max_height = max(img.shape[-2] for img in images)
total_width = sum(img.shape[-1] for img in images)
canvas = torch.zeros((3, max_height, total_width), device=images[0].device)
current_x = 0
for img in images:
h, w = img.shape[-2:]
canvas[:, :h, current_x:current_x+w] = img * 0.5 + 0.5
current_x += w
return to_pil_image(canvas)
def main(args: argparse.Namespace, root_dir: str) -> None:
"""Main function to run the image generation process."""
# Initialize accelerator
accelerator = Accelerator(mixed_precision=args.dtype if args.dtype != 'fp32' else 'no')
# Set weight dtype
weight_dtype = torch.float32
if args.dtype == 'fp16':
weight_dtype = torch.float16
elif args.dtype == 'bf16':
weight_dtype = torch.bfloat16
# Load pipeline and process inputs
pipeline = load_pipeline(args, accelerator, weight_dtype)
input_images = preprocess(args.input_image_path)
# Generate and save image
results = run(args, accelerator, pipeline, args.instruction, args.negative_prompt, input_images)
os.makedirs(os.path.dirname(args.output_image_path), exist_ok=True)
if len(results.images) > 1:
for i, image in enumerate(results.images):
image_name, ext = os.path.splitext(args.output_image_path)
image.save(f"{image_name}_{i}{ext}")
vis_images = [to_tensor(image) * 2 - 1 for image in results.images]
output_image = create_collage(vis_images)
output_image.save(args.output_image_path)
print(f"Image saved to {args.output_image_path}")
if __name__ == "__main__":
root_dir = os.path.abspath(os.path.join(__file__, os.path.pardir))
args = parse_args()
main(args, root_dir)
\ No newline at end of file
import dotenv
dotenv.load_dotenv(override=True)
import argparse
import os
from typing import List, Tuple
from PIL import Image
import torch
from torchvision.transforms.functional import to_pil_image, to_tensor
from accelerate import Accelerator
from diffusers.hooks import apply_group_offloading
from omnigen2.pipelines.omnigen2.pipeline_omnigen2_chat import OmniGen2ChatPipeline
from omnigen2.models.transformers.transformer_omnigen2 import OmniGen2Transformer2DModel
def parse_args() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="OmniGen2 image generation script.")
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Path to model checkpoint.",
)
parser.add_argument(
"--scheduler",
type=str,
default="euler",
choices=["euler", "dpmsolver++"],
help="Scheduler to use.",
)
parser.add_argument(
"--num_inference_step",
type=int,
default=50,
help="Number of inference steps."
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="Random seed for generation."
)
parser.add_argument(
"--height",
type=int,
default=1024,
help="Output image height."
)
parser.add_argument(
"--width",
type=int,
default=1024,
help="Output image width."
)
parser.add_argument(
"--max_input_image_pixels",
type=int,
default=1048576,
help="Maximum number of pixels for each input image."
)
parser.add_argument(
"--dtype",
type=str,
default='bf16',
choices=['fp32', 'fp16', 'bf16'],
help="Data type for model weights."
)
parser.add_argument(
"--text_guidance_scale",
type=float,
default=5.0,
help="Text guidance scale."
)
parser.add_argument(
"--image_guidance_scale",
type=float,
default=2.0,
help="Image guidance scale."
)
parser.add_argument(
"--cfg_range_start",
type=float,
default=0.0,
help="Start of the CFG range."
)
parser.add_argument(
"--cfg_range_end",
type=float,
default=1.0,
help="End of the CFG range."
)
parser.add_argument(
"--instruction",
type=str,
default="A dog running in the park",
help="Text prompt for generation."
)
parser.add_argument(
"--negative_prompt",
type=str,
default="(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar",
help="Negative prompt for generation."
)
parser.add_argument(
"--input_image_path",
type=str,
nargs='+',
default=None,
help="Path(s) to input image(s)."
)
parser.add_argument(
"--output_image_path",
type=str,
default="output.png",
help="Path to save output image."
)
parser.add_argument(
"--num_images_per_prompt",
type=int,
default=1,
help="Number of images to generate per prompt."
)
parser.add_argument(
"--enable_sequential_cpu_offload",
action="store_true",
help="Enable sequential CPU offload."
)
parser.add_argument(
"--enable_model_cpu_offload",
action="store_true",
help="Enable model CPU offload."
)
parser.add_argument(
"--enable_group_offload",
action="store_true",
help="Enable group offload."
)
return parser.parse_args()
def load_pipeline(args: argparse.Namespace, accelerator: Accelerator, weight_dtype: torch.dtype) -> OmniGen2ChatPipeline:
pipeline = OmniGen2ChatPipeline.from_pretrained(
args.model_path,
torch_dtype=weight_dtype,
trust_remote_code=True,
)
pipeline.transformer = OmniGen2Transformer2DModel.from_pretrained(
args.model_path,
subfolder="transformer",
torch_dtype=weight_dtype,
)
if args.scheduler == "dpmsolver":
from omnigen2.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
scheduler = DPMSolverMultistepScheduler(
algorithm_type="dpmsolver++",
solver_type="midpoint",
solver_order=2,
prediction_type="flow_prediction",
)
pipeline.scheduler = scheduler
if args.enable_sequential_cpu_offload:
pipeline.enable_sequential_cpu_offload()
elif args.enable_model_cpu_offload:
pipeline.enable_model_cpu_offload()
elif args.enable_group_offload:
apply_group_offloading(pipeline.transformer, onload_device=accelerator.device, offload_type="block_level", num_blocks_per_group=2, use_stream=True)
apply_group_offloading(pipeline.mllm, onload_device=accelerator.device, offload_type="block_level", num_blocks_per_group=2, use_stream=True)
apply_group_offloading(pipeline.vae, onload_device=accelerator.device, offload_type="block_level", num_blocks_per_group=2, use_stream=True)
else:
pipeline = pipeline.to(accelerator.device)
return pipeline
def preprocess(input_image_path: List[str] = []) -> Tuple[str, str, List[Image.Image]]:
"""Preprocess the input images."""
# Process input images
input_images = None
if input_image_path:
input_images = []
if isinstance(input_image_path, str):
input_image_path = [input_image_path]
if len(input_image_path) == 1 and os.path.isdir(input_image_path[0]):
input_images = [Image.open(os.path.join(input_image_path[0], f)).convert("RGB")
for f in os.listdir(input_image_path[0])]
else:
input_images = [Image.open(path).convert("RGB") for path in input_image_path]
return input_images
def run(args: argparse.Namespace,
accelerator: Accelerator,
pipeline: OmniGen2ChatPipeline,
instruction: str,
negative_prompt: str,
input_images: List[Image.Image]) -> Image.Image:
"""Run the image generation pipeline with the given parameters."""
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
results = pipeline(
prompt=instruction,
input_images=input_images,
width=args.width,
height=args.height,
num_inference_steps=args.num_inference_step,
max_sequence_length=1024,
text_guidance_scale=args.text_guidance_scale,
image_guidance_scale=args.image_guidance_scale,
cfg_range=(args.cfg_range_start, args.cfg_range_end),
negative_prompt=negative_prompt,
num_images_per_prompt=args.num_images_per_prompt,
generator=generator,
output_type="pil",
)
return results
def create_collage(images: List[torch.Tensor]) -> Image.Image:
"""Create a horizontal collage from a list of images."""
max_height = max(img.shape[-2] for img in images)
total_width = sum(img.shape[-1] for img in images)
canvas = torch.zeros((3, max_height, total_width), device=images[0].device)
current_x = 0
for img in images:
h, w = img.shape[-2:]
canvas[:, :h, current_x:current_x+w] = img * 0.5 + 0.5
current_x += w
return to_pil_image(canvas)
def main(args: argparse.Namespace, root_dir: str) -> None:
"""Main function to run the image generation process."""
# Initialize accelerator
accelerator = Accelerator(mixed_precision=args.dtype if args.dtype != 'fp32' else 'no')
# Set weight dtype
weight_dtype = torch.float32
if args.dtype == 'fp16':
weight_dtype = torch.float16
elif args.dtype == 'bf16':
weight_dtype = torch.bfloat16
# Load pipeline and process inputs
pipeline = load_pipeline(args, accelerator, weight_dtype)
input_images = preprocess(args.input_image_path)
# Generate and save image
results = run(args, accelerator, pipeline, args.instruction, args.negative_prompt, input_images)
if results.images is not None:
vis_images = [to_tensor(image) * 2 - 1 for image in results.images]
output_image = create_collage(vis_images)
os.makedirs(os.path.dirname(args.output_image_path), exist_ok=True)
output_image.save(args.output_image_path)
print(f"Image saved to {args.output_image_path}")
print(f"Text: {results.text}")
if __name__ == "__main__":
root_dir = os.path.abspath(os.path.join(__file__, os.path.pardir))
args = parse_args()
main(args, root_dir)
# 模型编码
modelCode=1670
# 模型名称
modelName=OmniGen2_pytorch
# 模型描述
modelDescription=引入反思机制,多模态任务生成屠榜,一键解锁AI绘图「哆啦 A 梦」任意门。
# 应用场景
appScenario=推理,多模态,制造,广媒,金融,能源,医疗,家居,教育
# 框架类型
frameType=pytorch
# OmniContext
As part of OmniGen2, we introduce a new benchmark for in-context generation, **OmniContext**, which aims to provide a more comprehensive evaluation of models' in-context generation abilities. It incorporates a diverse set of input images and instructions, and utilizes GPT-4.1 for interpretable, metric-driven assessment.
<p align="center">
<img src="../assets/omnicontext_overview.png" width="95%">
<br>
<em>Overview of OmniContext benchmark.</em>
</p>
<p align="center">
<img src="../assets/omnicontext_evaluation.png" width="95%">
<br>
<em>An illustrative evaluation case in the OmniContext benchmark.</em>
</p>
The evaluation of the OmniContext benchmark includes the following steps:
## Step1 Environment Setup
```bash
# 1. Activate Python environment
conda activate omnigen2
# 2. Install dependencies
pip install -U datasets megfile
```
## Step2 Generate Images
Note: we fix the resolution of the output images at 1024 × 1024 to ensure that the settings are consistent across different models.
You may try generating results using OmniGen2 or other models; please ensure that the output image directory structure and format are consistent with the format specified below.
```
results/
├── {method_name}/
│ └── fullset/
│ └── {task_type}/
│ ├── key1.png
│ ├── key2.png
│ └── ...
```
To use OmniGen2, you can run the following script to generate images:
```bash
cd OmniGen2
accelerate launch --num_processes=8 -m omnicontext.inference \
--model_path "OmniGen2/OmniGen2" \
--model_name "OmniGen2" \
--test_data "OmniGen2/OmniContext" \
--result_dir "omnicontext/results" \
--num_inference_step 50 \
--height 1024 \
--width 1024 \
--text_guidance_scale 5.0 \
--image_guidance_scale 2.0 \
--num_images_per_prompt 1 \
--disable_align_res # Align the resolution to the original image when dealing image editing tasks, disable it when dealing in context generation tasks.
```
## Step3 Evaluation
1. We use GPT-4.1 to evaluate the quality of the generated images. Please make sure to set up your API key before running the script.
```bash
cd OmniGen2
openai_key="<Your-API-Key>"
python -m omnicontext.test_omnicontext_score \
--test_data "OmniGen2/OmniContext" \
--result_dir "omnicontext/results" \
--model_name "OmniGen2" \
--openai_key ${openai_key} \
--max_workers 100
```
2. Next, calculate the final score:
```bash
python -m omnicontext.calculate_statistics \
--save_path "omnicontext/results" \
--model_name "OmniGen2" \
--backbone gpt4dot1
```
## Acknowledgements
The code structure of this benchmark is inspired by [Step1X-Edit](https://github.com/stepfun-ai/Step1X-Edit).
Special thanks to the original project for their valuable contribution.
\ No newline at end of file
import megfile
import os
import pandas as pd
from collections import defaultdict
import sys
import numpy as np
import math
import json
import glob
def analyze_scores(json_lines, language):
group_prompt_following_scores = {}
group_subject_consistency_scores = {}
group_overall_scores = {}
for task_type in json_lines.keys():
prompt_following_scores = []
subject_consistency_scores = []
overall_scores = []
for json_line in json_lines[task_type]:
if json_line['instruction_language'] != language:
continue
prompt_following_score = json_line['PF_score']
subject_consistency_score = json_line['SC_score']
overall_score = math.sqrt(prompt_following_score * subject_consistency_score)
prompt_following_scores.append(prompt_following_score)
subject_consistency_scores.append(subject_consistency_score)
overall_scores.append(overall_score)
group_prompt_following_scores[task_type] = np.mean(prompt_following_scores)
group_subject_consistency_scores[task_type] = np.mean(subject_consistency_scores)
group_overall_scores[task_type] = np.mean(overall_scores)
return group_prompt_following_scores, group_subject_consistency_scores, group_overall_scores
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--save_path", type=str, default="/results/")
parser.add_argument("--backbone", type=str, default="gpt4dot1")
parser.add_argument("--model_name", type=str, default="OmniGen2")
parser.add_argument("--language", type=str, default="en")
args = parser.parse_args()
result_json_files = glob.glob(os.path.join(args.save_path, args.model_name, args.backbone, "**/*.jsonl"))
print(f"{result_json_files=}")
print(f"{len(result_json_files)=}")
result_json_lines = defaultdict(list)
for result_json_file in result_json_files:
with open(result_json_file, 'r') as f:
for line in f:
data = json.loads(line)
task_type = os.path.basename(os.path.dirname(result_json_file))
result_json_lines[task_type].append(data)
group_prompt_following_scores, group_subject_consistency_scores, group_overall_scores = analyze_scores(result_json_lines, language=args.language)
for task_type in group_prompt_following_scores.keys():
print(f"{task_type}: {group_prompt_following_scores[task_type]:.3f}, {group_subject_consistency_scores[task_type]:.3f}, {group_overall_scores[task_type]:.3f}")
print(f"Average: {np.mean(list(group_prompt_following_scores.values())):.3f}, {np.mean(list(group_subject_consistency_scores.values())):.3f}, {np.mean(list(group_overall_scores.values())):.3f}")
\ No newline at end of file
import dotenv
dotenv.load_dotenv(override=True)
import argparse
import os
import datasets
from tqdm import tqdm
from typing import List, Tuple
from torch.utils.data import DataLoader
from PIL import Image, ImageOps
import torch
from torchvision.transforms.functional import to_pil_image, to_tensor
from accelerate import Accelerator
from diffusers.hooks import apply_group_offloading
from omnigen2.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline
from omnigen2.models.transformers.transformer_omnigen2 import OmniGen2Transformer2DModel
def parse_args() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="OmniGen2 image generation script.")
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Path to model checkpoint.",
)
parser.add_argument(
"--model_name",
type=str,
required=True,
help="Model name for output directory.",
)
parser.add_argument(
"--scheduler",
type=str,
default="euler",
choices=["euler", "dpmsolver"],
help="Scheduler to use.",
)
parser.add_argument(
"--num_inference_step",
type=int,
default=50,
help="Number of inference steps."
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="Random seed for generation."
)
parser.add_argument(
"--height",
type=int,
default=1024,
help="Output image height."
)
parser.add_argument(
"--width",
type=int,
default=1024,
help="Output image width."
)
parser.add_argument(
"--max_input_image_pixels",
type=int,
default=1048576,
help="Maximum number of pixels for each input image."
)
parser.add_argument(
"--dtype",
type=str,
default='bf16',
choices=['fp32', 'fp16', 'bf16'],
help="Data type for model weights."
)
parser.add_argument(
"--text_guidance_scale",
type=float,
default=5.0,
help="Text guidance scale."
)
parser.add_argument(
"--image_guidance_scale",
type=float,
default=2.0,
help="Image guidance scale."
)
parser.add_argument(
"--cfg_range_start",
type=float,
default=0.0,
help="Start of the CFG range."
)
parser.add_argument(
"--cfg_range_end",
type=float,
default=1.0,
help="End of the CFG range."
)
parser.add_argument(
"--negative_prompt",
type=str,
default="(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar",
help="Negative prompt for generation."
)
parser.add_argument(
"--test_data",
type=str,
default=None,
help="Path to test data."
)
parser.add_argument(
"--result_dir",
type=str,
default="results",
help="Path to save the generated images."
)
parser.add_argument(
"--num_images_per_prompt",
type=int,
default=1,
help="Number of images to generate per prompt."
)
parser.add_argument(
"--enable_model_cpu_offload",
action="store_true",
help="Enable model CPU offload."
)
parser.add_argument(
"--enable_sequential_cpu_offload",
action="store_true",
help="Enable sequential CPU offload."
)
parser.add_argument(
"--enable_group_offload",
action="store_true",
help="Enable group offload."
)
parser.add_argument(
"--disable_align_res",
action="store_true",
help="Align resolution to the input image resolution."
)
return parser.parse_args()
class Collator:
def __call__(self, features):
return features
def load_pipeline(args: argparse.Namespace, accelerator: Accelerator, weight_dtype: torch.dtype) -> OmniGen2Pipeline:
from transformers import CLIPProcessor
pipeline = OmniGen2Pipeline.from_pretrained(
args.model_path,
processor=CLIPProcessor.from_pretrained(
args.model_path,
subfolder="processor",
use_fast=True
),
torch_dtype=weight_dtype,
trust_remote_code=True,
)
pipeline.transformer = OmniGen2Transformer2DModel.from_pretrained(
args.model_path,
subfolder="transformer",
torch_dtype=weight_dtype,
)
if args.scheduler == "dpmsolver":
from omnigen2.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
scheduler = DPMSolverMultistepScheduler(
algorithm_type="dpmsolver++",
solver_type="midpoint",
solver_order=2,
prediction_type="flow_prediction",
)
pipeline.scheduler = scheduler
if args.enable_sequential_cpu_offload:
pipeline.enable_sequential_cpu_offload()
elif args.enable_model_cpu_offload:
pipeline.enable_model_cpu_offload()
elif args.enable_group_offload:
apply_group_offloading(pipeline.transformer, onload_device=accelerator.device, offload_type="block_level", num_blocks_per_group=2, use_stream=True)
apply_group_offloading(pipeline.mllm, onload_device=accelerator.device, offload_type="block_level", num_blocks_per_group=2, use_stream=True)
apply_group_offloading(pipeline.vae, onload_device=accelerator.device, offload_type="block_level", num_blocks_per_group=2, use_stream=True)
else:
pipeline = pipeline.to(accelerator.device)
return pipeline
def preprocess(input_image_path: List[str] = []) -> Tuple[str, str, List[Image.Image]]:
"""Preprocess the input images."""
# Process input images
input_images = None
if input_image_path:
input_images = []
if isinstance(input_image_path, str):
input_image_path = [input_image_path]
if len(input_image_path) == 1 and os.path.isdir(input_image_path[0]):
input_images = [Image.open(os.path.join(input_image_path[0], f)).convert("RGB")
for f in os.listdir(input_image_path[0])]
else:
input_images = [Image.open(path).convert("RGB") for path in input_image_path]
input_images = [ImageOps.exif_transpose(img) for img in input_images]
return input_images
def run(args: argparse.Namespace,
accelerator: Accelerator,
pipeline: OmniGen2Pipeline,
instruction: str,
negative_prompt: str,
input_images: List[Image.Image]) -> Image.Image:
"""Run the image generation pipeline with the given parameters."""
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
results = pipeline(
prompt=instruction,
input_images=input_images,
width=args.width,
height=args.height,
align_res=not args.disable_align_res,
num_inference_steps=args.num_inference_step,
max_sequence_length=1024,
text_guidance_scale=args.text_guidance_scale,
image_guidance_scale=args.image_guidance_scale,
cfg_range=(args.cfg_range_start, args.cfg_range_end),
negative_prompt=negative_prompt,
num_images_per_prompt=args.num_images_per_prompt,
generator=generator,
output_type="pil",
)
return results
def create_collage(images: List[torch.Tensor]) -> Image.Image:
"""Create a horizontal collage from a list of images."""
max_height = max(img.shape[-2] for img in images)
total_width = sum(img.shape[-1] for img in images)
canvas = torch.zeros((3, max_height, total_width), device=images[0].device)
current_x = 0
for img in images:
h, w = img.shape[-2:]
canvas[:, :h, current_x:current_x+w] = img * 0.5 + 0.5
current_x += w
return to_pil_image(canvas)
def main(args: argparse.Namespace, root_dir: str) -> None:
"""Main function to run the image generation process."""
# Initialize accelerator
accelerator = Accelerator(mixed_precision=args.dtype if args.dtype != 'fp32' else 'no')
test_dataset = datasets.load_dataset(args.test_data, split="train")
print('test_dataset', test_dataset)
loader = DataLoader(
test_dataset,
collate_fn=Collator(),
batch_size=1,
shuffle=True,
# shuffle=False,
pin_memory=False,
drop_last=False,
)
loader = accelerator.prepare(loader)
# Set weight dtype
weight_dtype = torch.float32
if args.dtype == 'fp16':
weight_dtype = torch.float16
elif args.dtype == 'bf16':
weight_dtype = torch.bfloat16
# Load pipeline and process inputs
pipeline = load_pipeline(args, accelerator, weight_dtype)
with tqdm(
total=len(loader),
desc="Generating images...",
unit="image",
disable=not accelerator.is_main_process,
) as pbar:
for i, bacthed_data in tqdm(enumerate(loader), total=len(loader), disable=accelerator.process_index!=0):
for data in bacthed_data:
key = data['key']
task_type = data['task_type']
instruction = data['instruction']
input_images = data['input_images']
input_images = [ImageOps.exif_transpose(img) for img in input_images]
# Generate and save image
results = run(args, accelerator, pipeline, instruction, args.negative_prompt, input_images)
sub_dir = os.path.join(args.result_dir, args.model_name, "fullset", task_type)
os.makedirs(sub_dir, exist_ok=True)
output_image_path = os.path.join(sub_dir, f"{key}.png")
if os.path.exists(output_image_path):
continue
if len(results.images) > 1:
for i, image in enumerate(results.images):
image_name, ext = os.path.splitext(output_image_path)
image.save(f"{image_name}_{i}{ext}")
vis_images = [to_tensor(image) * 2 - 1 for image in results.images]
output_image = create_collage(vis_images)
output_image.save(output_image_path)
pbar.update(1)
if __name__ == "__main__":
root_dir = os.path.abspath(os.path.join(__file__, os.path.pardir))
args = parse_args()
main(args, root_dir)
\ No newline at end of file
import os
from typing import Union, List, Optional
import json
import regex as re
import ast
import random
def fix_json(input_str):
# Add double quotes around keys using regex
fixed_str = re.sub(r'(\w+):', r'"\1":', input_str)
# Add double quotes around string values if necessary and wrap int/float values in []
def format_value(match):
key, value, comma = match.groups()
value = value.strip()
# Check if value is an integer or float
if re.match(r'^-?\d+(\.\d+)?$', value):
value = f'[{value}]'
# Check if value is a boolean or null
elif re.match(r'^(true|false|null)$', value, re.IGNORECASE):
pass # leave as is
else:
# Add quotes around string values
value = f'"{value}"'
return f'{key}: {value}{comma}'
fixed_str = re.sub(r'(".*?"):(.*?)(,|})', format_value, fixed_str)
return fixed_str
def read_file_to_string(file_path):
"""
Reads the contents of a text file and returns it as a string.
:param file_path: The path to the text file.
:return: A string containing the contents of the file.
"""
try:
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
except FileNotFoundError:
print(f"The file {file_path} was not found.")
return None
except Exception as e:
print(f"An error occurred: {e}")
return None
def read_files_to_string(file_paths):
"""
Reads the contents of multiple text files and returns them as a single string,
with each file's contents separated by a newline.
:param file_paths: A list of paths to text files.
:return: A string containing the concatenated contents of the files.
"""
all_contents = [] # List to hold the contents of each file
for file_path in file_paths:
try:
with open(file_path, 'r', encoding='utf-8') as file:
all_contents.append(file.read())
except FileNotFoundError:
print(f"The file {file_path} was not found.")
except Exception as e:
print(f"An error occurred while reading {file_path}: {e}")
# Join all the contents with a newline character
return "\n".join(all_contents)
def get_file_path(filename: Union[str, os.PathLike], search_from: Union[str, os.PathLike] = "."):
"""
Search for a file across a directory and return its absolute path.
Args:
filename (Union[str, os.PathLike]): The name of the file to search for.
search_from (Union[str, os.PathLike], optional): The directory from which to start the search. Defaults to ".".
Returns:
str: Absolute path to the found file.
Raises:
FileNotFoundError: If the file is not found.
"""
for root, dirs, files in os.walk(search_from):
for name in files:
if name == filename:
return os.path.abspath(os.path.join(root, name))
raise FileNotFoundError(filename, "not found.")
#+=========================================================================================
def verify(s, target_sequence):
# Count the occurrences of the target sequence
count = s.count(target_sequence)
# Check if the target sequence appears exactly twice
return count == 2
def is_int_between_0_and_10(s):
try:
num = int(s)
return 0 <= num <= 10
except ValueError:
return False
def is_str_a_list_of_ints_0_to_10(s):
try:
# Attempt to parse the string as a Python literal (list, dict, etc.)
parsed = ast.literal_eval(s)
# Check if the parsed object is a list
if not isinstance(parsed, list):
return False
# Check if all elements are integers and between 0 to 10
return all(isinstance(item, int) and 0 <= item <= 10 for item in parsed)
except (ValueError, SyntaxError):
# If parsing fails or any other error occurs
return False
def is_str_valid_score_format_brackets(s):
try:
# Removing brackets and splitting the string by commas
content = s.strip("[]").split(',')
length = len(content)
# Parsing each element and checking the format and range
scores = {}
for item in content:
key, value = item.split(':')
key = key.strip()
value = int(value.strip())
# Check if the key starts with 'score' and the value is in the correct range
if not key.startswith("score") or not 0 <= value <= 10:
return False
scores[key] = value
fetch_words = [f"score{i+1}" for i in range(length)]
# Check if at least 'score1' and 'score2' are present
return all(key in scores for key in fetch_words)
except (ValueError, SyntaxError):
# If any parsing error occurs
return False
#+=========================================================================================
def mllm_output_to_dict(input_string, give_up_parsing=False):
"""
Args:
input_string (str): actually the output of the mllm model to be parsed
output_file_name (str): The name of the output file.
"""
# Catch for gpt4v rate_limit_exceeded error
if input_string == "rate_limit_exceeded":
return "rate_limit_exceeded"
# find the json mannually
# some mllm tends not to output the delimiters, but it does output the json contents
# so we will find the json content mannually
start_index = input_string.find('{')
end_index = input_string.rfind('}') + 1
if start_index == -1 or end_index == 0:
# json not found
# some mllm tends to output only a list of scores like [6, 0],
# this time we will just get the scores and ignore the reasoning (other part of the json)
start_index = input_string.find('[')
end_index = input_string.rfind(']') + 1
if is_int_between_0_and_10(input_string): # if output is simply a number
score = int(input_string)
json_content = {'score': score, "reasoning": "System: output is simply a number"}
json_str = json.dumps(json_content)
input_string = json_str
start_index = 0
end_index = len(json_str)
else:
raise Exception(f"Failed to find the json content in the string: {input_string=}.")
# Check if we found two delimiters
if start_index != -1 and end_index != -1 and start_index != end_index:
# Extract the JSON string
json_str = input_string[start_index:end_index].strip()
json_str = json_str.replace("\n", "")
# Parse the JSON string into a dictionary
try:
new_data = json.loads(json_str)
except Exception as e:
print(f"Error: {e}")
print("Now fixing: ", json_str)
new_data = json.loads(fix_json(json_str))
return new_data
return new_data
else:
raise Exception(f"The required delimiters were not found correctly in the string: {input_string=}.")
def write_entry_to_json_file(input_string, uid, prompt_input, vision_input, output_file_name, give_up_parsing=False):
"""
Args:
input_string (str): actually the output of the mllm model to be parsed
uid (str): The unique identifier for the each item in the test data
prompt_input (str): The prompt input for the entry. text prompt.
vision_input (str): The vision input for the entry. image links.
output_file_name (str): The name of the output file.
"""
# Catch for gpt4v rate_limit_exceeded error
if input_string == "rate_limit_exceeded":
return "rate_limit_exceeded"
# Define the delimiters
delimiter = '||V^=^V||'
if input_string.count(delimiter) == 2:
if not verify(input_string, delimiter):
print("The required delimiters were not found correctly in the string.")
return False
# Extract the content between the delimiters
start_index = input_string.find(delimiter) + len(delimiter)
end_index = input_string.rfind(delimiter)
else:
# find the json mannually
# some mllm tends not to output the delimiters, but it does output the json contents
# so we will find the json content mannually
start_index = input_string.find('{')
end_index = input_string.rfind('}') + 1
if start_index == -1 or end_index == 0:
# json not found
# some mllm tends to output only a list of scores like [6, 0],
# this time we will just get the scores and ignore the reasoning (other part of the json)
start_index = input_string.find('[')
end_index = input_string.rfind(']') + 1
if give_up_parsing: # if we want to give up parsing
guessed_value = random.randint(0, 10)
print(f"Failed to find the json content in the string. Guess a value : {guessed_value}.")
json_content = {'score': [guessed_value], "reasoning": f"guess_if_cannot_parse | {input_string}"}
json_str = json.dumps(json_content)
input_string = json_str
start_index = 0
end_index = len(json_str)
elif re.match(r'^\[\d+, ?\d+\]$', input_string[start_index:end_index]):
scores = json.loads(input_string[start_index:end_index])
json_content = {'score': scores, "reasoning": None}
json_str = json.dumps(json_content)
input_string = json_str
start_index = 0
end_index = len(json_str)
elif is_int_between_0_and_10(input_string): # if output is simply a number
scores = [int(input_string)]
json_content = {'score': scores, "reasoning": None}
json_str = json.dumps(json_content)
input_string = json_str
start_index = 0
end_index = len(json_str)
else:
print("Failed to find the json content in the string.")
return False
# Check if we found two delimiters
if start_index != -1 and end_index != -1 and start_index != end_index:
# Extract the JSON string
json_str = input_string[start_index:end_index].strip()
json_str = json_str.replace("\n", "")
try:
# Parse the JSON string into a dictionary
new_data = json.loads(json_str)
# Ensure the directory exists
os.makedirs(os.path.dirname(output_file_name), exist_ok=True)
# Initialize or load existing data
if os.path.exists(output_file_name):
with open(output_file_name, 'r') as json_file:
data = json.load(json_file)
else:
data = {}
# If the additional key is already in the data, add or update notes
if uid in data:
data[uid].update(new_data) # Update with new data
if prompt_input: # If there are new notes, update or add them
data[uid]['prompt_input'] = prompt_input
if vision_input: # If there are new notes, update or add them
data[uid]['vision_input'] = vision_input
else:
# If it's a new key, add the entry to the dictionary
data[uid] = new_data
if prompt_input:
data[uid]['prompt_input'] = prompt_input
if vision_input:
data[uid]['vision_input'] = vision_input
# Write the updated data to the file
with open(output_file_name, 'w') as json_file:
json.dump(data, json_file, indent=4)
print(f"Data was successfully updated in {output_file_name}")
return True
except json.JSONDecodeError as e:
print(f"An error occurred while parsing the JSON content: {e}")
return False
else:
print("The required delimiters were not found correctly in the string.")
return False
def check_key_in_json(file_path, key):
try:
with open(file_path, 'r') as json_file:
data = json.load(json_file)
# Check if the key exists at the top level of the JSON structure
if key in data:
return True
else:
return False
except FileNotFoundError:
print(f"The file {file_path} was not found.")
except json.JSONDecodeError as e:
print(f"Error reading {file_path}: {e}")
except Exception as e:
print(f"An error occurred with {file_path}: {e}")
return False
\ No newline at end of file
from .prompt_generator import PromptGenerator
from .openai_util import ask_gpt4o
from .json_util import mllm_output_to_dict
import random
import json
import time
class OmniContextScore:
def __init__(self, openai_url: str, openai_key: str) -> None:
self.openai_url = openai_url
self.openai_key = openai_key
self.prompt_generator = PromptGenerator()
def evaluate(self, input_image_paths, instruction, with_scene=False):
results_dict = {}
max_tries = 3
PF_scores = None
SC_scores = None
for try_idx in range(max_tries):
try:
PF_prompt = self.prompt_generator(instruction, task_type="prompt_following")
SC_prompt = self.prompt_generator(instruction, task_type="subject_consistency", with_scene=with_scene)
PF_results = ask_gpt4o(input_image_paths, PF_prompt, self.openai_url, self.openai_key)
SC_results = ask_gpt4o(input_image_paths, SC_prompt, self.openai_url, self.openai_key)
PF_scores = mllm_output_to_dict(PF_results)
SC_scores = mllm_output_to_dict(SC_results)
if PF_scores == "rate_limit_exceeded" or SC_scores == "rate_limit_exceeded":
raise Exception("rate_limit_exceeded")
except Exception as e:
backoff_time = 2 ** try_idx # Exponential backoff: 1, 2, 4 seconds
print(f"{e}, {instruction=}, Attempt {try_idx+1} failed, retrying after {backoff_time} seconds...")
time.sleep(backoff_time)
if PF_scores is None:
guessed_value = random.randint(0, 10)
print(f"Failed to find the json content in the string for {instruction}. Guess a value : {guessed_value=}.", flush=True)
PF_scores = {'score': guessed_value, "reasoning": f"guess_if_cannot_parse | {PF_results}"}
if SC_scores is None:
guessed_value = random.randint(0, 10)
print(f"Failed to find the json content in the string for {instruction}. Guess a value : {guessed_value=}.", flush=True)
SC_scores = {'score': guessed_value, "reasoning": f"guess_if_cannot_parse | {SC_results}"}
results_dict["PF_scores"] = PF_scores
results_dict["SC_scores"] = SC_scores
return results_dict
\ No newline at end of file
import requests
import base64
from io import BytesIO
from PIL import Image, ImageOps
from typing import Union, Optional, Tuple, List
import os
def encode_pil_image(pil_image):
# Create an in-memory binary stream
image_stream = BytesIO()
# Save the PIL image to the binary stream in JPEG format (you can change the format if needed)
pil_image.save(image_stream, format='JPEG')
# Get the binary data from the stream and encode it as base64
image_data = image_stream.getvalue()
base64_image = base64.b64encode(image_data).decode('utf-8')
return base64_image
def load_image(image: Union[str, Image.Image], format: str = "RGB", size: Optional[Tuple] = None) -> Image.Image:
"""
Load an image from a given path or URL and convert it to a PIL Image.
Args:
image (Union[str, Image.Image]): The image path, URL, or a PIL Image object to be loaded.
format (str, optional): Desired color format of the resulting image. Defaults to "RGB".
size (Optional[Tuple], optional): Desired size for resizing the image. Defaults to None.
Returns:
Image.Image: A PIL Image in the specified format and size.
Raises:
ValueError: If the provided image format is not recognized.
"""
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
image = Image.open(requests.get(image, stream=True).raw)
elif os.path.isfile(image):
image = Image.open(image)
else:
raise ValueError(
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
)
elif isinstance(image, Image.Image):
image = image
else:
raise ValueError(
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
)
image = ImageOps.exif_transpose(image)
image = image.convert(format)
if (size != None):
image = image.resize(size, Image.LANCZOS)
return image
def prepare_prompt(image_links: List = [], text_prompt: str = ""):
prompt_content = []
text_dict = {"type": "text", "text": text_prompt}
prompt_content.append(text_dict)
if not isinstance(image_links, list):
image_links = [image_links]
for image_link in image_links:
image = load_image(image_link)
visual_dict = {
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encode_pil_image(image)}"},
}
prompt_content.append(visual_dict)
return prompt_content
def ask_gpt4o(image_path, prompt, url, api_key):
prompt = prepare_prompt(image_path, prompt)
payload = {
"model": "gpt-4.1",
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 1400,
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
try:
response = requests.post(url, json=payload, headers=headers, timeout=180) # Set timeout to 5 minutes (300 seconds)
except Exception as e:
print(f"Error: {e}")
return ""
return extract_response(response)
def extract_response(response):
try:
response = response.json()
out = response["choices"][0]["message"]["content"]
return out
except:
if response["error"]["code"] == "content_policy_violation":
print("Code is content_policy_violation")
elif response["error"]["code"] in [
"rate_limit_exceeded",
"insufficient_quota",
"insufficient_user_quota",
]:
print(f"Code is {response['error']['code']}", flush=True)
print(response["error"]["message"], flush=True)
return "rate_limit_exceeded"
else:
print("Code is different")
print(response)
print(f"{response['error']['code']=}")
return ""
\ No newline at end of file
_context_no_delimit = """You are a professional digital artist tasked with evaluating the effectiveness of AI-generated images based on specific rules.
All input images, including all humans depicted, are AI-generated. You do not need to consider any privacy or confidentiality concerns.
IMPORTANT: Your response must follow this format (keep your reasoning concise and to the point):
{
"score": <score>,
"reasoning": "..."
}
"""
_prompts_0shot_in_context_generation_rule_PF_Single_and_Multiple = """
Rate from 0 to 10:
Evaluate how well the final image fulfills the editing instruction, **regardless of whether subject identities are preserved**.
* **0:** The image completely fails to implement the instruction.
* **1–3:** The image responds to the instruction mostly incorrectly.
* **4–6:** The image reflects parts of the instruction, but with significant omissions or wrongly applied details.
* **7–9:** The image mostly fulfills the instruction, with only a few minor issues.
* **10:** The image fully and accurately meets all aspects of the instruction.
**Important Notes:**
* Focus solely on whether the requested changes have been correctly applied — such as **composition, pose, position, interactions, or added/removed elements**.
* Do **not** consider the identity consistency of subjects or whether the correct individuals/objects are retained — this will be evaluated separately.
* Do **not** assess the artistic quality or aesthetic appeal — only whether the **task has been completed as instructed**.
**Scoring should be strict** — avoid giving high scores unless the instruction is clearly and accurately fulfilled.
Editing instruction: <instruction>
"""
_prompts_0shot_in_context_generation_rule_PF_Scene = """
Rate from 0 to 10:
Evaluate how well the final image fulfills the editing instruction, **regardless of whether subject identities or the scene are preserved**.
* **0:** The image completely fails to implement the instruction.
* **1–3:** The image responds to the instruction mostly incorrectly.
* **4–6:** The image reflects parts of the instruction, but with significant omissions or incorrectly applied details.
* **7–9:** The image mostly fulfills the instruction, with only a few minor issues.
* **10:** The image fully and accurately meets all aspects of the instruction.
**Important Notes:**
**Scoring should be strict** — avoid giving high scores unless the instruction is clearly and accurately fulfilled.
* Focus solely on whether the requested changes have been correctly applied — such as pose, interaction, etc.
* Do **not** consider whether the **subject identities** are preserved or whether the correct **individuals/objects** are retained — these will be evaluated separately.
* Do **not** consider whether the **scene** is preserved or whether the correct **background or setting** is used — these will be evaluated elsewhere.
* Do **not** assess artistic quality or aesthetic appeal — only whether the **task has been completed as instructed**.
Editing instruction: <instruction>
"""
_prompts_0shot_in_context_generation_rule_SC_Single_and_Multiple = """
Rate from 0 to 10:
Evaluate whether the identities of all subjects in the final image match those of the individuals specified in the original images, as described in the instruction.
**Scoring Criteria:**
* **0:** The subject identities in the image are *completely inconsistent* with those in the reference images.
* **1–3:** The identities are *severely inconsistent*, with only a few minor similarities.
* **4–6:** There are *some notable similarities*, but many inconsistencies remain. This represents a *moderate* level of identity match.
* **7–9:** The identities are *mostly consistent*, with only minor mismatches.
* **10:** The subject identities in the final image are *perfectly consistent* with those in the original images.
**Pay special attention to:**
* Whether **facial and head features** match, including the appearance and placement of eyes, nose, mouth, cheekbones, wrinkles, chin, makeup, hairstyle, hair color, and overall facial structure and head shape.
* Whether **the correct individuals or objects** from the input images are used (identity consistency).
* **Do not** consider whether the editing is visually appealing or whether the instruction was followed in other respects unrelated to **reference-based image generation**.
* Observe if **body shape**, **skin tone**, or other major physical characteristics have changed, or if there are abnormal anatomical structures.
* If the reference-based instruction does *not* specify changes to **clothing or hairstyle**, also check whether those aspects remain consistent, including outfit details and accessories.
**Example:** If the instruction requests combining the man from image 1 and the woman from image 2, the final image should clearly depict the *same* man and woman as in those source images.
**Important:**
* Every time there is a difference, deduct one point.*
* Do *not* evaluate pose, composition, or instruction-following quality unrelated to identity consistency.
* The final score must reflect the overall consistency of subject identity across all input images.
* **Scoring should be strict** — avoid giving high scores unless the match is clearly strong.
Editing instruction: <instruction>
"""
_prompts_0shot_in_context_generation_rule_SC_Scene = """
Rate from 0 to 10:
Evaluate whether the identities of all subjects and the scene background in the final image match those of the individuals specified in the original images, as described in the instruction.
**Scoring Criteria:**
* **0:** The subject identities and scene background in the image are *completely inconsistent* with those in the reference images.
* **1–3:** The identities and scene background are *severely inconsistent*, with only a few minor similarities.
* **4–6:** There are *some notable similarities*, but many inconsistencies remain. This represents a *moderate* level of identity match.
* **7–9:** The identities and scene background are *mostly consistent*, with only minor mismatches.
* **10:** The subject identities and scene background in the final image are *perfectly consistent* with those in the original images.
**Pay special attention to:**
* Whether **facial and head features** match, including the appearance and placement of eyes, nose, mouth, cheekbones, wrinkles, chin, makeup, hairstyle, hair color, and overall facial structure and head shape.
* Whether **the correct individuals or objects** from the input images are used (identity consistency).
* **Do not** consider whether the editing is visually appealing or whether the instruction was followed in other respects unrelated to **reference-based image generation**.
* Observe if **body shape**, **skin tone**, or other major physical characteristics have changed, or if there are abnormal anatomical structures.
* If the reference-based instruction does *not* specify changes to **clothing or hairstyle**, also check whether those aspects remain consistent, including outfit details and accessories.
* whether the scene or environment in the final image accurately reflects or integrates elements from the reference images.
* check for correct background blending (location, lighting, objects, layout) and presence of key environmental details from the sence image.
**Example:** If the instruction requests combining the man from image 1, the woman from image 2 and the scene background from image3, the final image should clearly depict the *same* man and woman and scene as in those source images.
**Important:**
* Every time there is a difference, deduct one point.*
* Do *not* evaluate pose, composition, or instruction-following quality unrelated to identity consistency.
* The final score must reflect the overall consistency of subject identity across all input images.
* **Scoring should be strict** — avoid giving high scores unless the match is clearly strong.
Editing instruction: <instruction>
"""
class PromptGenerator:
def __init__(self):
pass
def __call__(self, input_instruction: str, task_type: str, with_scene=False) -> str:
prompt = _context_no_delimit
if task_type == "prompt_following":
if with_scene:
prompt += _prompts_0shot_in_context_generation_rule_PF_Scene
else:
prompt += _prompts_0shot_in_context_generation_rule_PF_Single_and_Multiple
elif task_type == "subject_consistency":
if with_scene:
prompt += _prompts_0shot_in_context_generation_rule_SC_Scene
else:
prompt += _prompts_0shot_in_context_generation_rule_SC_Single_and_Multiple
else:
raise ValueError(f"Invalid task type: {task_type}")
prompt = prompt.replace("<instruction>", input_instruction)
return prompt
\ No newline at end of file
import math
import datasets
import argparse
import os
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import json
import glob
from collections import defaultdict
from PIL import Image
from omnicontext.omnicontext_score import OmniContextScore
def process_single_item(item, vie_score, max_retries=5):
instruction = item['instruction']
key = item['key']
instruction_language = item['instruction_language']
input_images = item['input_images']
output_image = Image.open(item['output_image_path']).convert("RGB")
ori_img_sizes = [input_image.size for input_image in input_images]
new_img_sizes = []
for ori_img_size in ori_img_sizes:
if ori_img_size[0] * ori_img_size[1] > 1024 * 1024:
ratio = math.sqrt(1024 * 1024 / (ori_img_size[0] * ori_img_size[1]))
new_img_size = (int(ori_img_size[0] * ratio), int(ori_img_size[1] * ratio))
else:
new_img_size = ori_img_size
new_img_size = (new_img_size[0] // 16 * 16, new_img_size[1] // 16 * 16)
new_img_sizes.append(new_img_size)
input_images = [input_image.resize(new_img_size) for input_image, new_img_size in zip(input_images, new_img_sizes)]
result_dict = {
'key': key,
'task_type': item['task_type'],
'instruction': instruction,
'instruction_language': instruction_language,
'output_image_path': item['output_image_path'],
}
if item['task_type'].find('scene') != -1:
with_scene = True
else:
with_scene = False
score_dict = vie_score.evaluate(input_images + [output_image], instruction, with_scene=with_scene)
print(f"{score_dict=}", flush=True)
result_dict['PF_score'] = score_dict['PF_scores']['score']
result_dict['PF_score_reason'] = score_dict['PF_scores']['reasoning']
result_dict['SC_score'] = score_dict['SC_scores']['score']
result_dict['SC_score_reason'] = score_dict['SC_scores']['reasoning']
return result_dict
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--result_dir", type=str, required=True)
parser.add_argument("--openai_url", type=str, default="https://api.openai.com/v1/chat/completions")
parser.add_argument("--openai_key", type=str, required=True)
parser.add_argument("--test_data", type=str, required=True)
parser.add_argument("--model_name", type=str, required=True)
parser.add_argument("--max_workers", type=int, default=100)
args = parser.parse_args()
omnicontext_score = OmniContextScore(args.openai_url, args.openai_key)
test_dataset = datasets.load_dataset(args.test_data, split="train")
sub_datasets = defaultdict(list)
for example in test_dataset:
task_type = example['task_type']
sub_datasets[task_type].append(example)
all_result_list = []
for task_type, sub_data in sub_datasets.items():
result_list = []
json_path = os.path.join(args.result_dir, args.model_name, "gpt4dot1", task_type, "score.jsonl")
if os.path.exists(json_path):
with open(json_path, 'r', encoding='utf-8') as f:
for line in f:
result = json.loads(line)
result_list.append(result)
print(f"Loaded {json_path} for {task_type}, length: {len(result_list)}")
continue
with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
futures = []
for item in sub_data:
key = item["key"]
output_image_path = os.path.join(args.result_dir, args.model_name, "fullset", task_type, f"{key}.png")
item['output_image_path'] = output_image_path
if not os.path.exists(output_image_path):
print(f"Output image not found: {output_image_path}, skip")
continue
future = executor.submit(process_single_item, item, omnicontext_score)
futures.append(future)
for future in tqdm(as_completed(futures), total=len(futures), unit="image", desc=f"Processing {task_type}"):
result = future.result()
if result:
result_list.append(result)
all_result_list.extend(result_list)
# Save group-specific CSV
os.makedirs(os.path.dirname(json_path), exist_ok=True)
with open(json_path, 'w', encoding='utf-8') as f:
for result in result_list:
f.write(json.dumps(result, ensure_ascii=False) + '\n')
print(f"Saved {json_path} for {task_type}, length: {len(result_list)}")
combined_json_path = os.path.join(args.result_dir, args.model_name, "gpt4dot1", "combined_score.jsonl")
os.makedirs(os.path.dirname(combined_json_path), exist_ok=True)
with open(combined_json_path, 'w', encoding='utf-8') as f:
for result in all_result_list:
f.write(json.dumps(result, ensure_ascii=False) + '\n')
\ No newline at end of file
<!-- <h1 align="center">OmniGen2</h1> -->
<p align="center">
<img src="assets/brand.png" width="65%">
</p>
<p align="center">
<a href="https://vectorspacelab.github.io/OmniGen2"><img src="https://img.shields.io/badge/Project%20Page-OmniGen2-yellow" alt="project page"></a>
<a href="https://arxiv.org/abs/2506.18871"><img src="https://img.shields.io/badge/arXiv%20paper-2506.18871-b31b1b.svg" alt="arxiv"></a>
<a href="https://github.com/VectorSpaceLab/OmniGen2?tab=readme-ov-file#-gradio-demo"><img src="https://img.shields.io/badge/Online%20Demo-🤗-blue" alt="demo"></a>
<a href="https://huggingface.co/spaces/OmniGen2/OmniGen2"><img src="https://img.shields.io/badge/HF%20Spaces-🤗-lightblue" alt="demo"></a>
<a href="https://huggingface.co/OmniGen2/OmniGen2"><img src="https://img.shields.io/badge/Model-🤗-yellow" alt="model"></a>
<a href="https://huggingface.co/datasets/OmniGen2/OmniContext"><img src="https://img.shields.io/badge/Benchmark-🤗-yellow" alt="model"></a>
<a href="https://huggingface.co/datasets/OmniGen2/X2I2"><img src="https://img.shields.io/badge/Dataset-🤗-yellow" alt="model"></a>
</p>
<h4 align="center">
<p>
<a href=#-news>News</a> |
<a href=#-quick-start>Quick Start</a> |
<a href=#-usage-tips>Usage Tips</a> |
<a href=#-gradio-demo>Online Demos</a> |
<a href="#heart-citing-us">Citation</a> |
<a href="#license">License</a>
<p>
</h4>
**OmniGen2** is a powerful and efficient unified multimodal model. Its architecture is composed of two key components: a 3B Vision-Language Model (VLM) and a 4B diffusion model. In this design, the frozen 3B VLM ([Qwen-VL-2.5](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)) is responsible for interpreting both visual signals and user instructions, while the 4B diffusion model leverages this understanding to perform high-quality image generation.
This dual-component architecture enables strong performance across four primary capabilities:
- **Visual Understanding**: Inherits the robust ability to interpret and analyze image content from its Qwen-VL-2.5 foundation.
- **Text-to-Image Generation**: Creates high-fidelity and aesthetically pleasing images from textual prompts.
- **Instruction-guided Image Editing**: Executes complex, instruction-based image modifications with high precision, achieving state-of-the-art performance among open-source models.
- **In-context Generation**: A versatile capability to process and flexibly combine diverse inputs—including humans, reference objects, and scenes—to produce novel and coherent visual outputs.
As an open-source project, OmniGen2 provides a powerful yet resource-efficient foundation for researchers and developers exploring the frontiers of controllable and personalized generative AI.
**We will release the training code, dataset, and data construction pipeline soon. Stay tuned!**
<p align="center">
<img src="assets/teaser.png" width="95%">
<br>
<em>Demonstration of OmniGen2's overall capabilities.</em>
</p>
<p align="center">
<img src="assets/examples_edit.png" width="95%">
<br>
<em>Demonstration of OmniGen2's image editing capabilities.</em>
</p>
<p align="center">
<img src="assets/examples_subject.png" width="95%">
<br>
<em>Demonstration of OmniGen2's in-context generation capabilities.</em>
</p>
## 🔥 News
- **2025-07-05**: Training datasets [X2I2](https://huggingface.co/datasets/OmniGen2/X2I2) are available.
- **2025-07-03**: OmniGen2 now supports [TeaCache](https://github.com/ali-vilab/TeaCache) and [TaylorSeer](https://github.com/Shenyi-Z/TaylorSeer) for faster inference, see [Usage Tips](#-usage-tips) for details. Thanks @legitnull for great [TeaCache-PR](https://github.com/VectorSpaceLab/OmniGen2/pull/52) and [TaylorSeer-PR](https://github.com/VectorSpaceLab/OmniGen2/pull/76).
- **2025-07-01**: OmniGen2 is supported by [ComfyUI official](https://comfyanonymous.github.io/ComfyUI_examples/omnigen), thanks !!
- **2025-06-30**: Training code is available, see [fine-tuning](docs/FINETUNE.md) for details.
- **2025-06-28**: We release [OmniContext](https://huggingface.co/datasets/OmniGen2/OmniContext) benchmark. The evaluation codes are in [omnicontext](https://github.com/VectorSpaceLab/OmniGen2/tree/main/omnicontext).
- **2025-06-24**: [Technical Report](https://arxiv.org/abs/2506.18871) is available.
- **2025-06-23**: We’ve updated our code and HF model—OmniGen2 now runs *without* `flash-attn`. Users can still install it for optimal performance.
- **2025-06-20**: Updated [resource requirements](#-resources-requirement), adding CPU offload support for devices with limited VRAM.
- **2025-06-16**: [Gradio](https://github.com/VectorSpaceLab/OmniGen2?tab=readme-ov-file#-gradio-demo) and [Jupyter](https://github.com/VectorSpaceLab/OmniGen2/blob/main/example.ipynb) is available. Online Gradio Demo: [Demo1](https://9c4426d27c3b9ecbed.gradio.live); [Chat-Demo1](https://0351497834a4d7226c.gradio.live); see more demo links in [gradio section](https://github.com/VectorSpaceLab/OmniGen2?tab=readme-ov-file#-gradio-demo)
- **2025-06-16**: We release **OmniGen2**, a multimodal generation model, model weights can be accessed in [huggingface](https://huggingface.co/OmniGen2/OmniGen2) and [modelscope](https://www.modelscope.cn/models/OmniGen2/OmniGen2).
## 📌 TODO
- [x] Technical report.
- [x] Support CPU offload and improve inference efficiency.
- [x] In-context generation benchmark: **OmniContext**.
- [ ] Integration of diffusers.
- [x] Training datasets.
- [ ] Training data construction pipeline.
- [ ] ComfyUI Demo (**commuity support will be greatly appreciated!**).
## 🚀 Quick Start
### 🛠️ Environment Setup
#### ✅ Recommended Setup
```bash
# 1. Clone the repo
git clone git@github.com:VectorSpaceLab/OmniGen2.git
cd OmniGen2
# 2. (Optional) Create a clean Python environment
conda create -n omnigen2 python=3.11
conda activate omnigen2
# 3. Install dependencies
# 3.1 Install PyTorch (choose correct CUDA version)
pip install torch==2.6.0 torchvision --extra-index-url https://download.pytorch.org/whl/cu124
# 3.2 Install other required packages
pip install -r requirements.txt
pip install flash-attn --no-build-isolation
```
#### 🌏 For users in Mainland China
```bash
# Install PyTorch from a domestic mirror
pip install torch==2.6.0 torchvision --index-url https://mirror.sjtu.edu.cn/pytorch-wheels/cu124
# Install other dependencies from Tsinghua mirror
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install flash-attn --no-build-isolation -i https://pypi.tuna.tsinghua.edu.cn/simple
```
---
### 🧪 Run Examples
```bash
# Visual Understanding
bash example_understanding.sh
# Text-to-image generation
bash example_t2i.sh
# Instruction-guided image editing
bash example_edit.sh
# Subject-driven image editing
bash example_subject_driven_edit.sh
```
---
### 🌐 Gradio Demo
* **Online Demo**:
We are temporarily providing 8 GPUs to support the online demos. If you notice a long queue for a particular link, please try other links:
[Demo1](https://be5916033313307354.gradio.live), [Demo2](https://281efc44b736406f42.gradio.live), [Demo3](https://a27912fbaef54294f8.gradio.live), [Demo4](https://bbf305e391bc769d22.gradio.live)
[Chat-Demo1](https://a79e0445bb498554e8.gradio.live), [Chat-Demo2](https://7f922fdca66e47c427.gradio.live), [Chat-Demo3](https://6568f4b2a8353be3ae.gradio.live), [Chat-Demo4](https://f38c30ed99f0f6caab.gradio.live)
<!-- [Available on Hugging Face Spaces 🚀](https://huggingface.co/spaces/Shitao/OmniGen2) -->
* **Run Locally**:
```bash
pip install gradio
python app.py
# Optional: Share demo with public link (You need to be able to access huggingface)
python app.py --share
```
## 💡 Usage Tips
To achieve optimal results with OmniGen2, you can adjust the following key hyperparameters based on your specific use case.
- `num_inference_step`: The number of sampling steps per generation. Higher values generally improve quality but increase generation time.
- Recommended Range: 28 to 50
- `text_guidance_scale`: Controls how strictly the output adheres to the text prompt (Classifier-Free Guidance).
- **For Text-to-Image**: Use a higher value (e.g., 6-7) for simple or less detailed prompts. Use a lower value (e.g., 4) for complex and highly detailed prompts.
- **For Editing/Composition**: A moderate value around 4-5 is recommended.
- `image_guidance_scale`: This controls how much the final image should resemble the input reference image.
- **The Trade-off**: A higher value (~2.0) makes the output more faithful to the reference image's structure and style, but it might ignore parts of your text prompt. A lower value (~1.5) gives the text prompt more influence.
- **Tip**: Start with 1.5 and increase it if you need more consistency with the reference image. For image editing task, we recommend to set it between 1.3 and 2.0; for in-context generateion task, a higher image_guidance_scale will maintian more details in input images, and we recommend to set it between 2.5 and 3.0.
- `max_pixels`: Automatically resizes images when their total pixel count (width × height) exceeds this limit, while maintaining its aspect ratio. This helps manage performance and memory usage.
- `max_input_image_side_length`: Maximum side length for input images.
- `negative_prompt`: Tell the model what you don't want to see in the image.
- **Example**: blurry, low quality, text, watermark
- **Tip**: For the best results, try experimenting with different negative prompts. If you're not sure, just leave it blank.
<!-- ## 💻 Resources Requirement
OmniGen2 require around 21G GPU memory for BF16 inference. For users do not have such GPU memory, may try: -->
## ❤️ Citing Us
If you find this repository or our work useful, please consider giving a star ⭐ and citation 🦖, which would be greatly appreciated (OmniGen2 report will be available as soon as possible):
```bibtex
@article{xiao2024omnigen,
title={Omnigen: Unified image generation},
author={Xiao, Shitao and Wang, Yueze and Zhou, Junjie and Yuan, Huaying and Xing, Xingrun and Yan, Ruiran and Wang, Shuting and Huang, Tiejun and Liu, Zheng},
journal={arXiv preprint arXiv:2409.11340},
year={2024}
}
```
## License
This work is licensed under Apache 2.0 license.
from .cache_init import cache_init
from .cal_type import cal_type
from .force_scheduler import force_scheduler
\ No newline at end of file
# Modified from https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-xDiT/taylorseer_flux/cache_functions/cache_init.py
# Type hinting would cause circular import, self should be `OmniGen2Pipeline`
def cache_init(self, num_steps: int):
'''
Initialization for cache.
'''
cache_dic = {}
cache = {}
cache_index = {}
cache[-1]={}
cache_index[-1]={}
cache_index['layer_index']={}
cache[-1]['layers_stream']={}
cache_dic['cache_counter'] = 0
for j in range(len(self.transformer.layers)):
cache[-1]['layers_stream'][j] = {}
cache_index[-1][j] = {}
cache_dic['Delta-DiT'] = False
cache_dic['cache_type'] = 'random'
cache_dic['cache_index'] = cache_index
cache_dic['cache'] = cache
cache_dic['fresh_ratio_schedule'] = 'ToCa'
cache_dic['fresh_ratio'] = 0.0
cache_dic['fresh_threshold'] = 3
cache_dic['soft_fresh_weight'] = 0.0
cache_dic['taylor_cache'] = True
cache_dic['max_order'] = 4
cache_dic['first_enhance'] = 5
current = {}
current['activated_steps'] = [0]
current['step'] = 0
current['num_steps'] = num_steps
return cache_dic, current
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment