Unverified Commit 57ac6738 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Refactor OmniGen (#10771)



* OmniGen model.py

* update OmniGenTransformerModel

* omnigen pipeline

* omnigen pipeline

* update omnigen_pipeline

* test case for omnigen

* update omnigenpipeline

* update docs

* update docs

* offload_transformer

* enable_transformer_block_cpu_offload

* update docs

* reformat

* reformat

* reformat

* update docs

* update docs

* make style

* make style

* Update docs/source/en/api/models/omnigen_transformer.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* update docs

* revert changes to examples/

* update OmniGen2DModel

* make style

* update test cases

* Update docs/source/en/api/pipelines/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* update docs

* typo

* Update src/diffusers/models/embeddings.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/models/attention.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/models/transformers/transformer_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/models/transformers/transformer_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/models/transformers/transformer_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update tests/pipelines/omnigen/test_pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update tests/pipelines/omnigen/test_pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* consistent attention processor

* updata

* update

* check_inputs

* make style

* update testpipeline

* update testpipeline

* refactor omnigen

* more updates

* apply review suggestion

---------
Co-authored-by: default avatarshitao <2906698981@qq.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent 81440fd4
...@@ -14,6 +14,17 @@ specific language governing permissions and limitations under the License. ...@@ -14,6 +14,17 @@ specific language governing permissions and limitations under the License.
A Transformer model that accepts multimodal instructions to generate images for [OmniGen](https://github.com/VectorSpaceLab/OmniGen/). A Transformer model that accepts multimodal instructions to generate images for [OmniGen](https://github.com/VectorSpaceLab/OmniGen/).
The abstract from the paper is:
*The emergence of Large Language Models (LLMs) has unified language generation tasks and revolutionized human-machine interaction. However, in the realm of image generation, a unified model capable of handling various tasks within a single framework remains largely unexplored. In this work, we introduce OmniGen, a new diffusion model for unified image generation. OmniGen is characterized by the following features: 1) Unification: OmniGen not only demonstrates text-to-image generation capabilities but also inherently supports various downstream tasks, such as image editing, subject-driven generation, and visual conditional generation. 2) Simplicity: The architecture of OmniGen is highly simplified, eliminating the need for additional plugins. Moreover, compared to existing diffusion models, it is more user-friendly and can complete complex tasks end-to-end through instructions without the need for extra intermediate steps, greatly simplifying the image generation workflow. 3) Knowledge Transfer: Benefit from learning in a unified format, OmniGen effectively transfers knowledge across different tasks, manages unseen tasks and domains, and exhibits novel capabilities. We also explore the model’s reasoning capabilities and potential applications of the chain-of-thought mechanism. This work represents the first attempt at a general-purpose image generation model, and we will release our resources at https://github.com/VectorSpaceLab/OmniGen to foster future advancements.*
```python
import torch
from diffusers import OmniGenTransformer2DModel
transformer = OmniGenTransformer2DModel.from_pretrained("Shitao/OmniGen-v1-diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## OmniGenTransformer2DModel ## OmniGenTransformer2DModel
[[autodoc]] OmniGenTransformer2DModel [[autodoc]] OmniGenTransformer2DModel
...@@ -19,27 +19,7 @@ ...@@ -19,27 +19,7 @@
The abstract from the paper is: The abstract from the paper is:
*The emergence of Large Language Models (LLMs) has unified language *The emergence of Large Language Models (LLMs) has unified language generation tasks and revolutionized human-machine interaction. However, in the realm of image generation, a unified model capable of handling various tasks within a single framework remains largely unexplored. In this work, we introduce OmniGen, a new diffusion model for unified image generation. OmniGen is characterized by the following features: 1) Unification: OmniGen not only demonstrates text-to-image generation capabilities but also inherently supports various downstream tasks, such as image editing, subject-driven generation, and visual conditional generation. 2) Simplicity: The architecture of OmniGen is highly simplified, eliminating the need for additional plugins. Moreover, compared to existing diffusion models, it is more user-friendly and can complete complex tasks end-to-end through instructions without the need for extra intermediate steps, greatly simplifying the image generation workflow. 3) Knowledge Transfer: Benefit from learning in a unified format, OmniGen effectively transfers knowledge across different tasks, manages unseen tasks and domains, and exhibits novel capabilities. We also explore the model’s reasoning capabilities and potential applications of the chain-of-thought mechanism. This work represents the first attempt at a general-purpose image generation model, and we will release our resources at https://github.com/VectorSpaceLab/OmniGen to foster future advancements.*
generation tasks and revolutionized human-machine interaction.
However, in the realm of image generation, a unified model capable of handling various tasks
within a single framework remains largely unexplored. In
this work, we introduce OmniGen, a new diffusion model
for unified image generation. OmniGen is characterized
by the following features: 1) Unification: OmniGen not
only demonstrates text-to-image generation capabilities but
also inherently supports various downstream tasks, such
as image editing, subject-driven generation, and visual conditional generation. 2) Simplicity: The architecture of
OmniGen is highly simplified, eliminating the need for additional plugins. Moreover, compared to existing diffusion
models, it is more user-friendly and can complete complex
tasks end-to-end through instructions without the need for
extra intermediate steps, greatly simplifying the image generation workflow. 3) Knowledge Transfer: Benefit from
learning in a unified format, OmniGen effectively transfers
knowledge across different tasks, manages unseen tasks and
domains, and exhibits novel capabilities. We also explore
the model’s reasoning capabilities and potential applications of the chain-of-thought mechanism.
This work represents the first attempt at a general-purpose image generation model,
and we will release our resources at https:
//github.com/VectorSpaceLab/OmniGen to foster future advancements.*
<Tip> <Tip>
...@@ -49,7 +29,6 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m ...@@ -49,7 +29,6 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
This pipeline was contributed by [staoxiao](https://github.com/staoxiao). The original codebase can be found [here](https://github.com/VectorSpaceLab/OmniGen). The original weights can be found under [hf.co/shitao](https://huggingface.co/Shitao/OmniGen-v1). This pipeline was contributed by [staoxiao](https://github.com/staoxiao). The original codebase can be found [here](https://github.com/VectorSpaceLab/OmniGen). The original weights can be found under [hf.co/shitao](https://huggingface.co/Shitao/OmniGen-v1).
## Inference ## Inference
First, load the pipeline: First, load the pipeline:
...@@ -57,17 +36,15 @@ First, load the pipeline: ...@@ -57,17 +36,15 @@ First, load the pipeline:
```python ```python
import torch import torch
from diffusers import OmniGenPipeline from diffusers import OmniGenPipeline
pipe = OmniGenPipeline.from_pretrained(
"Shitao/OmniGen-v1-diffusers", pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1-diffusers", torch_dtype=torch.bfloat16)
torch_dtype=torch.bfloat16
)
pipe.to("cuda") pipe.to("cuda")
``` ```
For text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image. For text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image.
You can try setting the `height` and `width` parameters to generate images with different size. You can try setting the `height` and `width` parameters to generate images with different size.
```py ```python
prompt = "Realistic photo. A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD." prompt = "Realistic photo. A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD."
image = pipe( image = pipe(
prompt=prompt, prompt=prompt,
...@@ -76,14 +53,14 @@ image = pipe( ...@@ -76,14 +53,14 @@ image = pipe(
guidance_scale=3, guidance_scale=3,
generator=torch.Generator(device="cpu").manual_seed(111), generator=torch.Generator(device="cpu").manual_seed(111),
).images[0] ).images[0]
image image.save("output.png")
``` ```
OmniGen supports multimodal inputs. OmniGen supports multimodal inputs.
When the input includes an image, you need to add a placeholder `<img><|image_1|></img>` in the text prompt to represent the image. When the input includes an image, you need to add a placeholder `<img><|image_1|></img>` in the text prompt to represent the image.
It is recommended to enable `use_input_image_size_as_output` to keep the edited image the same size as the original image. It is recommended to enable `use_input_image_size_as_output` to keep the edited image the same size as the original image.
```py ```python
prompt="<img><|image_1|></img> Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola." prompt="<img><|image_1|></img> Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola."
input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png")] input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png")]
image = pipe( image = pipe(
...@@ -93,14 +70,11 @@ image = pipe( ...@@ -93,14 +70,11 @@ image = pipe(
img_guidance_scale=1.6, img_guidance_scale=1.6,
use_input_image_size_as_output=True, use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(222)).images[0] generator=torch.Generator(device="cpu").manual_seed(222)).images[0]
image image.save("output.png")
``` ```
## OmniGenPipeline ## OmniGenPipeline
[[autodoc]] OmniGenPipeline [[autodoc]] OmniGenPipeline
- all - all
- __call__ - __call__
...@@ -19,25 +19,22 @@ For more information, please refer to the [paper](https://arxiv.org/pdf/2409.113 ...@@ -19,25 +19,22 @@ For more information, please refer to the [paper](https://arxiv.org/pdf/2409.113
This guide will walk you through using OmniGen for various tasks and use cases. This guide will walk you through using OmniGen for various tasks and use cases.
## Load model checkpoints ## Load model checkpoints
Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method. Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method.
```py ```python
import torch import torch
from diffusers import OmniGenPipeline from diffusers import OmniGenPipeline
pipe = OmniGenPipeline.from_pretrained(
"Shitao/OmniGen-v1-diffusers",
torch_dtype=torch.bfloat16
)
```
pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1-diffusers", torch_dtype=torch.bfloat16)
```
## Text-to-image ## Text-to-image
For text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image. For text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image.
You can try setting the `height` and `width` parameters to generate images with different size. You can try setting the `height` and `width` parameters to generate images with different size.
```py ```python
import torch import torch
from diffusers import OmniGenPipeline from diffusers import OmniGenPipeline
...@@ -55,8 +52,9 @@ image = pipe( ...@@ -55,8 +52,9 @@ image = pipe(
guidance_scale=3, guidance_scale=3,
generator=torch.Generator(device="cpu").manual_seed(111), generator=torch.Generator(device="cpu").manual_seed(111),
).images[0] ).images[0]
image image.save("output.png")
``` ```
<div class="flex justify-center"> <div class="flex justify-center">
<img src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png" alt="generated image"/> <img src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png" alt="generated image"/>
</div> </div>
...@@ -67,7 +65,7 @@ OmniGen supports multimodal inputs. ...@@ -67,7 +65,7 @@ OmniGen supports multimodal inputs.
When the input includes an image, you need to add a placeholder `<img><|image_1|></img>` in the text prompt to represent the image. When the input includes an image, you need to add a placeholder `<img><|image_1|></img>` in the text prompt to represent the image.
It is recommended to enable `use_input_image_size_as_output` to keep the edited image the same size as the original image. It is recommended to enable `use_input_image_size_as_output` to keep the edited image the same size as the original image.
```py ```python
import torch import torch
from diffusers import OmniGenPipeline from diffusers import OmniGenPipeline
from diffusers.utils import load_image from diffusers.utils import load_image
...@@ -86,9 +84,11 @@ image = pipe( ...@@ -86,9 +84,11 @@ image = pipe(
guidance_scale=2, guidance_scale=2,
img_guidance_scale=1.6, img_guidance_scale=1.6,
use_input_image_size_as_output=True, use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(222)).images[0] generator=torch.Generator(device="cpu").manual_seed(222)
image ).images[0]
image.save("output.png")
``` ```
<div class="flex flex-row gap-4"> <div class="flex flex-row gap-4">
<div class="flex-1"> <div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png"/> <img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png"/>
...@@ -101,7 +101,8 @@ image ...@@ -101,7 +101,8 @@ image
</div> </div>
OmniGen has some interesting features, such as visual reasoning, as shown in the example below. OmniGen has some interesting features, such as visual reasoning, as shown in the example below.
```py
```python
prompt="If the woman is thirsty, what should she take? Find it in the image and highlight it in blue. <img><|image_1|></img>" prompt="If the woman is thirsty, what should she take? Find it in the image and highlight it in blue. <img><|image_1|></img>"
input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png")] input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png")]
image = pipe( image = pipe(
...@@ -110,20 +111,20 @@ image = pipe( ...@@ -110,20 +111,20 @@ image = pipe(
guidance_scale=2, guidance_scale=2,
img_guidance_scale=1.6, img_guidance_scale=1.6,
use_input_image_size_as_output=True, use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(0)).images[0] generator=torch.Generator(device="cpu").manual_seed(0)
image ).images[0]
image.save("output.png")
``` ```
<div class="flex justify-center"> <div class="flex justify-center">
<img src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/reasoning.png" alt="generated image"/> <img src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/reasoning.png" alt="generated image"/>
</div> </div>
## Controllable generation ## Controllable generation
OmniGen can handle several classic computer vision tasks. OmniGen can handle several classic computer vision tasks. As shown below, OmniGen can detect human skeletons in input images, which can be used as control conditions to generate new images.
As shown below, OmniGen can detect human skeletons in input images, which can be used as control conditions to generate new images.
```py ```python
import torch import torch
from diffusers import OmniGenPipeline from diffusers import OmniGenPipeline
from diffusers.utils import load_image from diffusers.utils import load_image
...@@ -142,8 +143,9 @@ image1 = pipe( ...@@ -142,8 +143,9 @@ image1 = pipe(
guidance_scale=2, guidance_scale=2,
img_guidance_scale=1.6, img_guidance_scale=1.6,
use_input_image_size_as_output=True, use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(333)).images[0] generator=torch.Generator(device="cpu").manual_seed(333)
image1 ).images[0]
image1.save("image1.png")
prompt="Generate a new photo using the following picture and text as conditions: <img><|image_1|></img>\n A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him." prompt="Generate a new photo using the following picture and text as conditions: <img><|image_1|></img>\n A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him."
input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/skeletal.png")] input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/skeletal.png")]
...@@ -153,8 +155,9 @@ image2 = pipe( ...@@ -153,8 +155,9 @@ image2 = pipe(
guidance_scale=2, guidance_scale=2,
img_guidance_scale=1.6, img_guidance_scale=1.6,
use_input_image_size_as_output=True, use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(333)).images[0] generator=torch.Generator(device="cpu").manual_seed(333)
image2 ).images[0]
image2.save("image2.png")
``` ```
<div class="flex flex-row gap-4"> <div class="flex flex-row gap-4">
...@@ -174,7 +177,8 @@ image2 ...@@ -174,7 +177,8 @@ image2
OmniGen can also directly use relevant information from input images to generate new images. OmniGen can also directly use relevant information from input images to generate new images.
```py
```python
import torch import torch
from diffusers import OmniGenPipeline from diffusers import OmniGenPipeline
from diffusers.utils import load_image from diffusers.utils import load_image
...@@ -193,9 +197,11 @@ image = pipe( ...@@ -193,9 +197,11 @@ image = pipe(
guidance_scale=2, guidance_scale=2,
img_guidance_scale=1.6, img_guidance_scale=1.6,
use_input_image_size_as_output=True, use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(0)).images[0] generator=torch.Generator(device="cpu").manual_seed(0)
image ).images[0]
image.save("output.png")
``` ```
<div class="flex flex-row gap-4"> <div class="flex flex-row gap-4">
<div class="flex-1"> <div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/same_pose.png"/> <img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/same_pose.png"/>
...@@ -203,13 +209,12 @@ image ...@@ -203,13 +209,12 @@ image
</div> </div>
</div> </div>
## ID and object preserving ## ID and object preserving
OmniGen can generate multiple images based on the people and objects in the input image and supports inputting multiple images simultaneously. OmniGen can generate multiple images based on the people and objects in the input image and supports inputting multiple images simultaneously.
Additionally, OmniGen can extract desired objects from an image containing multiple objects based on instructions. Additionally, OmniGen can extract desired objects from an image containing multiple objects based on instructions.
```py ```python
import torch import torch
from diffusers import OmniGenPipeline from diffusers import OmniGenPipeline
from diffusers.utils import load_image from diffusers.utils import load_image
...@@ -231,9 +236,11 @@ image = pipe( ...@@ -231,9 +236,11 @@ image = pipe(
width=1024, width=1024,
guidance_scale=2.5, guidance_scale=2.5,
img_guidance_scale=1.6, img_guidance_scale=1.6,
generator=torch.Generator(device="cpu").manual_seed(666)).images[0] generator=torch.Generator(device="cpu").manual_seed(666)
image ).images[0]
image.save("output.png")
``` ```
<div class="flex flex-row gap-4"> <div class="flex flex-row gap-4">
<div class="flex-1"> <div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/3.png"/> <img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/3.png"/>
...@@ -249,7 +256,6 @@ image ...@@ -249,7 +256,6 @@ image
</div> </div>
</div> </div>
```py ```py
import torch import torch
from diffusers import OmniGenPipeline from diffusers import OmniGenPipeline
...@@ -261,7 +267,6 @@ pipe = OmniGenPipeline.from_pretrained( ...@@ -261,7 +267,6 @@ pipe = OmniGenPipeline.from_pretrained(
) )
pipe.to("cuda") pipe.to("cuda")
prompt="A woman is walking down the street, wearing a white long-sleeve blouse with lace details on the sleeves, paired with a blue pleated skirt. The woman is <img><|image_1|></img>. The long-sleeve blouse and a pleated skirt are <img><|image_2|></img>." prompt="A woman is walking down the street, wearing a white long-sleeve blouse with lace details on the sleeves, paired with a blue pleated skirt. The woman is <img><|image_1|></img>. The long-sleeve blouse and a pleated skirt are <img><|image_2|></img>."
input_image_1 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/emma.jpeg") input_image_1 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/emma.jpeg")
input_image_2 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/dress.jpg") input_image_2 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/dress.jpg")
...@@ -273,8 +278,9 @@ image = pipe( ...@@ -273,8 +278,9 @@ image = pipe(
width=1024, width=1024,
guidance_scale=2.5, guidance_scale=2.5,
img_guidance_scale=1.6, img_guidance_scale=1.6,
generator=torch.Generator(device="cpu").manual_seed(666)).images[0] generator=torch.Generator(device="cpu").manual_seed(666)
image ).images[0]
image.save("output.png")
``` ```
<div class="flex flex-row gap-4"> <div class="flex flex-row gap-4">
...@@ -292,13 +298,12 @@ image ...@@ -292,13 +298,12 @@ image
</div> </div>
</div> </div>
## Optimization when using multiple images
## Optimization when inputting multiple images
For text-to-image task, OmniGen requires minimal memory and time costs (9GB memory and 31s for a 1024x1024 image on A800 GPU). For text-to-image task, OmniGen requires minimal memory and time costs (9GB memory and 31s for a 1024x1024 image on A800 GPU).
However, when using input images, the computational cost increases. However, when using input images, the computational cost increases.
Here are some guidelines to help you reduce computational costs when inputting multiple images. The experiments are conducted on an A800 GPU with two input images. Here are some guidelines to help you reduce computational costs when using multiple images. The experiments are conducted on an A800 GPU with two input images.
Like other pipelines, you can reduce memory usage by offloading the model: `pipe.enable_model_cpu_offload()` or `pipe.enable_sequential_cpu_offload() `. Like other pipelines, you can reduce memory usage by offloading the model: `pipe.enable_model_cpu_offload()` or `pipe.enable_sequential_cpu_offload() `.
In OmniGen, you can also decrease computational overhead by reducing the `max_input_image_size`. In OmniGen, you can also decrease computational overhead by reducing the `max_input_image_size`.
...@@ -310,5 +315,3 @@ The memory consumption for different image sizes is shown in the table below: ...@@ -310,5 +315,3 @@ The memory consumption for different image sizes is shown in the table below:
| max_input_image_size=512 | 17GB | | max_input_image_size=512 | 17GB |
| max_input_image_size=256 | 14GB | | max_input_image_size=256 | 14GB |
...@@ -1199,7 +1199,7 @@ def apply_rotary_emb( ...@@ -1199,7 +1199,7 @@ def apply_rotary_emb(
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2: elif use_real_unbind_dim == -2:
# Used for Stable Audio # Used for Stable Audio and OmniGen
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1) x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else: else:
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from typing import Any, Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -23,11 +23,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor ...@@ -23,11 +23,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models.autoencoders import AutoencoderKL from ...models.autoencoders import AutoencoderKL
from ...models.transformers import OmniGenTransformer2DModel from ...models.transformers import OmniGenTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import ( from ...utils import is_torch_xla_available, logging, replace_example_docstring
is_torch_xla_available,
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .processor_omnigen import OmniGenMultiModalProcessor from .processor_omnigen import OmniGenMultiModalProcessor
...@@ -48,11 +44,12 @@ EXAMPLE_DOC_STRING = """ ...@@ -48,11 +44,12 @@ EXAMPLE_DOC_STRING = """
>>> pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1-diffusers", torch_dtype=torch.bfloat16) >>> pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1-diffusers", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda") >>> pipe.to("cuda")
>>> prompt = "A cat holding a sign that says hello world" >>> prompt = "A cat holding a sign that says hello world"
>>> # Depending on the variant being used, the pipeline call will slightly vary. >>> # Depending on the variant being used, the pipeline call will slightly vary.
>>> # Refer to the pipeline documentation for more details. >>> # Refer to the pipeline documentation for more details.
>>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0] >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0]
>>> image.save("t2i.png") >>> image.save("output.png")
``` ```
""" """
...@@ -200,7 +197,6 @@ class OmniGenPipeline( ...@@ -200,7 +197,6 @@ class OmniGenPipeline(
width, width,
use_input_image_size_as_output, use_input_image_size_as_output,
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
): ):
if input_images is not None: if input_images is not None:
if len(input_images) != len(prompt): if len(input_images) != len(prompt):
...@@ -324,10 +320,8 @@ class OmniGenPipeline( ...@@ -324,10 +320,8 @@ class OmniGenPipeline(
latents: Optional[torch.Tensor] = None, latents: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 120000,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -376,10 +370,6 @@ class OmniGenPipeline( ...@@ -376,10 +370,6 @@ class OmniGenPipeline(
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*): callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
...@@ -389,7 +379,6 @@ class OmniGenPipeline( ...@@ -389,7 +379,6 @@ class OmniGenPipeline(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class. `._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
Examples: Examples:
...@@ -414,11 +403,9 @@ class OmniGenPipeline( ...@@ -414,11 +403,9 @@ class OmniGenPipeline(
width, width,
use_input_image_size_as_output, use_input_image_size_as_output,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
) )
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._interrupt = False self._interrupt = False
# 2. Define call parameters # 2. Define call parameters
...@@ -451,7 +438,8 @@ class OmniGenPipeline( ...@@ -451,7 +438,8 @@ class OmniGenPipeline(
) )
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
# 6. Prepare latents. # 6. Prepare latents
transformer_dtype = self.transformer.dtype
if use_input_image_size_as_output: if use_input_image_size_as_output:
height, width = processed_data["input_pixel_values"][0].shape[-2:] height, width = processed_data["input_pixel_values"][0].shape[-2:]
latent_channels = self.transformer.config.in_channels latent_channels = self.transformer.config.in_channels
...@@ -460,7 +448,7 @@ class OmniGenPipeline( ...@@ -460,7 +448,7 @@ class OmniGenPipeline(
latent_channels, latent_channels,
height, height,
width, width,
self.transformer.dtype, torch.float32,
device, device,
generator, generator,
latents, latents,
...@@ -471,6 +459,7 @@ class OmniGenPipeline( ...@@ -471,6 +459,7 @@ class OmniGenPipeline(
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * (num_cfg + 1)) latent_model_input = torch.cat([latents] * (num_cfg + 1))
latent_model_input = latent_model_input.to(transformer_dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]) timestep = t.expand(latent_model_input.shape[0])
...@@ -483,7 +472,6 @@ class OmniGenPipeline( ...@@ -483,7 +472,6 @@ class OmniGenPipeline(
input_image_sizes=processed_data["input_image_sizes"], input_image_sizes=processed_data["input_image_sizes"],
attention_mask=processed_data["attention_mask"], attention_mask=processed_data["attention_mask"],
position_ids=processed_data["position_ids"], position_ids=processed_data["position_ids"],
attention_kwargs=attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
...@@ -495,7 +483,6 @@ class OmniGenPipeline( ...@@ -495,7 +483,6 @@ class OmniGenPipeline(
noise_pred = uncond + guidance_scale * (cond - uncond) noise_pred = uncond + guidance_scale * (cond - uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None: if callback_on_step_end is not None:
...@@ -506,11 +493,6 @@ class OmniGenPipeline( ...@@ -506,11 +493,6 @@ class OmniGenPipeline(
latents = callback_outputs.pop("latents", latents) latents = callback_outputs.pop("latents", latents)
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
progress_bar.update() progress_bar.update()
if not output_type == "latent": if not output_type == "latent":
......
...@@ -18,17 +18,10 @@ from ..test_pipelines_common import PipelineTesterMixin ...@@ -18,17 +18,10 @@ from ..test_pipelines_common import PipelineTesterMixin
class OmniGenPipelineFastTests(unittest.TestCase, PipelineTesterMixin): class OmniGenPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = OmniGenPipeline pipeline_class = OmniGenPipeline
params = frozenset( params = frozenset(["prompt", "guidance_scale"])
[ batch_params = frozenset(["prompt"])
"prompt",
"guidance_scale", test_layerwise_casting = True
]
)
batch_params = frozenset(
[
"prompt",
]
)
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment