Unverified Commit 798e1718 authored by Shitao Xiao's avatar Shitao Xiao Committed by GitHub
Browse files

Add OmniGen (#10148)



* OmniGen model.py

* update OmniGenTransformerModel

* omnigen pipeline

* omnigen pipeline

* update omnigen_pipeline

* test case for omnigen

* update omnigenpipeline

* update docs

* update docs

* offload_transformer

* enable_transformer_block_cpu_offload

* update docs

* reformat

* reformat

* reformat

* update docs

* update docs

* make style

* make style

* Update docs/source/en/api/models/omnigen_transformer.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* update docs

* revert changes to examples/

* update OmniGen2DModel

* make style

* update test cases

* Update docs/source/en/api/pipelines/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* update docs

* typo

* Update src/diffusers/models/embeddings.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/models/attention.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/models/transformers/transformer_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/models/transformers/transformer_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/models/transformers/transformer_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update tests/pipelines/omnigen/test_pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update tests/pipelines/omnigen/test_pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* consistent attention processor

* updata

* update

* check_inputs

* make style

* update testpipeline

* update testpipeline

---------
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: default avatarhlky <hlky@hlky.ac>
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent ed4b7522
......@@ -89,6 +89,8 @@
title: Kandinsky
- local: using-diffusers/ip_adapter
title: IP-Adapter
- local: using-diffusers/omnigen
title: OmniGen
- local: using-diffusers/pag
title: PAG
- local: using-diffusers/controlnet
......@@ -292,6 +294,8 @@
title: LTXVideoTransformer3DModel
- local: api/models/mochi_transformer3d
title: MochiTransformer3DModel
- local: api/models/omnigen_transformer
title: OmniGenTransformer2DModel
- local: api/models/pixart_transformer2d
title: PixArtTransformer2DModel
- local: api/models/prior_transformer
......@@ -448,6 +452,8 @@
title: MultiDiffusion
- local: api/pipelines/musicldm
title: MusicLDM
- local: api/pipelines/omnigen
title: OmniGen
- local: api/pipelines/pag
title: PAG
- local: api/pipelines/paint_by_example
......
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# OmniGenTransformer2DModel
A Transformer model that accepts multimodal instructions to generate images for [OmniGen](https://github.com/VectorSpaceLab/OmniGen/).
## OmniGenTransformer2DModel
[[autodoc]] OmniGenTransformer2DModel
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-->
# OmniGen
[OmniGen: Unified Image Generation](https://arxiv.org/pdf/2409.11340) from BAAI, by Shitao Xiao, Yueze Wang, Junjie Zhou, Huaying Yuan, Xingrun Xing, Ruiran Yan, Chaofan Li, Shuting Wang, Tiejun Huang, Zheng Liu.
The abstract from the paper is:
*The emergence of Large Language Models (LLMs) has unified language
generation tasks and revolutionized human-machine interaction.
However, in the realm of image generation, a unified model capable of handling various tasks
within a single framework remains largely unexplored. In
this work, we introduce OmniGen, a new diffusion model
for unified image generation. OmniGen is characterized
by the following features: 1) Unification: OmniGen not
only demonstrates text-to-image generation capabilities but
also inherently supports various downstream tasks, such
as image editing, subject-driven generation, and visual conditional generation. 2) Simplicity: The architecture of
OmniGen is highly simplified, eliminating the need for additional plugins. Moreover, compared to existing diffusion
models, it is more user-friendly and can complete complex
tasks end-to-end through instructions without the need for
extra intermediate steps, greatly simplifying the image generation workflow. 3) Knowledge Transfer: Benefit from
learning in a unified format, OmniGen effectively transfers
knowledge across different tasks, manages unseen tasks and
domains, and exhibits novel capabilities. We also explore
the model’s reasoning capabilities and potential applications of the chain-of-thought mechanism.
This work represents the first attempt at a general-purpose image generation model,
and we will release our resources at https:
//github.com/VectorSpaceLab/OmniGen to foster future advancements.*
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
This pipeline was contributed by [staoxiao](https://github.com/staoxiao). The original codebase can be found [here](https://github.com/VectorSpaceLab/OmniGen). The original weights can be found under [hf.co/shitao](https://huggingface.co/Shitao/OmniGen-v1).
## Inference
First, load the pipeline:
```python
import torch
from diffusers import OmniGenPipeline
pipe = OmniGenPipeline.from_pretrained(
"Shitao/OmniGen-v1-diffusers",
torch_dtype=torch.bfloat16
)
pipe.to("cuda")
```
For text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image.
You can try setting the `height` and `width` parameters to generate images with different size.
```py
prompt = "Realistic photo. A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD."
image = pipe(
prompt=prompt,
height=1024,
width=1024,
guidance_scale=3,
generator=torch.Generator(device="cpu").manual_seed(111),
).images[0]
image
```
OmniGen supports multimodal inputs.
When the input includes an image, you need to add a placeholder `<img><|image_1|></img>` in the text prompt to represent the image.
It is recommended to enable `use_input_image_size_as_output` to keep the edited image the same size as the original image.
```py
prompt="<img><|image_1|></img> Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola."
input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png")]
image = pipe(
prompt=prompt,
input_images=input_images,
guidance_scale=2,
img_guidance_scale=1.6,
use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(222)).images[0]
image
```
## OmniGenPipeline
[[autodoc]] OmniGenPipeline
- all
- __call__
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# OmniGen
OmniGen is an image generation model. Unlike existing text-to-image models, OmniGen is a single model designed to handle a variety of tasks (e.g., text-to-image, image editing, controllable generation). It has the following features:
- Minimalist model architecture, consisting of only a VAE and a transformer module, for joint modeling of text and images.
- Support for multimodal inputs. It can process any text-image mixed data as instructions for image generation, rather than relying solely on text.
For more information, please refer to the [paper](https://arxiv.org/pdf/2409.11340).
This guide will walk you through using OmniGen for various tasks and use cases.
## Load model checkpoints
Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method.
```py
import torch
from diffusers import OmniGenPipeline
pipe = OmniGenPipeline.from_pretrained(
"Shitao/OmniGen-v1-diffusers",
torch_dtype=torch.bfloat16
)
```
## Text-to-image
For text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image.
You can try setting the `height` and `width` parameters to generate images with different size.
```py
import torch
from diffusers import OmniGenPipeline
pipe = OmniGenPipeline.from_pretrained(
"Shitao/OmniGen-v1-diffusers",
torch_dtype=torch.bfloat16
)
pipe.to("cuda")
prompt = "Realistic photo. A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD."
image = pipe(
prompt=prompt,
height=1024,
width=1024,
guidance_scale=3,
generator=torch.Generator(device="cpu").manual_seed(111),
).images[0]
image
```
<div class="flex justify-center">
<img src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png" alt="generated image"/>
</div>
## Image edit
OmniGen supports multimodal inputs.
When the input includes an image, you need to add a placeholder `<img><|image_1|></img>` in the text prompt to represent the image.
It is recommended to enable `use_input_image_size_as_output` to keep the edited image the same size as the original image.
```py
import torch
from diffusers import OmniGenPipeline
from diffusers.utils import load_image
pipe = OmniGenPipeline.from_pretrained(
"Shitao/OmniGen-v1-diffusers",
torch_dtype=torch.bfloat16
)
pipe.to("cuda")
prompt="<img><|image_1|></img> Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola."
input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png")]
image = pipe(
prompt=prompt,
input_images=input_images,
guidance_scale=2,
img_guidance_scale=1.6,
use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(222)).images[0]
image
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">original image</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">edited image</figcaption>
</div>
</div>
OmniGen has some interesting features, such as visual reasoning, as shown in the example below.
```py
prompt="If the woman is thirsty, what should she take? Find it in the image and highlight it in blue. <img><|image_1|></img>"
input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png")]
image = pipe(
prompt=prompt,
input_images=input_images,
guidance_scale=2,
img_guidance_scale=1.6,
use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(0)).images[0]
image
```
<div class="flex justify-center">
<img src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/reasoning.png" alt="generated image"/>
</div>
## Controllable generation
OmniGen can handle several classic computer vision tasks.
As shown below, OmniGen can detect human skeletons in input images, which can be used as control conditions to generate new images.
```py
import torch
from diffusers import OmniGenPipeline
from diffusers.utils import load_image
pipe = OmniGenPipeline.from_pretrained(
"Shitao/OmniGen-v1-diffusers",
torch_dtype=torch.bfloat16
)
pipe.to("cuda")
prompt="Detect the skeleton of human in this image: <img><|image_1|></img>"
input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png")]
image1 = pipe(
prompt=prompt,
input_images=input_images,
guidance_scale=2,
img_guidance_scale=1.6,
use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(333)).images[0]
image1
prompt="Generate a new photo using the following picture and text as conditions: <img><|image_1|></img>\n A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him."
input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/skeletal.png")]
image2 = pipe(
prompt=prompt,
input_images=input_images,
guidance_scale=2,
img_guidance_scale=1.6,
use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(333)).images[0]
image2
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">original image</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/skeletal.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">detected skeleton</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/skeletal2img.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">skeleton to image</figcaption>
</div>
</div>
OmniGen can also directly use relevant information from input images to generate new images.
```py
import torch
from diffusers import OmniGenPipeline
from diffusers.utils import load_image
pipe = OmniGenPipeline.from_pretrained(
"Shitao/OmniGen-v1-diffusers",
torch_dtype=torch.bfloat16
)
pipe.to("cuda")
prompt="Following the pose of this image <img><|image_1|></img>, generate a new photo: A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him."
input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png")]
image = pipe(
prompt=prompt,
input_images=input_images,
guidance_scale=2,
img_guidance_scale=1.6,
use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(0)).images[0]
image
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/same_pose.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image</figcaption>
</div>
</div>
## ID and object preserving
OmniGen can generate multiple images based on the people and objects in the input image and supports inputting multiple images simultaneously.
Additionally, OmniGen can extract desired objects from an image containing multiple objects based on instructions.
```py
import torch
from diffusers import OmniGenPipeline
from diffusers.utils import load_image
pipe = OmniGenPipeline.from_pretrained(
"Shitao/OmniGen-v1-diffusers",
torch_dtype=torch.bfloat16
)
pipe.to("cuda")
prompt="A man and a woman are sitting at a classroom desk. The man is the man with yellow hair in <img><|image_1|></img>. The woman is the woman on the left of <img><|image_2|></img>"
input_image_1 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/3.png")
input_image_2 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/4.png")
input_images=[input_image_1, input_image_2]
image = pipe(
prompt=prompt,
input_images=input_images,
height=1024,
width=1024,
guidance_scale=2.5,
img_guidance_scale=1.6,
generator=torch.Generator(device="cpu").manual_seed(666)).images[0]
image
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/3.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">input_image_1</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/4.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">input_image_2</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/id2.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image</figcaption>
</div>
</div>
```py
import torch
from diffusers import OmniGenPipeline
from diffusers.utils import load_image
pipe = OmniGenPipeline.from_pretrained(
"Shitao/OmniGen-v1-diffusers",
torch_dtype=torch.bfloat16
)
pipe.to("cuda")
prompt="A woman is walking down the street, wearing a white long-sleeve blouse with lace details on the sleeves, paired with a blue pleated skirt. The woman is <img><|image_1|></img>. The long-sleeve blouse and a pleated skirt are <img><|image_2|></img>."
input_image_1 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/emma.jpeg")
input_image_2 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/dress.jpg")
input_images=[input_image_1, input_image_2]
image = pipe(
prompt=prompt,
input_images=input_images,
height=1024,
width=1024,
guidance_scale=2.5,
img_guidance_scale=1.6,
generator=torch.Generator(device="cpu").manual_seed(666)).images[0]
image
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/emma.jpeg"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">person image</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/dress.jpg"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">clothe image</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/tryon.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image</figcaption>
</div>
</div>
## Optimization when inputting multiple images
For text-to-image task, OmniGen requires minimal memory and time costs (9GB memory and 31s for a 1024x1024 image on A800 GPU).
However, when using input images, the computational cost increases.
Here are some guidelines to help you reduce computational costs when inputting multiple images. The experiments are conducted on an A800 GPU with two input images.
Like other pipelines, you can reduce memory usage by offloading the model: `pipe.enable_model_cpu_offload()` or `pipe.enable_sequential_cpu_offload() `.
In OmniGen, you can also decrease computational overhead by reducing the `max_input_image_size`.
The memory consumption for different image sizes is shown in the table below:
| Method | Memory Usage |
|---------------------------|--------------|
| max_input_image_size=1024 | 40GB |
| max_input_image_size=512 | 17GB |
| max_input_image_size=256 | 14GB |
import argparse
import os
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from transformers import AutoTokenizer
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenPipeline, OmniGenTransformer2DModel
def main(args):
# checkpoint from https://huggingface.co/Shitao/OmniGen-v1
if not os.path.exists(args.origin_ckpt_path):
print("Model not found, downloading...")
cache_folder = os.getenv("HF_HUB_CACHE")
args.origin_ckpt_path = snapshot_download(
repo_id=args.origin_ckpt_path,
cache_dir=cache_folder,
ignore_patterns=["flax_model.msgpack", "rust_model.ot", "tf_model.h5", "model.pt"],
)
print(f"Downloaded model to {args.origin_ckpt_path}")
ckpt = os.path.join(args.origin_ckpt_path, "model.safetensors")
ckpt = load_file(ckpt, device="cpu")
mapping_dict = {
"pos_embed": "patch_embedding.pos_embed",
"x_embedder.proj.weight": "patch_embedding.output_image_proj.weight",
"x_embedder.proj.bias": "patch_embedding.output_image_proj.bias",
"input_x_embedder.proj.weight": "patch_embedding.input_image_proj.weight",
"input_x_embedder.proj.bias": "patch_embedding.input_image_proj.bias",
"final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
"final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
"final_layer.linear.weight": "proj_out.weight",
"final_layer.linear.bias": "proj_out.bias",
"time_token.mlp.0.weight": "time_token.linear_1.weight",
"time_token.mlp.0.bias": "time_token.linear_1.bias",
"time_token.mlp.2.weight": "time_token.linear_2.weight",
"time_token.mlp.2.bias": "time_token.linear_2.bias",
"t_embedder.mlp.0.weight": "t_embedder.linear_1.weight",
"t_embedder.mlp.0.bias": "t_embedder.linear_1.bias",
"t_embedder.mlp.2.weight": "t_embedder.linear_2.weight",
"t_embedder.mlp.2.bias": "t_embedder.linear_2.bias",
"llm.embed_tokens.weight": "embed_tokens.weight",
}
converted_state_dict = {}
for k, v in ckpt.items():
if k in mapping_dict:
converted_state_dict[mapping_dict[k]] = v
elif "qkv" in k:
to_q, to_k, to_v = v.chunk(3)
converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_q.weight"] = to_q
converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_k.weight"] = to_k
converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_v.weight"] = to_v
elif "o_proj" in k:
converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_out.0.weight"] = v
else:
converted_state_dict[k[4:]] = v
transformer = OmniGenTransformer2DModel(
rope_scaling={
"long_factor": [
1.0299999713897705,
1.0499999523162842,
1.0499999523162842,
1.0799999237060547,
1.2299998998641968,
1.2299998998641968,
1.2999999523162842,
1.4499999284744263,
1.5999999046325684,
1.6499998569488525,
1.8999998569488525,
2.859999895095825,
3.68999981880188,
5.419999599456787,
5.489999771118164,
5.489999771118164,
9.09000015258789,
11.579999923706055,
15.65999984741211,
15.769999504089355,
15.789999961853027,
18.360000610351562,
21.989999771118164,
23.079999923706055,
30.009998321533203,
32.35000228881836,
32.590003967285156,
35.56000518798828,
39.95000457763672,
53.840003967285156,
56.20000457763672,
57.95000457763672,
59.29000473022461,
59.77000427246094,
59.920005798339844,
61.190006256103516,
61.96000671386719,
62.50000762939453,
63.3700065612793,
63.48000717163086,
63.48000717163086,
63.66000747680664,
63.850006103515625,
64.08000946044922,
64.760009765625,
64.80001068115234,
64.81001281738281,
64.81001281738281,
],
"short_factor": [
1.05,
1.05,
1.05,
1.1,
1.1,
1.1,
1.2500000000000002,
1.2500000000000002,
1.4000000000000004,
1.4500000000000004,
1.5500000000000005,
1.8500000000000008,
1.9000000000000008,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.1000000000000005,
2.1000000000000005,
2.2,
2.3499999999999996,
2.3499999999999996,
2.3499999999999996,
2.3499999999999996,
2.3999999999999995,
2.3999999999999995,
2.6499999999999986,
2.6999999999999984,
2.8999999999999977,
2.9499999999999975,
3.049999999999997,
3.049999999999997,
3.049999999999997,
],
"type": "su",
},
patch_size=2,
in_channels=4,
pos_embed_max_size=192,
)
transformer.load_state_dict(converted_state_dict, strict=True)
transformer.to(torch.bfloat16)
num_model_params = sum(p.numel() for p in transformer.parameters())
print(f"Total number of transformer parameters: {num_model_params}")
scheduler = FlowMatchEulerDiscreteScheduler(invert_sigmas=True, num_train_timesteps=1)
vae = AutoencoderKL.from_pretrained(os.path.join(args.origin_ckpt_path, "vae"), torch_dtype=torch.float32)
tokenizer = AutoTokenizer.from_pretrained(args.origin_ckpt_path)
pipeline = OmniGenPipeline(tokenizer=tokenizer, transformer=transformer, vae=vae, scheduler=scheduler)
pipeline.save_pretrained(args.dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--origin_ckpt_path",
default="Shitao/OmniGen-v1",
type=str,
required=False,
help="Path to the checkpoint to convert.",
)
parser.add_argument(
"--dump_path", default="OmniGen-v1-diffusers", type=str, required=False, help="Path to the output pipeline."
)
args = parser.parse_args()
main(args)
......@@ -124,6 +124,7 @@ else:
"MotionAdapter",
"MultiAdapter",
"MultiControlNetModel",
"OmniGenTransformer2DModel",
"PixArtTransformer2DModel",
"PriorTransformer",
"SanaTransformer2DModel",
......@@ -342,6 +343,7 @@ else:
"MarigoldNormalsPipeline",
"MochiPipeline",
"MusicLDMPipeline",
"OmniGenPipeline",
"PaintByExamplePipeline",
"PIAPipeline",
"PixArtAlphaPipeline",
......@@ -638,6 +640,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MotionAdapter,
MultiAdapter,
MultiControlNetModel,
OmniGenTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
SanaTransformer2DModel,
......@@ -835,6 +838,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MarigoldNormalsPipeline,
MochiPipeline,
MusicLDMPipeline,
OmniGenPipeline,
PaintByExamplePipeline,
PIAPipeline,
PixArtAlphaPipeline,
......
......@@ -73,6 +73,7 @@ if is_torch_available():
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
......@@ -142,6 +143,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LTXVideoTransformer3DModel,
LuminaNextDiT2DModel,
MochiTransformer3DModel,
OmniGenTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
SanaTransformer2DModel,
......
......@@ -71,7 +71,7 @@ class AdaLayerNorm(nn.Module):
if self.chunk_dim == 1:
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
# other if-branch. This branch is specific to CogVideoX for now.
# other if-branch. This branch is specific to CogVideoX and OmniGen for now.
shift, scale = temb.chunk(2, dim=1)
shift = shift[:, None, :]
scale = scale[:, None, :]
......
......@@ -22,5 +22,6 @@ if is_torch_available():
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_ltx import LTXVideoTransformer3DModel
from .transformer_mochi import MochiTransformer3DModel
from .transformer_omnigen import OmniGenTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
This diff is collapsed.
......@@ -264,6 +264,7 @@ else:
)
_import_structure["mochi"] = ["MochiPipeline"]
_import_structure["musicldm"] = ["MusicLDMPipeline"]
_import_structure["omnigen"] = ["OmniGenPipeline"]
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
_import_structure["pia"] = ["PIAPipeline"]
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
......@@ -602,6 +603,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
)
from .mochi import MochiPipeline
from .musicldm import MusicLDMPipeline
from .omnigen import OmniGenPipeline
from .pag import (
AnimateDiffPAGPipeline,
HunyuanDiTPAGPipeline,
......
......@@ -48,9 +48,14 @@ EXAMPLE_DOC_STRING = """
>>> from huggingface_hub import snapshot_download
>>> snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir="BestWishYsh/ConsisID-preview")
>>> face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = (
... prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16)
... )
>>> (
... face_helper_1,
... face_helper_2,
... face_clip_model,
... face_main_model,
... eva_transform_mean,
... eva_transform_std,
... ) = prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16)
>>> pipe = ConsisIDPipeline.from_pretrained("BestWishYsh/ConsisID-preview", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
......
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_omnigen"] = ["OmniGenPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_omnigen import OmniGenPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
This diff is collapsed.
# Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Dict, List
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
def crop_image(pil_image, max_image_size):
"""
Crop the image so that its height and width does not exceed `max_image_size`, while ensuring both the height and
width are multiples of 16.
"""
while min(*pil_image.size) >= 2 * max_image_size:
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
if max(*pil_image.size) > max_image_size:
scale = max_image_size / max(*pil_image.size)
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
if min(*pil_image.size) < 16:
scale = 16 / min(*pil_image.size)
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
arr = np.array(pil_image)
crop_y1 = (arr.shape[0] % 16) // 2
crop_y2 = arr.shape[0] % 16 - crop_y1
crop_x1 = (arr.shape[1] % 16) // 2
crop_x2 = arr.shape[1] % 16 - crop_x1
arr = arr[crop_y1 : arr.shape[0] - crop_y2, crop_x1 : arr.shape[1] - crop_x2]
return Image.fromarray(arr)
class OmniGenMultiModalProcessor:
def __init__(self, text_tokenizer, max_image_size: int = 1024):
self.text_tokenizer = text_tokenizer
self.max_image_size = max_image_size
self.image_transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: crop_image(pil_image, max_image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
self.collator = OmniGenCollator()
def reset_max_image_size(self, max_image_size):
self.max_image_size = max_image_size
self.image_transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: crop_image(pil_image, max_image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
def process_image(self, image):
if isinstance(image, str):
image = Image.open(image).convert("RGB")
return self.image_transform(image)
def process_multi_modal_prompt(self, text, input_images):
text = self.add_prefix_instruction(text)
if input_images is None or len(input_images) == 0:
model_inputs = self.text_tokenizer(text)
return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None}
pattern = r"<\|image_\d+\|>"
prompt_chunks = [self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)]
for i in range(1, len(prompt_chunks)):
if prompt_chunks[i][0] == 1:
prompt_chunks[i] = prompt_chunks[i][1:]
image_tags = re.findall(pattern, text)
image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
unique_image_ids = sorted(set(image_ids))
assert unique_image_ids == list(
range(1, len(unique_image_ids) + 1)
), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
# total images must be the same as the number of image tags
assert (
len(unique_image_ids) == len(input_images)
), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
input_images = [input_images[x - 1] for x in image_ids]
all_input_ids = []
img_inx = []
for i in range(len(prompt_chunks)):
all_input_ids.extend(prompt_chunks[i])
if i != len(prompt_chunks) - 1:
start_inx = len(all_input_ids)
size = input_images[i].size(-2) * input_images[i].size(-1) // 16 // 16
img_inx.append([start_inx, start_inx + size])
all_input_ids.extend([0] * size)
return {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx}
def add_prefix_instruction(self, prompt):
user_prompt = "<|user|>\n"
generation_prompt = "Generate an image according to the following instructions\n"
assistant_prompt = "<|assistant|>\n<|diffusion|>"
prompt_suffix = "<|end|>\n"
prompt = f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}"
return prompt
def __call__(
self,
instructions: List[str],
input_images: List[List[str]] = None,
height: int = 1024,
width: int = 1024,
negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.",
use_img_cfg: bool = True,
separate_cfg_input: bool = False,
use_input_image_size_as_output: bool = False,
num_images_per_prompt: int = 1,
) -> Dict:
if isinstance(instructions, str):
instructions = [instructions]
input_images = [input_images]
input_data = []
for i in range(len(instructions)):
cur_instruction = instructions[i]
cur_input_images = None if input_images is None else input_images[i]
if cur_input_images is not None and len(cur_input_images) > 0:
cur_input_images = [self.process_image(x) for x in cur_input_images]
else:
cur_input_images = None
assert "<img><|image_1|></img>" not in cur_instruction
mllm_input = self.process_multi_modal_prompt(cur_instruction, cur_input_images)
neg_mllm_input, img_cfg_mllm_input = None, None
neg_mllm_input = self.process_multi_modal_prompt(negative_prompt, None)
if use_img_cfg:
if cur_input_images is not None and len(cur_input_images) >= 1:
img_cfg_prompt = [f"<img><|image_{i + 1}|></img>" for i in range(len(cur_input_images))]
img_cfg_mllm_input = self.process_multi_modal_prompt(" ".join(img_cfg_prompt), cur_input_images)
else:
img_cfg_mllm_input = neg_mllm_input
for _ in range(num_images_per_prompt):
if use_input_image_size_as_output:
input_data.append(
(
mllm_input,
neg_mllm_input,
img_cfg_mllm_input,
[mllm_input["pixel_values"][0].size(-2), mllm_input["pixel_values"][0].size(-1)],
)
)
else:
input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
return self.collator(input_data)
class OmniGenCollator:
def __init__(self, pad_token_id=2, hidden_size=3072):
self.pad_token_id = pad_token_id
self.hidden_size = hidden_size
def create_position(self, attention_mask, num_tokens_for_output_images):
position_ids = []
text_length = attention_mask.size(-1)
img_length = max(num_tokens_for_output_images)
for mask in attention_mask:
temp_l = torch.sum(mask)
temp_position = [0] * (text_length - temp_l) + list(
range(temp_l + img_length + 1)
) # we add a time embedding into the sequence, so add one more token
position_ids.append(temp_position)
return torch.LongTensor(position_ids)
def create_mask(self, attention_mask, num_tokens_for_output_images):
"""
OmniGen applies causal attention to each element in the sequence, but applies bidirectional attention within
each image sequence References: [OmniGen](https://arxiv.org/pdf/2409.11340)
"""
extended_mask = []
padding_images = []
text_length = attention_mask.size(-1)
img_length = max(num_tokens_for_output_images)
seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more token
inx = 0
for mask in attention_mask:
temp_l = torch.sum(mask)
pad_l = text_length - temp_l
temp_mask = torch.tril(torch.ones(size=(temp_l + 1, temp_l + 1)))
image_mask = torch.zeros(size=(temp_l + 1, img_length))
temp_mask = torch.cat([temp_mask, image_mask], dim=-1)
image_mask = torch.ones(size=(img_length, temp_l + img_length + 1))
temp_mask = torch.cat([temp_mask, image_mask], dim=0)
if pad_l > 0:
pad_mask = torch.zeros(size=(temp_l + 1 + img_length, pad_l))
temp_mask = torch.cat([pad_mask, temp_mask], dim=-1)
pad_mask = torch.ones(size=(pad_l, seq_len))
temp_mask = torch.cat([pad_mask, temp_mask], dim=0)
true_img_length = num_tokens_for_output_images[inx]
pad_img_length = img_length - true_img_length
if pad_img_length > 0:
temp_mask[:, -pad_img_length:] = 0
temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size))
else:
temp_padding_imgs = None
extended_mask.append(temp_mask.unsqueeze(0))
padding_images.append(temp_padding_imgs)
inx += 1
return torch.cat(extended_mask, dim=0), padding_images
def adjust_attention_for_input_images(self, attention_mask, image_sizes):
for b_inx in image_sizes.keys():
for start_inx, end_inx in image_sizes[b_inx]:
attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1
return attention_mask
def pad_input_ids(self, input_ids, image_sizes):
max_l = max([len(x) for x in input_ids])
padded_ids = []
attention_mask = []
for i in range(len(input_ids)):
temp_ids = input_ids[i]
temp_l = len(temp_ids)
pad_l = max_l - temp_l
if pad_l == 0:
attention_mask.append([1] * max_l)
padded_ids.append(temp_ids)
else:
attention_mask.append([0] * pad_l + [1] * temp_l)
padded_ids.append([self.pad_token_id] * pad_l + temp_ids)
if i in image_sizes:
new_inx = []
for old_inx in image_sizes[i]:
new_inx.append([x + pad_l for x in old_inx])
image_sizes[i] = new_inx
return torch.LongTensor(padded_ids), torch.LongTensor(attention_mask), image_sizes
def process_mllm_input(self, mllm_inputs, target_img_size):
num_tokens_for_output_images = []
for img_size in target_img_size:
num_tokens_for_output_images.append(img_size[0] * img_size[1] // 16 // 16)
pixel_values, image_sizes = [], {}
b_inx = 0
for x in mllm_inputs:
if x["pixel_values"] is not None:
pixel_values.extend(x["pixel_values"])
for size in x["image_sizes"]:
if b_inx not in image_sizes:
image_sizes[b_inx] = [size]
else:
image_sizes[b_inx].append(size)
b_inx += 1
pixel_values = [x.unsqueeze(0) for x in pixel_values]
input_ids = [x["input_ids"] for x in mllm_inputs]
padded_input_ids, attention_mask, image_sizes = self.pad_input_ids(input_ids, image_sizes)
position_ids = self.create_position(attention_mask, num_tokens_for_output_images)
attention_mask, padding_images = self.create_mask(attention_mask, num_tokens_for_output_images)
attention_mask = self.adjust_attention_for_input_images(attention_mask, image_sizes)
return padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes
def __call__(self, features):
mllm_inputs = [f[0] for f in features]
cfg_mllm_inputs = [f[1] for f in features]
img_cfg_mllm_input = [f[2] for f in features]
target_img_size = [f[3] for f in features]
if img_cfg_mllm_input[0] is not None:
mllm_inputs = mllm_inputs + cfg_mllm_inputs + img_cfg_mllm_input
target_img_size = target_img_size + target_img_size + target_img_size
else:
mllm_inputs = mllm_inputs + cfg_mllm_inputs
target_img_size = target_img_size + target_img_size
(
all_padded_input_ids,
all_position_ids,
all_attention_mask,
all_padding_images,
all_pixel_values,
all_image_sizes,
) = self.process_mllm_input(mllm_inputs, target_img_size)
data = {
"input_ids": all_padded_input_ids,
"attention_mask": all_attention_mask,
"position_ids": all_position_ids,
"input_pixel_values": all_pixel_values,
"input_image_sizes": all_image_sizes,
}
return data
......@@ -621,6 +621,21 @@ class MultiControlNetModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class OmniGenTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PixArtTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -1217,6 +1217,21 @@ class MusicLDMPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class OmniGenPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class PaintByExamplePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import OmniGenTransformer2DModel
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class OmniGenTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = OmniGenTransformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
height = 8
width = 8
sequence_length = 24
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
timestep = torch.rand(size=(batch_size,), dtype=hidden_states.dtype).to(torch_device)
input_ids = torch.randint(0, 10, (batch_size, sequence_length)).to(torch_device)
input_img_latents = [torch.randn((1, num_channels, height, width)).to(torch_device)]
input_image_sizes = {0: [[0, 0 + height * width // 2 // 2]]}
attn_seq_length = sequence_length + 1 + height * width // 2 // 2
attention_mask = torch.ones((batch_size, attn_seq_length, attn_seq_length)).to(torch_device)
position_ids = torch.LongTensor([list(range(attn_seq_length))] * batch_size).to(torch_device)
return {
"hidden_states": hidden_states,
"timestep": timestep,
"input_ids": input_ids,
"input_img_latents": input_img_latents,
"input_image_sizes": input_image_sizes,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
@property
def input_shape(self):
return (4, 8, 8)
@property
def output_shape(self):
return (4, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"hidden_size": 16,
"num_attention_heads": 4,
"num_key_value_heads": 4,
"intermediate_size": 32,
"num_layers": 1,
"pad_token_id": 0,
"vocab_size": 100,
"in_channels": 4,
"time_step_dim": 4,
"rope_scaling": {"long_factor": list(range(1, 3)), "short_factor": list(range(1, 3))},
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"OmniGenTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
import gc
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenPipeline, OmniGenTransformer2DModel
from diffusers.utils.testing_utils import (
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
)
from ..test_pipelines_common import PipelineTesterMixin
class OmniGenPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = OmniGenPipeline
params = frozenset(
[
"prompt",
"guidance_scale",
]
)
batch_params = frozenset(
[
"prompt",
]
)
def get_dummy_components(self):
torch.manual_seed(0)
transformer = OmniGenTransformer2DModel(
hidden_size=16,
num_attention_heads=4,
num_key_value_heads=4,
intermediate_size=32,
num_layers=1,
in_channels=4,
time_step_dim=4,
rope_scaling={"long_factor": list(range(1, 3)), "short_factor": list(range(1, 3))},
)
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4, 4, 4, 4),
layers_per_block=1,
latent_channels=4,
norm_num_groups=1,
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
)
scheduler = FlowMatchEulerDiscreteScheduler(invert_sigmas=True, num_train_timesteps=1)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"tokenizer": tokenizer,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 1,
"guidance_scale": 3.0,
"output_type": "np",
"height": 16,
"width": 16,
}
return inputs
def test_inference(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
generated_image = pipe(**inputs).images[0]
self.assertEqual(generated_image.shape, (16, 16, 3))
@slow
@require_torch_gpu
class OmniGenPipelineSlowTests(unittest.TestCase):
pipeline_class = OmniGenPipeline
repo_id = "shitao/OmniGen-v1-diffusers"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
return {
"prompt": "A photo of a cat",
"num_inference_steps": 2,
"guidance_scale": 2.5,
"output_type": "np",
"generator": generator,
}
def test_omnigen_inference(self):
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
expected_slice = np.array(
[
[0.1783447, 0.16772744, 0.14339337],
[0.17066911, 0.15521264, 0.13757327],
[0.17072496, 0.15531206, 0.13524258],
[0.16746324, 0.1564025, 0.13794944],
[0.16490817, 0.15258026, 0.13697758],
[0.16971767, 0.15826806, 0.13928896],
[0.16782972, 0.15547255, 0.13783783],
[0.16464645, 0.15281534, 0.13522372],
[0.16535294, 0.15301755, 0.13526791],
[0.16365296, 0.15092957, 0.13443318],
],
dtype=np.float32,
)
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
assert max_diff < 1e-4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment