Commit 109f0842 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2847 failed with stages
in 0 seconds
import dotenv
dotenv.load_dotenv(override=True)
import argparse
import os
from typing import List, Tuple
from PIL import Image, ImageOps
import torch
from torchvision.transforms.functional import to_pil_image, to_tensor
from accelerate import Accelerator
from diffusers.hooks import apply_group_offloading
from omnigen2.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline
from omnigen2.models.transformers.transformer_omnigen2 import OmniGen2Transformer2DModel
def parse_args() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="OmniGen2 image generation script.")
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Path to model checkpoint.",
)
parser.add_argument(
"--transformer_path",
type=str,
default=None,
help="Path to transformer checkpoint.",
)
parser.add_argument(
"--transformer_lora_path",
type=str,
default=None,
help="Path to transformer LoRA checkpoint.",
)
parser.add_argument(
"--scheduler",
type=str,
default="euler",
choices=["euler", "dpmsolver++"],
help="Scheduler to use.",
)
parser.add_argument(
"--num_inference_step",
type=int,
default=50,
help="Number of inference steps."
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="Random seed for generation."
)
parser.add_argument(
"--height",
type=int,
default=1024,
help="Output image height."
)
parser.add_argument(
"--width",
type=int,
default=1024,
help="Output image width."
)
parser.add_argument(
"--max_input_image_pixels",
type=int,
default=1048576,
help="Maximum number of pixels for each input image."
)
parser.add_argument(
"--dtype",
type=str,
default='bf16',
choices=['fp32', 'fp16', 'bf16'],
help="Data type for model weights."
)
parser.add_argument(
"--text_guidance_scale",
type=float,
default=5.0,
help="Text guidance scale."
)
parser.add_argument(
"--image_guidance_scale",
type=float,
default=2.0,
help="Image guidance scale."
)
parser.add_argument(
"--cfg_range_start",
type=float,
default=0.0,
help="Start of the CFG range."
)
parser.add_argument(
"--cfg_range_end",
type=float,
default=1.0,
help="End of the CFG range."
)
parser.add_argument(
"--instruction",
type=str,
default="A dog running in the park",
help="Text prompt for generation."
)
parser.add_argument(
"--negative_prompt",
type=str,
default="(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar",
help="Negative prompt for generation."
)
parser.add_argument(
"--input_image_path",
type=str,
nargs='+',
default=None,
help="Path(s) to input image(s)."
)
parser.add_argument(
"--output_image_path",
type=str,
default="output.png",
help="Path to save output image."
)
parser.add_argument(
"--num_images_per_prompt",
type=int,
default=1,
help="Number of images to generate per prompt."
)
parser.add_argument(
"--enable_model_cpu_offload",
action="store_true",
help="Enable model CPU offload."
)
parser.add_argument(
"--enable_sequential_cpu_offload",
action="store_true",
help="Enable sequential CPU offload."
)
parser.add_argument(
"--enable_group_offload",
action="store_true",
help="Enable group offload."
)
parser.add_argument(
"--enable_teacache",
action="store_true",
help="Enable teacache to speed up inference."
)
parser.add_argument(
"--teacache_rel_l1_thresh",
type=float,
default=0.05,
help="Relative L1 threshold for teacache."
)
parser.add_argument(
"--enable_taylorseer",
action="store_true",
help="Enable TaylorSeer Caching."
)
return parser.parse_args()
def load_pipeline(args: argparse.Namespace, accelerator: Accelerator, weight_dtype: torch.dtype) -> OmniGen2Pipeline:
pipeline = OmniGen2Pipeline.from_pretrained(
args.model_path,
torch_dtype=weight_dtype,
trust_remote_code=True,
)
if args.transformer_path:
print(f"Transformer weights loaded from {args.transformer_path}")
pipeline.transformer = OmniGen2Transformer2DModel.from_pretrained(
args.transformer_path,
torch_dtype=weight_dtype,
)
else:
pipeline.transformer = OmniGen2Transformer2DModel.from_pretrained(
args.model_path,
subfolder="transformer",
torch_dtype=weight_dtype,
)
if args.transformer_lora_path:
print(f"LoRA weights loaded from {args.transformer_lora_path}")
pipeline.load_lora_weights(args.transformer_lora_path)
if args.enable_teacache and args.enable_taylorseer:
print("WARNING: enable_teacache and enable_taylorseer are mutually exclusive. enable_teacache will be ignored.")
if args.enable_taylorseer:
pipeline.enable_taylorseer = True
elif args.enable_teacache:
pipeline.transformer.enable_teacache = True
pipeline.transformer.teacache_rel_l1_thresh = args.teacache_rel_l1_thresh
if args.scheduler == "dpmsolver++":
from omnigen2.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
scheduler = DPMSolverMultistepScheduler(
algorithm_type="dpmsolver++",
solver_type="midpoint",
solver_order=2,
prediction_type="flow_prediction",
)
pipeline.scheduler = scheduler
if args.enable_sequential_cpu_offload:
pipeline.enable_sequential_cpu_offload()
elif args.enable_model_cpu_offload:
pipeline.enable_model_cpu_offload()
elif args.enable_group_offload:
apply_group_offloading(pipeline.transformer, onload_device=accelerator.device, offload_type="block_level", num_blocks_per_group=2, use_stream=True)
apply_group_offloading(pipeline.mllm, onload_device=accelerator.device, offload_type="block_level", num_blocks_per_group=2, use_stream=True)
apply_group_offloading(pipeline.vae, onload_device=accelerator.device, offload_type="block_level", num_blocks_per_group=2, use_stream=True)
else:
pipeline = pipeline.to(accelerator.device)
return pipeline
def preprocess(input_image_path: List[str] = []) -> Tuple[str, str, List[Image.Image]]:
"""Preprocess the input images."""
# Process input images
input_images = None
if input_image_path:
input_images = []
if isinstance(input_image_path, str):
input_image_path = [input_image_path]
if len(input_image_path) == 1 and os.path.isdir(input_image_path[0]):
input_images = [Image.open(os.path.join(input_image_path[0], f)).convert("RGB")
for f in os.listdir(input_image_path[0])]
else:
input_images = [Image.open(path).convert("RGB") for path in input_image_path]
input_images = [ImageOps.exif_transpose(img) for img in input_images]
return input_images
def run(args: argparse.Namespace,
accelerator: Accelerator,
pipeline: OmniGen2Pipeline,
instruction: str,
negative_prompt: str,
input_images: List[Image.Image]) -> Image.Image:
"""Run the image generation pipeline with the given parameters."""
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
results = pipeline(
prompt=instruction,
input_images=input_images,
width=args.width,
height=args.height,
num_inference_steps=args.num_inference_step,
max_sequence_length=1024,
text_guidance_scale=args.text_guidance_scale,
image_guidance_scale=args.image_guidance_scale,
cfg_range=(args.cfg_range_start, args.cfg_range_end),
negative_prompt=negative_prompt,
num_images_per_prompt=args.num_images_per_prompt,
generator=generator,
output_type="pil",
)
return results
def create_collage(images: List[torch.Tensor]) -> Image.Image:
"""Create a horizontal collage from a list of images."""
max_height = max(img.shape[-2] for img in images)
total_width = sum(img.shape[-1] for img in images)
canvas = torch.zeros((3, max_height, total_width), device=images[0].device)
current_x = 0
for img in images:
h, w = img.shape[-2:]
canvas[:, :h, current_x:current_x+w] = img * 0.5 + 0.5
current_x += w
return to_pil_image(canvas)
def main(args: argparse.Namespace, root_dir: str) -> None:
"""Main function to run the image generation process."""
# Initialize accelerator
accelerator = Accelerator(mixed_precision=args.dtype if args.dtype != 'fp32' else 'no')
# Set weight dtype
weight_dtype = torch.float32
if args.dtype == 'fp16':
weight_dtype = torch.float16
elif args.dtype == 'bf16':
weight_dtype = torch.bfloat16
# Load pipeline and process inputs
pipeline = load_pipeline(args, accelerator, weight_dtype)
input_images = preprocess(args.input_image_path)
# Generate and save image
results = run(args, accelerator, pipeline, args.instruction, args.negative_prompt, input_images)
os.makedirs(os.path.dirname(args.output_image_path), exist_ok=True)
if len(results.images) > 1:
for i, image in enumerate(results.images):
image_name, ext = os.path.splitext(args.output_image_path)
image.save(f"{image_name}_{i}{ext}")
vis_images = [to_tensor(image) * 2 - 1 for image in results.images]
output_image = create_collage(vis_images)
output_image.save(args.output_image_path)
print(f"Image saved to {args.output_image_path}")
if __name__ == "__main__":
root_dir = os.path.abspath(os.path.join(__file__, os.path.pardir))
args = parse_args()
main(args, root_dir)
\ No newline at end of file
import dotenv
dotenv.load_dotenv(override=True)
import argparse
import os
from typing import List, Tuple
from PIL import Image
import torch
from torchvision.transforms.functional import to_pil_image, to_tensor
from accelerate import Accelerator
from diffusers.hooks import apply_group_offloading
from omnigen2.pipelines.omnigen2.pipeline_omnigen2_chat import OmniGen2ChatPipeline
from omnigen2.models.transformers.transformer_omnigen2 import OmniGen2Transformer2DModel
def parse_args() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="OmniGen2 image generation script.")
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Path to model checkpoint.",
)
parser.add_argument(
"--scheduler",
type=str,
default="euler",
choices=["euler", "dpmsolver++"],
help="Scheduler to use.",
)
parser.add_argument(
"--num_inference_step",
type=int,
default=50,
help="Number of inference steps."
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="Random seed for generation."
)
parser.add_argument(
"--height",
type=int,
default=1024,
help="Output image height."
)
parser.add_argument(
"--width",
type=int,
default=1024,
help="Output image width."
)
parser.add_argument(
"--max_input_image_pixels",
type=int,
default=1048576,
help="Maximum number of pixels for each input image."
)
parser.add_argument(
"--dtype",
type=str,
default='bf16',
choices=['fp32', 'fp16', 'bf16'],
help="Data type for model weights."
)
parser.add_argument(
"--text_guidance_scale",
type=float,
default=5.0,
help="Text guidance scale."
)
parser.add_argument(
"--image_guidance_scale",
type=float,
default=2.0,
help="Image guidance scale."
)
parser.add_argument(
"--cfg_range_start",
type=float,
default=0.0,
help="Start of the CFG range."
)
parser.add_argument(
"--cfg_range_end",
type=float,
default=1.0,
help="End of the CFG range."
)
parser.add_argument(
"--instruction",
type=str,
default="A dog running in the park",
help="Text prompt for generation."
)
parser.add_argument(
"--negative_prompt",
type=str,
default="(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar",
help="Negative prompt for generation."
)
parser.add_argument(
"--input_image_path",
type=str,
nargs='+',
default=None,
help="Path(s) to input image(s)."
)
parser.add_argument(
"--output_image_path",
type=str,
default="output.png",
help="Path to save output image."
)
parser.add_argument(
"--num_images_per_prompt",
type=int,
default=1,
help="Number of images to generate per prompt."
)
parser.add_argument(
"--enable_sequential_cpu_offload",
action="store_true",
help="Enable sequential CPU offload."
)
parser.add_argument(
"--enable_model_cpu_offload",
action="store_true",
help="Enable model CPU offload."
)
parser.add_argument(
"--enable_group_offload",
action="store_true",
help="Enable group offload."
)
return parser.parse_args()
def load_pipeline(args: argparse.Namespace, accelerator: Accelerator, weight_dtype: torch.dtype) -> OmniGen2ChatPipeline:
pipeline = OmniGen2ChatPipeline.from_pretrained(
args.model_path,
torch_dtype=weight_dtype,
trust_remote_code=True,
)
pipeline.transformer = OmniGen2Transformer2DModel.from_pretrained(
args.model_path,
subfolder="transformer",
torch_dtype=weight_dtype,
)
if args.scheduler == "dpmsolver":
from omnigen2.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
scheduler = DPMSolverMultistepScheduler(
algorithm_type="dpmsolver++",
solver_type="midpoint",
solver_order=2,
prediction_type="flow_prediction",
)
pipeline.scheduler = scheduler
if args.enable_sequential_cpu_offload:
pipeline.enable_sequential_cpu_offload()
elif args.enable_model_cpu_offload:
pipeline.enable_model_cpu_offload()
elif args.enable_group_offload:
apply_group_offloading(pipeline.transformer, onload_device=accelerator.device, offload_type="block_level", num_blocks_per_group=2, use_stream=True)
apply_group_offloading(pipeline.mllm, onload_device=accelerator.device, offload_type="block_level", num_blocks_per_group=2, use_stream=True)
apply_group_offloading(pipeline.vae, onload_device=accelerator.device, offload_type="block_level", num_blocks_per_group=2, use_stream=True)
else:
pipeline = pipeline.to(accelerator.device)
return pipeline
def preprocess(input_image_path: List[str] = []) -> Tuple[str, str, List[Image.Image]]:
"""Preprocess the input images."""
# Process input images
input_images = None
if input_image_path:
input_images = []
if isinstance(input_image_path, str):
input_image_path = [input_image_path]
if len(input_image_path) == 1 and os.path.isdir(input_image_path[0]):
input_images = [Image.open(os.path.join(input_image_path[0], f)).convert("RGB")
for f in os.listdir(input_image_path[0])]
else:
input_images = [Image.open(path).convert("RGB") for path in input_image_path]
return input_images
def run(args: argparse.Namespace,
accelerator: Accelerator,
pipeline: OmniGen2ChatPipeline,
instruction: str,
negative_prompt: str,
input_images: List[Image.Image]) -> Image.Image:
"""Run the image generation pipeline with the given parameters."""
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
results = pipeline(
prompt=instruction,
input_images=input_images,
width=args.width,
height=args.height,
num_inference_steps=args.num_inference_step,
max_sequence_length=1024,
text_guidance_scale=args.text_guidance_scale,
image_guidance_scale=args.image_guidance_scale,
cfg_range=(args.cfg_range_start, args.cfg_range_end),
negative_prompt=negative_prompt,
num_images_per_prompt=args.num_images_per_prompt,
generator=generator,
output_type="pil",
)
return results
def create_collage(images: List[torch.Tensor]) -> Image.Image:
"""Create a horizontal collage from a list of images."""
max_height = max(img.shape[-2] for img in images)
total_width = sum(img.shape[-1] for img in images)
canvas = torch.zeros((3, max_height, total_width), device=images[0].device)
current_x = 0
for img in images:
h, w = img.shape[-2:]
canvas[:, :h, current_x:current_x+w] = img * 0.5 + 0.5
current_x += w
return to_pil_image(canvas)
def main(args: argparse.Namespace, root_dir: str) -> None:
"""Main function to run the image generation process."""
# Initialize accelerator
accelerator = Accelerator(mixed_precision=args.dtype if args.dtype != 'fp32' else 'no')
# Set weight dtype
weight_dtype = torch.float32
if args.dtype == 'fp16':
weight_dtype = torch.float16
elif args.dtype == 'bf16':
weight_dtype = torch.bfloat16
# Load pipeline and process inputs
pipeline = load_pipeline(args, accelerator, weight_dtype)
input_images = preprocess(args.input_image_path)
# Generate and save image
results = run(args, accelerator, pipeline, args.instruction, args.negative_prompt, input_images)
if results.images is not None:
vis_images = [to_tensor(image) * 2 - 1 for image in results.images]
output_image = create_collage(vis_images)
os.makedirs(os.path.dirname(args.output_image_path), exist_ok=True)
output_image.save(args.output_image_path)
print(f"Image saved to {args.output_image_path}")
print(f"Text: {results.text}")
if __name__ == "__main__":
root_dir = os.path.abspath(os.path.join(__file__, os.path.pardir))
args = parse_args()
main(args, root_dir)
# 模型编码
modelCode=1670
# 模型名称
modelName=OmniGen2_pytorch
# 模型描述
modelDescription=引入反思机制,多模态任务生成屠榜,一键解锁AI绘图「哆啦 A 梦」任意门。
# 应用场景
appScenario=推理,多模态,制造,广媒,金融,能源,医疗,家居,教育
# 框架类型
frameType=pytorch
# OmniContext
As part of OmniGen2, we introduce a new benchmark for in-context generation, **OmniContext**, which aims to provide a more comprehensive evaluation of models' in-context generation abilities. It incorporates a diverse set of input images and instructions, and utilizes GPT-4.1 for interpretable, metric-driven assessment.
<p align="center">
<img src="../assets/omnicontext_overview.png" width="95%">
<br>
<em>Overview of OmniContext benchmark.</em>
</p>
<p align="center">
<img src="../assets/omnicontext_evaluation.png" width="95%">
<br>
<em>An illustrative evaluation case in the OmniContext benchmark.</em>
</p>
The evaluation of the OmniContext benchmark includes the following steps:
## Step1 Environment Setup
```bash
# 1. Activate Python environment
conda activate omnigen2
# 2. Install dependencies
pip install -U datasets megfile
```
## Step2 Generate Images
Note: we fix the resolution of the output images at 1024 × 1024 to ensure that the settings are consistent across different models.
You may try generating results using OmniGen2 or other models; please ensure that the output image directory structure and format are consistent with the format specified below.
```
results/
├── {method_name}/
│ └── fullset/
│ └── {task_type}/
│ ├── key1.png
│ ├── key2.png
│ └── ...
```
To use OmniGen2, you can run the following script to generate images:
```bash
cd OmniGen2
accelerate launch --num_processes=8 -m omnicontext.inference \
--model_path "OmniGen2/OmniGen2" \
--model_name "OmniGen2" \
--test_data "OmniGen2/OmniContext" \
--result_dir "omnicontext/results" \
--num_inference_step 50 \
--height 1024 \
--width 1024 \
--text_guidance_scale 5.0 \
--image_guidance_scale 2.0 \
--num_images_per_prompt 1 \
--disable_align_res # Align the resolution to the original image when dealing image editing tasks, disable it when dealing in context generation tasks.
```
## Step3 Evaluation
1. We use GPT-4.1 to evaluate the quality of the generated images. Please make sure to set up your API key before running the script.
```bash
cd OmniGen2
openai_key="<Your-API-Key>"
python -m omnicontext.test_omnicontext_score \
--test_data "OmniGen2/OmniContext" \
--result_dir "omnicontext/results" \
--model_name "OmniGen2" \
--openai_key ${openai_key} \
--max_workers 100
```
2. Next, calculate the final score:
```bash
python -m omnicontext.calculate_statistics \
--save_path "omnicontext/results" \
--model_name "OmniGen2" \
--backbone gpt4dot1
```
## Acknowledgements
The code structure of this benchmark is inspired by [Step1X-Edit](https://github.com/stepfun-ai/Step1X-Edit).
Special thanks to the original project for their valuable contribution.
\ No newline at end of file
import megfile
import os
import pandas as pd
from collections import defaultdict
import sys
import numpy as np
import math
import json
import glob
def analyze_scores(json_lines, language):
group_prompt_following_scores = {}
group_subject_consistency_scores = {}
group_overall_scores = {}
for task_type in json_lines.keys():
prompt_following_scores = []
subject_consistency_scores = []
overall_scores = []
for json_line in json_lines[task_type]:
if json_line['instruction_language'] != language:
continue
prompt_following_score = json_line['PF_score']
subject_consistency_score = json_line['SC_score']
overall_score = math.sqrt(prompt_following_score * subject_consistency_score)
prompt_following_scores.append(prompt_following_score)
subject_consistency_scores.append(subject_consistency_score)
overall_scores.append(overall_score)
group_prompt_following_scores[task_type] = np.mean(prompt_following_scores)
group_subject_consistency_scores[task_type] = np.mean(subject_consistency_scores)
group_overall_scores[task_type] = np.mean(overall_scores)
return group_prompt_following_scores, group_subject_consistency_scores, group_overall_scores
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--save_path", type=str, default="/results/")
parser.add_argument("--backbone", type=str, default="gpt4dot1")
parser.add_argument("--model_name", type=str, default="OmniGen2")
parser.add_argument("--language", type=str, default="en")
args = parser.parse_args()
result_json_files = glob.glob(os.path.join(args.save_path, args.model_name, args.backbone, "**/*.jsonl"))
print(f"{result_json_files=}")
print(f"{len(result_json_files)=}")
result_json_lines = defaultdict(list)
for result_json_file in result_json_files:
with open(result_json_file, 'r') as f:
for line in f:
data = json.loads(line)
task_type = os.path.basename(os.path.dirname(result_json_file))
result_json_lines[task_type].append(data)
group_prompt_following_scores, group_subject_consistency_scores, group_overall_scores = analyze_scores(result_json_lines, language=args.language)
for task_type in group_prompt_following_scores.keys():
print(f"{task_type}: {group_prompt_following_scores[task_type]:.3f}, {group_subject_consistency_scores[task_type]:.3f}, {group_overall_scores[task_type]:.3f}")
print(f"Average: {np.mean(list(group_prompt_following_scores.values())):.3f}, {np.mean(list(group_subject_consistency_scores.values())):.3f}, {np.mean(list(group_overall_scores.values())):.3f}")
\ No newline at end of file
import dotenv
dotenv.load_dotenv(override=True)
import argparse
import os
import datasets
from tqdm import tqdm
from typing import List, Tuple
from torch.utils.data import DataLoader
from PIL import Image, ImageOps
import torch
from torchvision.transforms.functional import to_pil_image, to_tensor
from accelerate import Accelerator
from diffusers.hooks import apply_group_offloading
from omnigen2.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline
from omnigen2.models.transformers.transformer_omnigen2 import OmniGen2Transformer2DModel
def parse_args() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="OmniGen2 image generation script.")
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Path to model checkpoint.",
)
parser.add_argument(
"--model_name",
type=str,
required=True,
help="Model name for output directory.",
)
parser.add_argument(
"--scheduler",
type=str,
default="euler",
choices=["euler", "dpmsolver"],
help="Scheduler to use.",
)
parser.add_argument(
"--num_inference_step",
type=int,
default=50,
help="Number of inference steps."
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="Random seed for generation."
)
parser.add_argument(
"--height",
type=int,
default=1024,
help="Output image height."
)
parser.add_argument(
"--width",
type=int,
default=1024,
help="Output image width."
)
parser.add_argument(
"--max_input_image_pixels",
type=int,
default=1048576,
help="Maximum number of pixels for each input image."
)
parser.add_argument(
"--dtype",
type=str,
default='bf16',
choices=['fp32', 'fp16', 'bf16'],
help="Data type for model weights."
)
parser.add_argument(
"--text_guidance_scale",
type=float,
default=5.0,
help="Text guidance scale."
)
parser.add_argument(
"--image_guidance_scale",
type=float,
default=2.0,
help="Image guidance scale."
)
parser.add_argument(
"--cfg_range_start",
type=float,
default=0.0,
help="Start of the CFG range."
)
parser.add_argument(
"--cfg_range_end",
type=float,
default=1.0,
help="End of the CFG range."
)
parser.add_argument(
"--negative_prompt",
type=str,
default="(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar",
help="Negative prompt for generation."
)
parser.add_argument(
"--test_data",
type=str,
default=None,
help="Path to test data."
)
parser.add_argument(
"--result_dir",
type=str,
default="results",
help="Path to save the generated images."
)
parser.add_argument(
"--num_images_per_prompt",
type=int,
default=1,
help="Number of images to generate per prompt."
)
parser.add_argument(
"--enable_model_cpu_offload",
action="store_true",
help="Enable model CPU offload."
)
parser.add_argument(
"--enable_sequential_cpu_offload",
action="store_true",
help="Enable sequential CPU offload."
)
parser.add_argument(
"--enable_group_offload",
action="store_true",
help="Enable group offload."
)
parser.add_argument(
"--disable_align_res",
action="store_true",
help="Align resolution to the input image resolution."
)
return parser.parse_args()
class Collator:
def __call__(self, features):
return features
def load_pipeline(args: argparse.Namespace, accelerator: Accelerator, weight_dtype: torch.dtype) -> OmniGen2Pipeline:
from transformers import CLIPProcessor
pipeline = OmniGen2Pipeline.from_pretrained(
args.model_path,
processor=CLIPProcessor.from_pretrained(
args.model_path,
subfolder="processor",
use_fast=True
),
torch_dtype=weight_dtype,
trust_remote_code=True,
)
pipeline.transformer = OmniGen2Transformer2DModel.from_pretrained(
args.model_path,
subfolder="transformer",
torch_dtype=weight_dtype,
)
if args.scheduler == "dpmsolver":
from omnigen2.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
scheduler = DPMSolverMultistepScheduler(
algorithm_type="dpmsolver++",
solver_type="midpoint",
solver_order=2,
prediction_type="flow_prediction",
)
pipeline.scheduler = scheduler
if args.enable_sequential_cpu_offload:
pipeline.enable_sequential_cpu_offload()
elif args.enable_model_cpu_offload:
pipeline.enable_model_cpu_offload()
elif args.enable_group_offload:
apply_group_offloading(pipeline.transformer, onload_device=accelerator.device, offload_type="block_level", num_blocks_per_group=2, use_stream=True)
apply_group_offloading(pipeline.mllm, onload_device=accelerator.device, offload_type="block_level", num_blocks_per_group=2, use_stream=True)
apply_group_offloading(pipeline.vae, onload_device=accelerator.device, offload_type="block_level", num_blocks_per_group=2, use_stream=True)
else:
pipeline = pipeline.to(accelerator.device)
return pipeline
def preprocess(input_image_path: List[str] = []) -> Tuple[str, str, List[Image.Image]]:
"""Preprocess the input images."""
# Process input images
input_images = None
if input_image_path:
input_images = []
if isinstance(input_image_path, str):
input_image_path = [input_image_path]
if len(input_image_path) == 1 and os.path.isdir(input_image_path[0]):
input_images = [Image.open(os.path.join(input_image_path[0], f)).convert("RGB")
for f in os.listdir(input_image_path[0])]
else:
input_images = [Image.open(path).convert("RGB") for path in input_image_path]
input_images = [ImageOps.exif_transpose(img) for img in input_images]
return input_images
def run(args: argparse.Namespace,
accelerator: Accelerator,
pipeline: OmniGen2Pipeline,
instruction: str,
negative_prompt: str,
input_images: List[Image.Image]) -> Image.Image:
"""Run the image generation pipeline with the given parameters."""
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
results = pipeline(
prompt=instruction,
input_images=input_images,
width=args.width,
height=args.height,
align_res=not args.disable_align_res,
num_inference_steps=args.num_inference_step,
max_sequence_length=1024,
text_guidance_scale=args.text_guidance_scale,
image_guidance_scale=args.image_guidance_scale,
cfg_range=(args.cfg_range_start, args.cfg_range_end),
negative_prompt=negative_prompt,
num_images_per_prompt=args.num_images_per_prompt,
generator=generator,
output_type="pil",
)
return results
def create_collage(images: List[torch.Tensor]) -> Image.Image:
"""Create a horizontal collage from a list of images."""
max_height = max(img.shape[-2] for img in images)
total_width = sum(img.shape[-1] for img in images)
canvas = torch.zeros((3, max_height, total_width), device=images[0].device)
current_x = 0
for img in images:
h, w = img.shape[-2:]
canvas[:, :h, current_x:current_x+w] = img * 0.5 + 0.5
current_x += w
return to_pil_image(canvas)
def main(args: argparse.Namespace, root_dir: str) -> None:
"""Main function to run the image generation process."""
# Initialize accelerator
accelerator = Accelerator(mixed_precision=args.dtype if args.dtype != 'fp32' else 'no')
test_dataset = datasets.load_dataset(args.test_data, split="train")
print('test_dataset', test_dataset)
loader = DataLoader(
test_dataset,
collate_fn=Collator(),
batch_size=1,
shuffle=True,
# shuffle=False,
pin_memory=False,
drop_last=False,
)
loader = accelerator.prepare(loader)
# Set weight dtype
weight_dtype = torch.float32
if args.dtype == 'fp16':
weight_dtype = torch.float16
elif args.dtype == 'bf16':
weight_dtype = torch.bfloat16
# Load pipeline and process inputs
pipeline = load_pipeline(args, accelerator, weight_dtype)
with tqdm(
total=len(loader),
desc="Generating images...",
unit="image",
disable=not accelerator.is_main_process,
) as pbar:
for i, bacthed_data in tqdm(enumerate(loader), total=len(loader), disable=accelerator.process_index!=0):
for data in bacthed_data:
key = data['key']
task_type = data['task_type']
instruction = data['instruction']
input_images = data['input_images']
input_images = [ImageOps.exif_transpose(img) for img in input_images]
# Generate and save image
results = run(args, accelerator, pipeline, instruction, args.negative_prompt, input_images)
sub_dir = os.path.join(args.result_dir, args.model_name, "fullset", task_type)
os.makedirs(sub_dir, exist_ok=True)
output_image_path = os.path.join(sub_dir, f"{key}.png")
if os.path.exists(output_image_path):
continue
if len(results.images) > 1:
for i, image in enumerate(results.images):
image_name, ext = os.path.splitext(output_image_path)
image.save(f"{image_name}_{i}{ext}")
vis_images = [to_tensor(image) * 2 - 1 for image in results.images]
output_image = create_collage(vis_images)
output_image.save(output_image_path)
pbar.update(1)
if __name__ == "__main__":
root_dir = os.path.abspath(os.path.join(__file__, os.path.pardir))
args = parse_args()
main(args, root_dir)
\ No newline at end of file
import os
from typing import Union, List, Optional
import json
import regex as re
import ast
import random
def fix_json(input_str):
# Add double quotes around keys using regex
fixed_str = re.sub(r'(\w+):', r'"\1":', input_str)
# Add double quotes around string values if necessary and wrap int/float values in []
def format_value(match):
key, value, comma = match.groups()
value = value.strip()
# Check if value is an integer or float
if re.match(r'^-?\d+(\.\d+)?$', value):
value = f'[{value}]'
# Check if value is a boolean or null
elif re.match(r'^(true|false|null)$', value, re.IGNORECASE):
pass # leave as is
else:
# Add quotes around string values
value = f'"{value}"'
return f'{key}: {value}{comma}'
fixed_str = re.sub(r'(".*?"):(.*?)(,|})', format_value, fixed_str)
return fixed_str
def read_file_to_string(file_path):
"""
Reads the contents of a text file and returns it as a string.
:param file_path: The path to the text file.
:return: A string containing the contents of the file.
"""
try:
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
except FileNotFoundError:
print(f"The file {file_path} was not found.")
return None
except Exception as e:
print(f"An error occurred: {e}")
return None
def read_files_to_string(file_paths):
"""
Reads the contents of multiple text files and returns them as a single string,
with each file's contents separated by a newline.
:param file_paths: A list of paths to text files.
:return: A string containing the concatenated contents of the files.
"""
all_contents = [] # List to hold the contents of each file
for file_path in file_paths:
try:
with open(file_path, 'r', encoding='utf-8') as file:
all_contents.append(file.read())
except FileNotFoundError:
print(f"The file {file_path} was not found.")
except Exception as e:
print(f"An error occurred while reading {file_path}: {e}")
# Join all the contents with a newline character
return "\n".join(all_contents)
def get_file_path(filename: Union[str, os.PathLike], search_from: Union[str, os.PathLike] = "."):
"""
Search for a file across a directory and return its absolute path.
Args:
filename (Union[str, os.PathLike]): The name of the file to search for.
search_from (Union[str, os.PathLike], optional): The directory from which to start the search. Defaults to ".".
Returns:
str: Absolute path to the found file.
Raises:
FileNotFoundError: If the file is not found.
"""
for root, dirs, files in os.walk(search_from):
for name in files:
if name == filename:
return os.path.abspath(os.path.join(root, name))
raise FileNotFoundError(filename, "not found.")
#+=========================================================================================
def verify(s, target_sequence):
# Count the occurrences of the target sequence
count = s.count(target_sequence)
# Check if the target sequence appears exactly twice
return count == 2
def is_int_between_0_and_10(s):
try:
num = int(s)
return 0 <= num <= 10
except ValueError:
return False
def is_str_a_list_of_ints_0_to_10(s):
try:
# Attempt to parse the string as a Python literal (list, dict, etc.)
parsed = ast.literal_eval(s)
# Check if the parsed object is a list
if not isinstance(parsed, list):
return False
# Check if all elements are integers and between 0 to 10
return all(isinstance(item, int) and 0 <= item <= 10 for item in parsed)
except (ValueError, SyntaxError):
# If parsing fails or any other error occurs
return False
def is_str_valid_score_format_brackets(s):
try:
# Removing brackets and splitting the string by commas
content = s.strip("[]").split(',')
length = len(content)
# Parsing each element and checking the format and range
scores = {}
for item in content:
key, value = item.split(':')
key = key.strip()
value = int(value.strip())
# Check if the key starts with 'score' and the value is in the correct range
if not key.startswith("score") or not 0 <= value <= 10:
return False
scores[key] = value
fetch_words = [f"score{i+1}" for i in range(length)]
# Check if at least 'score1' and 'score2' are present
return all(key in scores for key in fetch_words)
except (ValueError, SyntaxError):
# If any parsing error occurs
return False
#+=========================================================================================
def mllm_output_to_dict(input_string, give_up_parsing=False):
"""
Args:
input_string (str): actually the output of the mllm model to be parsed
output_file_name (str): The name of the output file.
"""
# Catch for gpt4v rate_limit_exceeded error
if input_string == "rate_limit_exceeded":
return "rate_limit_exceeded"
# find the json mannually
# some mllm tends not to output the delimiters, but it does output the json contents
# so we will find the json content mannually
start_index = input_string.find('{')
end_index = input_string.rfind('}') + 1
if start_index == -1 or end_index == 0:
# json not found
# some mllm tends to output only a list of scores like [6, 0],
# this time we will just get the scores and ignore the reasoning (other part of the json)
start_index = input_string.find('[')
end_index = input_string.rfind(']') + 1
if is_int_between_0_and_10(input_string): # if output is simply a number
score = int(input_string)
json_content = {'score': score, "reasoning": "System: output is simply a number"}
json_str = json.dumps(json_content)
input_string = json_str
start_index = 0
end_index = len(json_str)
else:
raise Exception(f"Failed to find the json content in the string: {input_string=}.")
# Check if we found two delimiters
if start_index != -1 and end_index != -1 and start_index != end_index:
# Extract the JSON string
json_str = input_string[start_index:end_index].strip()
json_str = json_str.replace("\n", "")
# Parse the JSON string into a dictionary
try:
new_data = json.loads(json_str)
except Exception as e:
print(f"Error: {e}")
print("Now fixing: ", json_str)
new_data = json.loads(fix_json(json_str))
return new_data
return new_data
else:
raise Exception(f"The required delimiters were not found correctly in the string: {input_string=}.")
def write_entry_to_json_file(input_string, uid, prompt_input, vision_input, output_file_name, give_up_parsing=False):
"""
Args:
input_string (str): actually the output of the mllm model to be parsed
uid (str): The unique identifier for the each item in the test data
prompt_input (str): The prompt input for the entry. text prompt.
vision_input (str): The vision input for the entry. image links.
output_file_name (str): The name of the output file.
"""
# Catch for gpt4v rate_limit_exceeded error
if input_string == "rate_limit_exceeded":
return "rate_limit_exceeded"
# Define the delimiters
delimiter = '||V^=^V||'
if input_string.count(delimiter) == 2:
if not verify(input_string, delimiter):
print("The required delimiters were not found correctly in the string.")
return False
# Extract the content between the delimiters
start_index = input_string.find(delimiter) + len(delimiter)
end_index = input_string.rfind(delimiter)
else:
# find the json mannually
# some mllm tends not to output the delimiters, but it does output the json contents
# so we will find the json content mannually
start_index = input_string.find('{')
end_index = input_string.rfind('}') + 1
if start_index == -1 or end_index == 0:
# json not found
# some mllm tends to output only a list of scores like [6, 0],
# this time we will just get the scores and ignore the reasoning (other part of the json)
start_index = input_string.find('[')
end_index = input_string.rfind(']') + 1
if give_up_parsing: # if we want to give up parsing
guessed_value = random.randint(0, 10)
print(f"Failed to find the json content in the string. Guess a value : {guessed_value}.")
json_content = {'score': [guessed_value], "reasoning": f"guess_if_cannot_parse | {input_string}"}
json_str = json.dumps(json_content)
input_string = json_str
start_index = 0
end_index = len(json_str)
elif re.match(r'^\[\d+, ?\d+\]$', input_string[start_index:end_index]):
scores = json.loads(input_string[start_index:end_index])
json_content = {'score': scores, "reasoning": None}
json_str = json.dumps(json_content)
input_string = json_str
start_index = 0
end_index = len(json_str)
elif is_int_between_0_and_10(input_string): # if output is simply a number
scores = [int(input_string)]
json_content = {'score': scores, "reasoning": None}
json_str = json.dumps(json_content)
input_string = json_str
start_index = 0
end_index = len(json_str)
else:
print("Failed to find the json content in the string.")
return False
# Check if we found two delimiters
if start_index != -1 and end_index != -1 and start_index != end_index:
# Extract the JSON string
json_str = input_string[start_index:end_index].strip()
json_str = json_str.replace("\n", "")
try:
# Parse the JSON string into a dictionary
new_data = json.loads(json_str)
# Ensure the directory exists
os.makedirs(os.path.dirname(output_file_name), exist_ok=True)
# Initialize or load existing data
if os.path.exists(output_file_name):
with open(output_file_name, 'r') as json_file:
data = json.load(json_file)
else:
data = {}
# If the additional key is already in the data, add or update notes
if uid in data:
data[uid].update(new_data) # Update with new data
if prompt_input: # If there are new notes, update or add them
data[uid]['prompt_input'] = prompt_input
if vision_input: # If there are new notes, update or add them
data[uid]['vision_input'] = vision_input
else:
# If it's a new key, add the entry to the dictionary
data[uid] = new_data
if prompt_input:
data[uid]['prompt_input'] = prompt_input
if vision_input:
data[uid]['vision_input'] = vision_input
# Write the updated data to the file
with open(output_file_name, 'w') as json_file:
json.dump(data, json_file, indent=4)
print(f"Data was successfully updated in {output_file_name}")
return True
except json.JSONDecodeError as e:
print(f"An error occurred while parsing the JSON content: {e}")
return False
else:
print("The required delimiters were not found correctly in the string.")
return False
def check_key_in_json(file_path, key):
try:
with open(file_path, 'r') as json_file:
data = json.load(json_file)
# Check if the key exists at the top level of the JSON structure
if key in data:
return True
else:
return False
except FileNotFoundError:
print(f"The file {file_path} was not found.")
except json.JSONDecodeError as e:
print(f"Error reading {file_path}: {e}")
except Exception as e:
print(f"An error occurred with {file_path}: {e}")
return False
\ No newline at end of file
from .prompt_generator import PromptGenerator
from .openai_util import ask_gpt4o
from .json_util import mllm_output_to_dict
import random
import json
import time
class OmniContextScore:
def __init__(self, openai_url: str, openai_key: str) -> None:
self.openai_url = openai_url
self.openai_key = openai_key
self.prompt_generator = PromptGenerator()
def evaluate(self, input_image_paths, instruction, with_scene=False):
results_dict = {}
max_tries = 3
PF_scores = None
SC_scores = None
for try_idx in range(max_tries):
try:
PF_prompt = self.prompt_generator(instruction, task_type="prompt_following")
SC_prompt = self.prompt_generator(instruction, task_type="subject_consistency", with_scene=with_scene)
PF_results = ask_gpt4o(input_image_paths, PF_prompt, self.openai_url, self.openai_key)
SC_results = ask_gpt4o(input_image_paths, SC_prompt, self.openai_url, self.openai_key)
PF_scores = mllm_output_to_dict(PF_results)
SC_scores = mllm_output_to_dict(SC_results)
if PF_scores == "rate_limit_exceeded" or SC_scores == "rate_limit_exceeded":
raise Exception("rate_limit_exceeded")
except Exception as e:
backoff_time = 2 ** try_idx # Exponential backoff: 1, 2, 4 seconds
print(f"{e}, {instruction=}, Attempt {try_idx+1} failed, retrying after {backoff_time} seconds...")
time.sleep(backoff_time)
if PF_scores is None:
guessed_value = random.randint(0, 10)
print(f"Failed to find the json content in the string for {instruction}. Guess a value : {guessed_value=}.", flush=True)
PF_scores = {'score': guessed_value, "reasoning": f"guess_if_cannot_parse | {PF_results}"}
if SC_scores is None:
guessed_value = random.randint(0, 10)
print(f"Failed to find the json content in the string for {instruction}. Guess a value : {guessed_value=}.", flush=True)
SC_scores = {'score': guessed_value, "reasoning": f"guess_if_cannot_parse | {SC_results}"}
results_dict["PF_scores"] = PF_scores
results_dict["SC_scores"] = SC_scores
return results_dict
\ No newline at end of file
import requests
import base64
from io import BytesIO
from PIL import Image, ImageOps
from typing import Union, Optional, Tuple, List
import os
def encode_pil_image(pil_image):
# Create an in-memory binary stream
image_stream = BytesIO()
# Save the PIL image to the binary stream in JPEG format (you can change the format if needed)
pil_image.save(image_stream, format='JPEG')
# Get the binary data from the stream and encode it as base64
image_data = image_stream.getvalue()
base64_image = base64.b64encode(image_data).decode('utf-8')
return base64_image
def load_image(image: Union[str, Image.Image], format: str = "RGB", size: Optional[Tuple] = None) -> Image.Image:
"""
Load an image from a given path or URL and convert it to a PIL Image.
Args:
image (Union[str, Image.Image]): The image path, URL, or a PIL Image object to be loaded.
format (str, optional): Desired color format of the resulting image. Defaults to "RGB".
size (Optional[Tuple], optional): Desired size for resizing the image. Defaults to None.
Returns:
Image.Image: A PIL Image in the specified format and size.
Raises:
ValueError: If the provided image format is not recognized.
"""
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
image = Image.open(requests.get(image, stream=True).raw)
elif os.path.isfile(image):
image = Image.open(image)
else:
raise ValueError(
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
)
elif isinstance(image, Image.Image):
image = image
else:
raise ValueError(
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
)
image = ImageOps.exif_transpose(image)
image = image.convert(format)
if (size != None):
image = image.resize(size, Image.LANCZOS)
return image
def prepare_prompt(image_links: List = [], text_prompt: str = ""):
prompt_content = []
text_dict = {"type": "text", "text": text_prompt}
prompt_content.append(text_dict)
if not isinstance(image_links, list):
image_links = [image_links]
for image_link in image_links:
image = load_image(image_link)
visual_dict = {
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encode_pil_image(image)}"},
}
prompt_content.append(visual_dict)
return prompt_content
def ask_gpt4o(image_path, prompt, url, api_key):
prompt = prepare_prompt(image_path, prompt)
payload = {
"model": "gpt-4.1",
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 1400,
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
try:
response = requests.post(url, json=payload, headers=headers, timeout=180) # Set timeout to 5 minutes (300 seconds)
except Exception as e:
print(f"Error: {e}")
return ""
return extract_response(response)
def extract_response(response):
try:
response = response.json()
out = response["choices"][0]["message"]["content"]
return out
except:
if response["error"]["code"] == "content_policy_violation":
print("Code is content_policy_violation")
elif response["error"]["code"] in [
"rate_limit_exceeded",
"insufficient_quota",
"insufficient_user_quota",
]:
print(f"Code is {response['error']['code']}", flush=True)
print(response["error"]["message"], flush=True)
return "rate_limit_exceeded"
else:
print("Code is different")
print(response)
print(f"{response['error']['code']=}")
return ""
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment