Unverified Commit 57ac6738 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Refactor OmniGen (#10771)



* OmniGen model.py

* update OmniGenTransformerModel

* omnigen pipeline

* omnigen pipeline

* update omnigen_pipeline

* test case for omnigen

* update omnigenpipeline

* update docs

* update docs

* offload_transformer

* enable_transformer_block_cpu_offload

* update docs

* reformat

* reformat

* reformat

* update docs

* update docs

* make style

* make style

* Update docs/source/en/api/models/omnigen_transformer.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* update docs

* revert changes to examples/

* update OmniGen2DModel

* make style

* update test cases

* Update docs/source/en/api/pipelines/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* update docs

* typo

* Update src/diffusers/models/embeddings.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/models/attention.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/models/transformers/transformer_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/models/transformers/transformer_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/models/transformers/transformer_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update tests/pipelines/omnigen/test_pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update tests/pipelines/omnigen/test_pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* consistent attention processor

* updata

* update

* check_inputs

* make style

* update testpipeline

* update testpipeline

* refactor omnigen

* more updates

* apply review suggestion

---------
Co-authored-by: default avatarshitao <2906698981@qq.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent 81440fd4
...@@ -14,6 +14,17 @@ specific language governing permissions and limitations under the License. ...@@ -14,6 +14,17 @@ specific language governing permissions and limitations under the License.
A Transformer model that accepts multimodal instructions to generate images for [OmniGen](https://github.com/VectorSpaceLab/OmniGen/). A Transformer model that accepts multimodal instructions to generate images for [OmniGen](https://github.com/VectorSpaceLab/OmniGen/).
The abstract from the paper is:
*The emergence of Large Language Models (LLMs) has unified language generation tasks and revolutionized human-machine interaction. However, in the realm of image generation, a unified model capable of handling various tasks within a single framework remains largely unexplored. In this work, we introduce OmniGen, a new diffusion model for unified image generation. OmniGen is characterized by the following features: 1) Unification: OmniGen not only demonstrates text-to-image generation capabilities but also inherently supports various downstream tasks, such as image editing, subject-driven generation, and visual conditional generation. 2) Simplicity: The architecture of OmniGen is highly simplified, eliminating the need for additional plugins. Moreover, compared to existing diffusion models, it is more user-friendly and can complete complex tasks end-to-end through instructions without the need for extra intermediate steps, greatly simplifying the image generation workflow. 3) Knowledge Transfer: Benefit from learning in a unified format, OmniGen effectively transfers knowledge across different tasks, manages unseen tasks and domains, and exhibits novel capabilities. We also explore the model’s reasoning capabilities and potential applications of the chain-of-thought mechanism. This work represents the first attempt at a general-purpose image generation model, and we will release our resources at https://github.com/VectorSpaceLab/OmniGen to foster future advancements.*
```python
import torch
from diffusers import OmniGenTransformer2DModel
transformer = OmniGenTransformer2DModel.from_pretrained("Shitao/OmniGen-v1-diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## OmniGenTransformer2DModel ## OmniGenTransformer2DModel
[[autodoc]] OmniGenTransformer2DModel [[autodoc]] OmniGenTransformer2DModel
...@@ -19,27 +19,7 @@ ...@@ -19,27 +19,7 @@
The abstract from the paper is: The abstract from the paper is:
*The emergence of Large Language Models (LLMs) has unified language *The emergence of Large Language Models (LLMs) has unified language generation tasks and revolutionized human-machine interaction. However, in the realm of image generation, a unified model capable of handling various tasks within a single framework remains largely unexplored. In this work, we introduce OmniGen, a new diffusion model for unified image generation. OmniGen is characterized by the following features: 1) Unification: OmniGen not only demonstrates text-to-image generation capabilities but also inherently supports various downstream tasks, such as image editing, subject-driven generation, and visual conditional generation. 2) Simplicity: The architecture of OmniGen is highly simplified, eliminating the need for additional plugins. Moreover, compared to existing diffusion models, it is more user-friendly and can complete complex tasks end-to-end through instructions without the need for extra intermediate steps, greatly simplifying the image generation workflow. 3) Knowledge Transfer: Benefit from learning in a unified format, OmniGen effectively transfers knowledge across different tasks, manages unseen tasks and domains, and exhibits novel capabilities. We also explore the model’s reasoning capabilities and potential applications of the chain-of-thought mechanism. This work represents the first attempt at a general-purpose image generation model, and we will release our resources at https://github.com/VectorSpaceLab/OmniGen to foster future advancements.*
generation tasks and revolutionized human-machine interaction.
However, in the realm of image generation, a unified model capable of handling various tasks
within a single framework remains largely unexplored. In
this work, we introduce OmniGen, a new diffusion model
for unified image generation. OmniGen is characterized
by the following features: 1) Unification: OmniGen not
only demonstrates text-to-image generation capabilities but
also inherently supports various downstream tasks, such
as image editing, subject-driven generation, and visual conditional generation. 2) Simplicity: The architecture of
OmniGen is highly simplified, eliminating the need for additional plugins. Moreover, compared to existing diffusion
models, it is more user-friendly and can complete complex
tasks end-to-end through instructions without the need for
extra intermediate steps, greatly simplifying the image generation workflow. 3) Knowledge Transfer: Benefit from
learning in a unified format, OmniGen effectively transfers
knowledge across different tasks, manages unseen tasks and
domains, and exhibits novel capabilities. We also explore
the model’s reasoning capabilities and potential applications of the chain-of-thought mechanism.
This work represents the first attempt at a general-purpose image generation model,
and we will release our resources at https:
//github.com/VectorSpaceLab/OmniGen to foster future advancements.*
<Tip> <Tip>
...@@ -49,7 +29,6 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m ...@@ -49,7 +29,6 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
This pipeline was contributed by [staoxiao](https://github.com/staoxiao). The original codebase can be found [here](https://github.com/VectorSpaceLab/OmniGen). The original weights can be found under [hf.co/shitao](https://huggingface.co/Shitao/OmniGen-v1). This pipeline was contributed by [staoxiao](https://github.com/staoxiao). The original codebase can be found [here](https://github.com/VectorSpaceLab/OmniGen). The original weights can be found under [hf.co/shitao](https://huggingface.co/Shitao/OmniGen-v1).
## Inference ## Inference
First, load the pipeline: First, load the pipeline:
...@@ -57,17 +36,15 @@ First, load the pipeline: ...@@ -57,17 +36,15 @@ First, load the pipeline:
```python ```python
import torch import torch
from diffusers import OmniGenPipeline from diffusers import OmniGenPipeline
pipe = OmniGenPipeline.from_pretrained(
"Shitao/OmniGen-v1-diffusers", pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1-diffusers", torch_dtype=torch.bfloat16)
torch_dtype=torch.bfloat16
)
pipe.to("cuda") pipe.to("cuda")
``` ```
For text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image. For text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image.
You can try setting the `height` and `width` parameters to generate images with different size. You can try setting the `height` and `width` parameters to generate images with different size.
```py ```python
prompt = "Realistic photo. A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD." prompt = "Realistic photo. A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD."
image = pipe( image = pipe(
prompt=prompt, prompt=prompt,
...@@ -76,14 +53,14 @@ image = pipe( ...@@ -76,14 +53,14 @@ image = pipe(
guidance_scale=3, guidance_scale=3,
generator=torch.Generator(device="cpu").manual_seed(111), generator=torch.Generator(device="cpu").manual_seed(111),
).images[0] ).images[0]
image image.save("output.png")
``` ```
OmniGen supports multimodal inputs. OmniGen supports multimodal inputs.
When the input includes an image, you need to add a placeholder `<img><|image_1|></img>` in the text prompt to represent the image. When the input includes an image, you need to add a placeholder `<img><|image_1|></img>` in the text prompt to represent the image.
It is recommended to enable `use_input_image_size_as_output` to keep the edited image the same size as the original image. It is recommended to enable `use_input_image_size_as_output` to keep the edited image the same size as the original image.
```py ```python
prompt="<img><|image_1|></img> Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola." prompt="<img><|image_1|></img> Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola."
input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png")] input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png")]
image = pipe( image = pipe(
...@@ -93,14 +70,11 @@ image = pipe( ...@@ -93,14 +70,11 @@ image = pipe(
img_guidance_scale=1.6, img_guidance_scale=1.6,
use_input_image_size_as_output=True, use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(222)).images[0] generator=torch.Generator(device="cpu").manual_seed(222)).images[0]
image image.save("output.png")
``` ```
## OmniGenPipeline ## OmniGenPipeline
[[autodoc]] OmniGenPipeline [[autodoc]] OmniGenPipeline
- all - all
- __call__ - __call__
...@@ -19,25 +19,22 @@ For more information, please refer to the [paper](https://arxiv.org/pdf/2409.113 ...@@ -19,25 +19,22 @@ For more information, please refer to the [paper](https://arxiv.org/pdf/2409.113
This guide will walk you through using OmniGen for various tasks and use cases. This guide will walk you through using OmniGen for various tasks and use cases.
## Load model checkpoints ## Load model checkpoints
Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method. Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method.
```py ```python
import torch import torch
from diffusers import OmniGenPipeline from diffusers import OmniGenPipeline
pipe = OmniGenPipeline.from_pretrained(
"Shitao/OmniGen-v1-diffusers",
torch_dtype=torch.bfloat16
)
```
pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1-diffusers", torch_dtype=torch.bfloat16)
```
## Text-to-image ## Text-to-image
For text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image. For text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image.
You can try setting the `height` and `width` parameters to generate images with different size. You can try setting the `height` and `width` parameters to generate images with different size.
```py ```python
import torch import torch
from diffusers import OmniGenPipeline from diffusers import OmniGenPipeline
...@@ -55,8 +52,9 @@ image = pipe( ...@@ -55,8 +52,9 @@ image = pipe(
guidance_scale=3, guidance_scale=3,
generator=torch.Generator(device="cpu").manual_seed(111), generator=torch.Generator(device="cpu").manual_seed(111),
).images[0] ).images[0]
image image.save("output.png")
``` ```
<div class="flex justify-center"> <div class="flex justify-center">
<img src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png" alt="generated image"/> <img src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png" alt="generated image"/>
</div> </div>
...@@ -67,7 +65,7 @@ OmniGen supports multimodal inputs. ...@@ -67,7 +65,7 @@ OmniGen supports multimodal inputs.
When the input includes an image, you need to add a placeholder `<img><|image_1|></img>` in the text prompt to represent the image. When the input includes an image, you need to add a placeholder `<img><|image_1|></img>` in the text prompt to represent the image.
It is recommended to enable `use_input_image_size_as_output` to keep the edited image the same size as the original image. It is recommended to enable `use_input_image_size_as_output` to keep the edited image the same size as the original image.
```py ```python
import torch import torch
from diffusers import OmniGenPipeline from diffusers import OmniGenPipeline
from diffusers.utils import load_image from diffusers.utils import load_image
...@@ -86,9 +84,11 @@ image = pipe( ...@@ -86,9 +84,11 @@ image = pipe(
guidance_scale=2, guidance_scale=2,
img_guidance_scale=1.6, img_guidance_scale=1.6,
use_input_image_size_as_output=True, use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(222)).images[0] generator=torch.Generator(device="cpu").manual_seed(222)
image ).images[0]
image.save("output.png")
``` ```
<div class="flex flex-row gap-4"> <div class="flex flex-row gap-4">
<div class="flex-1"> <div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png"/> <img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png"/>
...@@ -101,7 +101,8 @@ image ...@@ -101,7 +101,8 @@ image
</div> </div>
OmniGen has some interesting features, such as visual reasoning, as shown in the example below. OmniGen has some interesting features, such as visual reasoning, as shown in the example below.
```py
```python
prompt="If the woman is thirsty, what should she take? Find it in the image and highlight it in blue. <img><|image_1|></img>" prompt="If the woman is thirsty, what should she take? Find it in the image and highlight it in blue. <img><|image_1|></img>"
input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png")] input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png")]
image = pipe( image = pipe(
...@@ -110,20 +111,20 @@ image = pipe( ...@@ -110,20 +111,20 @@ image = pipe(
guidance_scale=2, guidance_scale=2,
img_guidance_scale=1.6, img_guidance_scale=1.6,
use_input_image_size_as_output=True, use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(0)).images[0] generator=torch.Generator(device="cpu").manual_seed(0)
image ).images[0]
image.save("output.png")
``` ```
<div class="flex justify-center"> <div class="flex justify-center">
<img src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/reasoning.png" alt="generated image"/> <img src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/reasoning.png" alt="generated image"/>
</div> </div>
## Controllable generation ## Controllable generation
OmniGen can handle several classic computer vision tasks. OmniGen can handle several classic computer vision tasks. As shown below, OmniGen can detect human skeletons in input images, which can be used as control conditions to generate new images.
As shown below, OmniGen can detect human skeletons in input images, which can be used as control conditions to generate new images.
```py ```python
import torch import torch
from diffusers import OmniGenPipeline from diffusers import OmniGenPipeline
from diffusers.utils import load_image from diffusers.utils import load_image
...@@ -142,8 +143,9 @@ image1 = pipe( ...@@ -142,8 +143,9 @@ image1 = pipe(
guidance_scale=2, guidance_scale=2,
img_guidance_scale=1.6, img_guidance_scale=1.6,
use_input_image_size_as_output=True, use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(333)).images[0] generator=torch.Generator(device="cpu").manual_seed(333)
image1 ).images[0]
image1.save("image1.png")
prompt="Generate a new photo using the following picture and text as conditions: <img><|image_1|></img>\n A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him." prompt="Generate a new photo using the following picture and text as conditions: <img><|image_1|></img>\n A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him."
input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/skeletal.png")] input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/skeletal.png")]
...@@ -153,8 +155,9 @@ image2 = pipe( ...@@ -153,8 +155,9 @@ image2 = pipe(
guidance_scale=2, guidance_scale=2,
img_guidance_scale=1.6, img_guidance_scale=1.6,
use_input_image_size_as_output=True, use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(333)).images[0] generator=torch.Generator(device="cpu").manual_seed(333)
image2 ).images[0]
image2.save("image2.png")
``` ```
<div class="flex flex-row gap-4"> <div class="flex flex-row gap-4">
...@@ -174,7 +177,8 @@ image2 ...@@ -174,7 +177,8 @@ image2
OmniGen can also directly use relevant information from input images to generate new images. OmniGen can also directly use relevant information from input images to generate new images.
```py
```python
import torch import torch
from diffusers import OmniGenPipeline from diffusers import OmniGenPipeline
from diffusers.utils import load_image from diffusers.utils import load_image
...@@ -193,9 +197,11 @@ image = pipe( ...@@ -193,9 +197,11 @@ image = pipe(
guidance_scale=2, guidance_scale=2,
img_guidance_scale=1.6, img_guidance_scale=1.6,
use_input_image_size_as_output=True, use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(0)).images[0] generator=torch.Generator(device="cpu").manual_seed(0)
image ).images[0]
image.save("output.png")
``` ```
<div class="flex flex-row gap-4"> <div class="flex flex-row gap-4">
<div class="flex-1"> <div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/same_pose.png"/> <img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/same_pose.png"/>
...@@ -203,13 +209,12 @@ image ...@@ -203,13 +209,12 @@ image
</div> </div>
</div> </div>
## ID and object preserving ## ID and object preserving
OmniGen can generate multiple images based on the people and objects in the input image and supports inputting multiple images simultaneously. OmniGen can generate multiple images based on the people and objects in the input image and supports inputting multiple images simultaneously.
Additionally, OmniGen can extract desired objects from an image containing multiple objects based on instructions. Additionally, OmniGen can extract desired objects from an image containing multiple objects based on instructions.
```py ```python
import torch import torch
from diffusers import OmniGenPipeline from diffusers import OmniGenPipeline
from diffusers.utils import load_image from diffusers.utils import load_image
...@@ -231,9 +236,11 @@ image = pipe( ...@@ -231,9 +236,11 @@ image = pipe(
width=1024, width=1024,
guidance_scale=2.5, guidance_scale=2.5,
img_guidance_scale=1.6, img_guidance_scale=1.6,
generator=torch.Generator(device="cpu").manual_seed(666)).images[0] generator=torch.Generator(device="cpu").manual_seed(666)
image ).images[0]
image.save("output.png")
``` ```
<div class="flex flex-row gap-4"> <div class="flex flex-row gap-4">
<div class="flex-1"> <div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/3.png"/> <img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/3.png"/>
...@@ -249,7 +256,6 @@ image ...@@ -249,7 +256,6 @@ image
</div> </div>
</div> </div>
```py ```py
import torch import torch
from diffusers import OmniGenPipeline from diffusers import OmniGenPipeline
...@@ -261,7 +267,6 @@ pipe = OmniGenPipeline.from_pretrained( ...@@ -261,7 +267,6 @@ pipe = OmniGenPipeline.from_pretrained(
) )
pipe.to("cuda") pipe.to("cuda")
prompt="A woman is walking down the street, wearing a white long-sleeve blouse with lace details on the sleeves, paired with a blue pleated skirt. The woman is <img><|image_1|></img>. The long-sleeve blouse and a pleated skirt are <img><|image_2|></img>." prompt="A woman is walking down the street, wearing a white long-sleeve blouse with lace details on the sleeves, paired with a blue pleated skirt. The woman is <img><|image_1|></img>. The long-sleeve blouse and a pleated skirt are <img><|image_2|></img>."
input_image_1 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/emma.jpeg") input_image_1 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/emma.jpeg")
input_image_2 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/dress.jpg") input_image_2 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/dress.jpg")
...@@ -273,8 +278,9 @@ image = pipe( ...@@ -273,8 +278,9 @@ image = pipe(
width=1024, width=1024,
guidance_scale=2.5, guidance_scale=2.5,
img_guidance_scale=1.6, img_guidance_scale=1.6,
generator=torch.Generator(device="cpu").manual_seed(666)).images[0] generator=torch.Generator(device="cpu").manual_seed(666)
image ).images[0]
image.save("output.png")
``` ```
<div class="flex flex-row gap-4"> <div class="flex flex-row gap-4">
...@@ -292,13 +298,12 @@ image ...@@ -292,13 +298,12 @@ image
</div> </div>
</div> </div>
## Optimization when using multiple images
## Optimization when inputting multiple images
For text-to-image task, OmniGen requires minimal memory and time costs (9GB memory and 31s for a 1024x1024 image on A800 GPU). For text-to-image task, OmniGen requires minimal memory and time costs (9GB memory and 31s for a 1024x1024 image on A800 GPU).
However, when using input images, the computational cost increases. However, when using input images, the computational cost increases.
Here are some guidelines to help you reduce computational costs when inputting multiple images. The experiments are conducted on an A800 GPU with two input images. Here are some guidelines to help you reduce computational costs when using multiple images. The experiments are conducted on an A800 GPU with two input images.
Like other pipelines, you can reduce memory usage by offloading the model: `pipe.enable_model_cpu_offload()` or `pipe.enable_sequential_cpu_offload() `. Like other pipelines, you can reduce memory usage by offloading the model: `pipe.enable_model_cpu_offload()` or `pipe.enable_sequential_cpu_offload() `.
In OmniGen, you can also decrease computational overhead by reducing the `max_input_image_size`. In OmniGen, you can also decrease computational overhead by reducing the `max_input_image_size`.
...@@ -310,5 +315,3 @@ The memory consumption for different image sizes is shown in the table below: ...@@ -310,5 +315,3 @@ The memory consumption for different image sizes is shown in the table below:
| max_input_image_size=512 | 17GB | | max_input_image_size=512 | 17GB |
| max_input_image_size=256 | 14GB | | max_input_image_size=256 | 14GB |
...@@ -1199,7 +1199,7 @@ def apply_rotary_emb( ...@@ -1199,7 +1199,7 @@ def apply_rotary_emb(
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2: elif use_real_unbind_dim == -2:
# Used for Stable Audio # Used for Stable Audio and OmniGen
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1) x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else: else:
......
...@@ -13,17 +13,15 @@ ...@@ -13,17 +13,15 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin from ...utils import logging
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers from ..attention_processor import Attention
from ..attention_processor import Attention, AttentionProcessor
from ..embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed from ..embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
...@@ -34,39 +32,21 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -34,39 +32,21 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class OmniGenFeedForward(nn.Module): class OmniGenFeedForward(nn.Module):
r""" def __init__(self, hidden_size: int, intermediate_size: int):
A feed-forward layer for OmniGen.
Parameters:
hidden_size (`int`):
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
hidden representations.
intermediate_size (`int`): The intermediate dimension of the feedforward layer.
"""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
):
super().__init__() super().__init__()
self.gate_up_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False) self.gate_up_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
self.activation_fn = nn.SiLU() self.activation_fn = nn.SiLU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
up_states = self.gate_up_proj(hidden_states) up_states = self.gate_up_proj(hidden_states)
gate, up_states = up_states.chunk(2, dim=-1) gate, up_states = up_states.chunk(2, dim=-1)
up_states = up_states * self.activation_fn(gate) up_states = up_states * self.activation_fn(gate)
return self.down_proj(up_states) return self.down_proj(up_states)
class OmniGenPatchEmbed(nn.Module): class OmniGenPatchEmbed(nn.Module):
"""2D Image to Patch Embedding with support for OmniGen."""
def __init__( def __init__(
self, self,
patch_size: int = 2, patch_size: int = 2,
...@@ -99,7 +79,7 @@ class OmniGenPatchEmbed(nn.Module): ...@@ -99,7 +79,7 @@ class OmniGenPatchEmbed(nn.Module):
) )
self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=True) self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=True)
def cropped_pos_embed(self, height, width): def _cropped_pos_embed(self, height, width):
"""Crops positional embeddings for SD3 compatibility.""" """Crops positional embeddings for SD3 compatibility."""
if self.pos_embed_max_size is None: if self.pos_embed_max_size is None:
raise ValueError("`pos_embed_max_size` must be set for cropping.") raise ValueError("`pos_embed_max_size` must be set for cropping.")
...@@ -122,43 +102,34 @@ class OmniGenPatchEmbed(nn.Module): ...@@ -122,43 +102,34 @@ class OmniGenPatchEmbed(nn.Module):
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
return spatial_pos_embed return spatial_pos_embed
def patch_embeddings(self, latent, is_input_image: bool): def _patch_embeddings(self, hidden_states: torch.Tensor, is_input_image: bool) -> torch.Tensor:
if is_input_image: if is_input_image:
latent = self.input_image_proj(latent) hidden_states = self.input_image_proj(hidden_states)
else: else:
latent = self.output_image_proj(latent) hidden_states = self.output_image_proj(hidden_states)
latent = latent.flatten(2).transpose(1, 2) hidden_states = hidden_states.flatten(2).transpose(1, 2)
return latent return hidden_states
def forward(self, latent: torch.Tensor, is_input_image: bool, padding_latent: torch.Tensor = None): def forward(
""" self, hidden_states: torch.Tensor, is_input_image: bool, padding_latent: torch.Tensor = None
Args: ) -> torch.Tensor:
latent: encoded image latents if isinstance(hidden_states, list):
is_input_image: use input_image_proj or output_image_proj
padding_latent:
When sizes of target images are inconsistent, use `padding_latent` to maintain consistent sequence
length.
Returns: torch.Tensor
"""
if isinstance(latent, list):
if padding_latent is None: if padding_latent is None:
padding_latent = [None] * len(latent) padding_latent = [None] * len(hidden_states)
patched_latents = [] patched_latents = []
for sub_latent, padding in zip(latent, padding_latent): for sub_latent, padding in zip(hidden_states, padding_latent):
height, width = sub_latent.shape[-2:] height, width = sub_latent.shape[-2:]
sub_latent = self.patch_embeddings(sub_latent, is_input_image) sub_latent = self._patch_embeddings(sub_latent, is_input_image)
pos_embed = self.cropped_pos_embed(height, width) pos_embed = self._cropped_pos_embed(height, width)
sub_latent = sub_latent + pos_embed sub_latent = sub_latent + pos_embed
if padding is not None: if padding is not None:
sub_latent = torch.cat([sub_latent, padding.to(sub_latent.device)], dim=-2) sub_latent = torch.cat([sub_latent, padding.to(sub_latent.device)], dim=-2)
patched_latents.append(sub_latent) patched_latents.append(sub_latent)
else: else:
height, width = latent.shape[-2:] height, width = hidden_states.shape[-2:]
pos_embed = self.cropped_pos_embed(height, width) pos_embed = self._cropped_pos_embed(height, width)
latent = self.patch_embeddings(latent, is_input_image) hidden_states = self._patch_embeddings(hidden_states, is_input_image)
patched_latents = latent + pos_embed patched_latents = hidden_states + pos_embed
return patched_latents return patched_latents
...@@ -180,15 +151,16 @@ class OmniGenSuScaledRotaryEmbedding(nn.Module): ...@@ -180,15 +151,16 @@ class OmniGenSuScaledRotaryEmbedding(nn.Module):
self.long_factor = rope_scaling["long_factor"] self.long_factor = rope_scaling["long_factor"]
self.original_max_position_embeddings = original_max_position_embeddings self.original_max_position_embeddings = original_max_position_embeddings
@torch.no_grad() def forward(self, hidden_states, position_ids):
def forward(self, x, position_ids):
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.original_max_position_embeddings: if seq_len > self.original_max_position_embeddings:
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=hidden_states.device)
else: else:
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=hidden_states.device)
inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim inv_freq_shape = (
torch.arange(0, self.dim, 2, dtype=torch.int64, device=hidden_states.device).float() / self.dim
)
self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
...@@ -196,11 +168,11 @@ class OmniGenSuScaledRotaryEmbedding(nn.Module): ...@@ -196,11 +168,11 @@ class OmniGenSuScaledRotaryEmbedding(nn.Module):
# Force float32 since bfloat16 loses precision on long contexts # Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285 # See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type device_type = hidden_states.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)[0]
scale = self.max_position_embeddings / self.original_max_position_embeddings scale = self.max_position_embeddings / self.original_max_position_embeddings
if scale <= 1.0: if scale <= 1.0:
...@@ -210,44 +182,7 @@ class OmniGenSuScaledRotaryEmbedding(nn.Module): ...@@ -210,44 +182,7 @@ class OmniGenSuScaledRotaryEmbedding(nn.Module):
cos = emb.cos() * scaling_factor cos = emb.cos() * scaling_factor
sin = emb.sin() * scaling_factor sin = emb.sin() * scaling_factor
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) return cos, sin
def apply_rotary_emb(
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
cos, sin = freqs_cis # [S, D]
if len(cos.shape) == 2:
cos = cos[None, None]
sin = sin[None, None]
elif len(cos.shape) == 3:
cos = cos[:, None]
sin = sin[:, None]
cos, sin = cos.to(x.device), sin.to(x.device)
# Rotates half the hidden dims of the input. this rorate function is widely used in LLM, e.g. Llama, Phi3, etc.
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
x_rotated = torch.cat((-x2, x1), dim=-1)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
class OmniGenAttnProcessor2_0: class OmniGenAttnProcessor2_0:
...@@ -278,7 +213,6 @@ class OmniGenAttnProcessor2_0: ...@@ -278,7 +213,6 @@ class OmniGenAttnProcessor2_0:
bsz, q_len, query_dim = query.size() bsz, q_len, query_dim = query.size()
inner_dim = key.shape[-1] inner_dim = key.shape[-1]
head_dim = query_dim // attn.heads head_dim = query_dim // attn.heads
dtype = query.dtype
# Get key-value heads # Get key-value heads
kv_heads = inner_dim // head_dim kv_heads = inner_dim // head_dim
...@@ -289,32 +223,19 @@ class OmniGenAttnProcessor2_0: ...@@ -289,32 +223,19 @@ class OmniGenAttnProcessor2_0:
# Apply RoPE if needed # Apply RoPE if needed
if image_rotary_emb is not None: if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb) from ..embeddings import apply_rotary_emb
key = apply_rotary_emb(key, image_rotary_emb)
query, key = query.to(dtype), key.to(dtype) query = apply_rotary_emb(query, image_rotary_emb, use_real_unbind_dim=-2)
key = apply_rotary_emb(key, image_rotary_emb, use_real_unbind_dim=-2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
hidden_states = hidden_states.transpose(1, 2).to(dtype) hidden_states = hidden_states.transpose(1, 2).type_as(query)
hidden_states = hidden_states.reshape(bsz, q_len, attn.out_dim) hidden_states = hidden_states.reshape(bsz, q_len, attn.out_dim)
hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[0](hidden_states)
return hidden_states return hidden_states
class OmniGenBlock(nn.Module): class OmniGenBlock(nn.Module):
"""
A LuminaNextDiTBlock for LuminaNextDiT2DModel.
Parameters:
hidden_size (`int`): Embedding dimension of the input features.
num_attention_heads (`int`): Number of attention heads.
num_key_value_heads (`int`):
Number of attention heads in key and value features (if using GQA), or set to None for the same as query.
intermediate_size (`int`): size of intermediate layer.
rms_norm_eps (`float`): The eps for norm layer.
"""
def __init__( def __init__(
self, self,
hidden_size: int, hidden_size: int,
...@@ -341,78 +262,77 @@ class OmniGenBlock(nn.Module): ...@@ -341,78 +262,77 @@ class OmniGenBlock(nn.Module):
self.mlp = OmniGenFeedForward(hidden_size, intermediate_size) self.mlp = OmniGenFeedForward(hidden_size, intermediate_size)
def forward( def forward(
self, self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor
hidden_states: torch.Tensor, ) -> torch.Tensor:
attention_mask: torch.Tensor, # 1. Attention
image_rotary_emb: torch.Tensor, norm_hidden_states = self.input_layernorm(hidden_states)
): attn_output = self.self_attn(
""" hidden_states=norm_hidden_states,
Perform a forward pass through the LuminaNextDiTBlock. encoder_hidden_states=norm_hidden_states,
Parameters:
hidden_states (`torch.Tensor`): The input of hidden_states for LuminaNextDiTBlock.
attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask.
image_rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies.
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
attn_outputs = self.self_attn(
hidden_states=hidden_states,
encoder_hidden_states=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
) )
hidden_states = hidden_states + attn_output
hidden_states = residual + attn_outputs # 2. Feed Forward
norm_hidden_states = self.post_attention_layernorm(hidden_states)
residual = hidden_states ff_output = self.mlp(norm_hidden_states)
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = hidden_states + ff_output
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states return hidden_states
class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): class OmniGenTransformer2DModel(ModelMixin, ConfigMixin):
""" """
The Transformer model introduced in OmniGen. The Transformer model introduced in OmniGen (https://arxiv.org/pdf/2409.11340).
Reference: https://arxiv.org/pdf/2409.11340
Parameters: Parameters:
hidden_size (`int`, *optional*, defaults to 3072): in_channels (`int`, defaults to `4`):
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's The number of channels in the input.
hidden representations. patch_size (`int`, defaults to `2`):
rms_norm_eps (`float`, *optional*, defaults to 1e-5): eps for RMSNorm layer. The size of the spatial patches to use in the patch embedding layer.
num_attention_heads (`int`, *optional*, defaults to 32): hidden_size (`int`, defaults to `3072`):
The number of attention heads in each attention layer. This parameter specifies how many separate attention The dimensionality of the hidden layers in the model.
mechanisms are used. rms_norm_eps (`float`, defaults to `1e-5`):
num_kv_heads (`int`, *optional*, defaults to 32): Eps for RMSNorm layer.
The number of key-value heads in the attention mechanism, if different from the number of attention heads. num_attention_heads (`int`, defaults to `32`):
If None, it defaults to num_attention_heads. The number of heads to use for multi-head attention.
intermediate_size (`int`, *optional*, defaults to 8192): dimension of the intermediate layer in FFN num_key_value_heads (`int`, defaults to `32`):
num_layers (`int`, *optional*, default to 32): The number of heads to use for keys and values in multi-head attention.
The number of layers in the model. This defines the depth of the neural network. intermediate_size (`int`, defaults to `8192`):
pad_token_id (`int`, *optional*, default to 32000): Dimension of the hidden layer in FeedForward layers.
id for pad token num_layers (`int`, default to `32`):
vocab_size (`int`, *optional*, default to 32064): The number of layers of transformer blocks to use.
size of vocabulary pad_token_id (`int`, default to `32000`):
patch_size (`int`, defaults to 2): Patch size to turn the input data into small patches. The id of the padding token.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input. vocab_size (`int`, default to `32064`):
pos_embed_max_size (`int`, *optional*, defaults to 192): The max size of pos emb. The size of the vocabulary of the embedding vocabulary.
rope_base (`int`, default to `10000`):
The default theta value to use when creating RoPE.
rope_scaling (`Dict`, optional):
The scaling factors for the RoPE. Must contain `short_factor` and `long_factor`.
pos_embed_max_size (`int`, default to `192`):
The maximum size of the positional embeddings.
time_step_dim (`int`, default to `256`):
Output dimension of timestep embeddings.
flip_sin_to_cos (`bool`, default to `True`):
Whether to flip the sin and cos in the positional embeddings when preparing timestep embeddings.
downscale_freq_shift (`int`, default to `0`):
The frequency shift to use when downscaling the timestep embeddings.
timestep_activation_fn (`str`, default to `silu`):
The activation function to use for the timestep embeddings.
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
_no_split_modules = ["OmniGenBlock"] _no_split_modules = ["OmniGenBlock"]
_skip_layerwise_casting_patterns = ["patch_embedding", "embed_tokens", "norm"]
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
in_channels: int = 4,
patch_size: int = 2,
hidden_size: int = 3072, hidden_size: int = 3072,
rms_norm_eps: float = 1e-05, rms_norm_eps: float = 1e-5,
num_attention_heads: int = 32, num_attention_heads: int = 32,
num_key_value_heads: int = 32, num_key_value_heads: int = 32,
intermediate_size: int = 8192, intermediate_size: int = 8192,
...@@ -423,8 +343,6 @@ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -423,8 +343,6 @@ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
original_max_position_embeddings: int = 4096, original_max_position_embeddings: int = 4096,
rope_base: int = 10000, rope_base: int = 10000,
rope_scaling: Dict = None, rope_scaling: Dict = None,
patch_size=2,
in_channels=4,
pos_embed_max_size: int = 192, pos_embed_max_size: int = 192,
time_step_dim: int = 256, time_step_dim: int = 256,
flip_sin_to_cos: bool = True, flip_sin_to_cos: bool = True,
...@@ -434,8 +352,6 @@ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -434,8 +352,6 @@ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = in_channels self.out_channels = in_channels
self.patch_size = patch_size
self.pos_embed_max_size = pos_embed_max_size
self.patch_embedding = OmniGenPatchEmbed( self.patch_embedding = OmniGenPatchEmbed(
patch_size=patch_size, patch_size=patch_size,
...@@ -448,11 +364,8 @@ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -448,11 +364,8 @@ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.time_token = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn) self.time_token = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn)
self.t_embedder = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn) self.t_embedder = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn)
self.norm_out = AdaLayerNorm(hidden_size, norm_elementwise_affine=False, norm_eps=1e-6, chunk_dim=1)
self.proj_out = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True)
self.embed_tokens = nn.Embedding(vocab_size, hidden_size, pad_token_id) self.embed_tokens = nn.Embedding(vocab_size, hidden_size, pad_token_id)
self.rotary_emb = OmniGenSuScaledRotaryEmbedding( self.rope = OmniGenSuScaledRotaryEmbedding(
hidden_size // num_attention_heads, hidden_size // num_attention_heads,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
original_max_position_embeddings=original_max_position_embeddings, original_max_position_embeddings=original_max_position_embeddings,
...@@ -462,126 +375,34 @@ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -462,126 +375,34 @@ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
OmniGenBlock( OmniGenBlock(hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, rms_norm_eps)
hidden_size,
num_attention_heads,
num_key_value_heads,
intermediate_size,
rms_norm_eps,
)
for _ in range(num_layers) for _ in range(num_layers)
] ]
) )
self.norm = RMSNorm(hidden_size, eps=rms_norm_eps) self.norm = RMSNorm(hidden_size, eps=rms_norm_eps)
self.norm_out = AdaLayerNorm(hidden_size, norm_elementwise_affine=False, norm_eps=1e-6, chunk_dim=1)
self.proj_out = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False self.gradient_checkpointing = False
def unpatchify(self, x, h, w): def _get_multimodal_embeddings(
""" self, input_ids: torch.Tensor, input_img_latents: List[torch.Tensor], input_image_sizes: Dict
x: (N, T, patch_size**2 * C) imgs: (N, H, W, C) ) -> Optional[torch.Tensor]:
""" if input_ids is None:
c = self.out_channels return None
x = x.reshape(
shape=(x.shape[0], h // self.patch_size, w // self.patch_size, self.patch_size, self.patch_size, c)
)
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, h, w))
return imgs
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[OmniGenAttnProcessor2_0, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def get_multimodal_embeddings(
self,
input_ids: torch.Tensor,
input_img_latents: List[torch.Tensor],
input_image_sizes: Dict,
):
"""
get the multi-modal conditional embeddings
Args:
input_ids: a sequence of text id
input_img_latents: continues embedding of input images
input_image_sizes: the index of the input image in the input_ids sequence.
Returns: torch.Tensor
"""
input_img_latents = [x.to(self.dtype) for x in input_img_latents] input_img_latents = [x.to(self.dtype) for x in input_img_latents]
condition_tokens = None condition_tokens = self.embed_tokens(input_ids)
if input_ids is not None: input_img_inx = 0
condition_tokens = self.embed_tokens(input_ids) input_image_tokens = self.patch_embedding(input_img_latents, is_input_image=True)
input_img_inx = 0 for b_inx in input_image_sizes.keys():
if input_img_latents is not None: for start_inx, end_inx in input_image_sizes[b_inx]:
input_image_tokens = self.patch_embedding(input_img_latents, is_input_image=True) # replace the placeholder in text tokens with the image embedding.
condition_tokens[b_inx, start_inx:end_inx] = input_image_tokens[input_img_inx].to(
for b_inx in input_image_sizes.keys(): condition_tokens.dtype
for start_inx, end_inx in input_image_sizes[b_inx]: )
# replace the placeholder in text tokens with the image embedding. input_img_inx += 1
condition_tokens[b_inx, start_inx:end_inx] = input_image_tokens[input_img_inx].to(
condition_tokens.dtype
)
input_img_inx += 1
return condition_tokens return condition_tokens
def forward( def forward(
...@@ -593,106 +414,55 @@ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -593,106 +414,55 @@ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
input_image_sizes: Dict[int, List[int]], input_image_sizes: Dict[int, List[int]],
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True, return_dict: bool = True,
): ) -> Union[Transformer2DModelOutput, Tuple[torch.Tensor]]:
""" batch_size, num_channels, height, width = hidden_states.shape
The [`OmniGenTransformer2DModel`] forward method. p = self.config.patch_size
post_patch_height, post_patch_width = height // p, width // p
Args:
hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
timestep (`torch.FloatTensor`):
Used to indicate denoising step.
input_ids (`torch.LongTensor`):
token ids
input_img_latents (`torch.Tensor`):
encoded image latents by VAE
input_image_sizes (`dict`):
the indices of the input_img_latents in the input_ids
attention_mask (`torch.Tensor`):
mask for self-attention
position_ids (`torch.LongTensor`):
id to represent position
past_key_values (`transformers.cache_utils.Cache`):
previous key and value states
offload_transformer_block (`bool`, *optional*, defaults to `True`):
offload transformer block to cpu
attention_kwargs: (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`OmniGen2DModelOutput`] instead of a plain tuple.
Returns:
If `return_dict` is True, an [`OmniGen2DModelOutput`] is returned, otherwise a `tuple` where the first
element is the sample tensor.
"""
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND: # 1. Patch & Timestep & Conditional Embedding
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
height, width = hidden_states.size()[-2:]
hidden_states = self.patch_embedding(hidden_states, is_input_image=False) hidden_states = self.patch_embedding(hidden_states, is_input_image=False)
num_tokens_for_output_image = hidden_states.size(1) num_tokens_for_output_image = hidden_states.size(1)
time_token = self.time_token(self.time_proj(timestep).to(hidden_states.dtype)).unsqueeze(1) timestep_proj = self.time_proj(timestep).type_as(hidden_states)
time_token = self.time_token(timestep_proj).unsqueeze(1)
temb = self.t_embedder(timestep_proj)
condition_tokens = self.get_multimodal_embeddings( condition_tokens = self._get_multimodal_embeddings(input_ids, input_img_latents, input_image_sizes)
input_ids=input_ids,
input_img_latents=input_img_latents,
input_image_sizes=input_image_sizes,
)
if condition_tokens is not None: if condition_tokens is not None:
inputs_embeds = torch.cat([condition_tokens, time_token, hidden_states], dim=1) hidden_states = torch.cat([condition_tokens, time_token, hidden_states], dim=1)
else: else:
inputs_embeds = torch.cat([time_token, hidden_states], dim=1) hidden_states = torch.cat([time_token, hidden_states], dim=1)
batch_size, seq_length = inputs_embeds.shape[:2] seq_length = hidden_states.size(1)
position_ids = position_ids.view(-1, seq_length).long() position_ids = position_ids.view(-1, seq_length).long()
# 2. Attention mask preprocessing
if attention_mask is not None and attention_mask.dim() == 3: if attention_mask is not None and attention_mask.dim() == 3:
dtype = inputs_embeds.dtype dtype = hidden_states.dtype
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
attention_mask = (1 - attention_mask) * min_dtype attention_mask = (1 - attention_mask) * min_dtype
attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype) attention_mask = attention_mask.unsqueeze(1).type_as(hidden_states)
else:
raise Exception("attention_mask parameter was unavailable or invalid")
hidden_states = inputs_embeds # 3. Rotary position embedding
image_rotary_emb = self.rope(hidden_states, position_ids)
image_rotary_emb = self.rotary_emb(hidden_states, position_ids) # 4. Transformer blocks
for decoder_layer in self.layers: for block in self.layers:
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func( hidden_states = self._gradient_checkpointing_func(
decoder_layer, hidden_states, attention_mask, image_rotary_emb block, hidden_states, attention_mask, image_rotary_emb
) )
else: else:
hidden_states = decoder_layer( hidden_states = block(hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb)
hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb
)
# 5. Output norm & projection
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
hidden_states = hidden_states[:, -num_tokens_for_output_image:] hidden_states = hidden_states[:, -num_tokens_for_output_image:]
timestep_proj = self.time_proj(timestep)
temb = self.t_embedder(timestep_proj.type_as(hidden_states))
hidden_states = self.norm_out(hidden_states, temb=temb) hidden_states = self.norm_out(hidden_states, temb=temb)
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)
output = self.unpatchify(hidden_states, height, width) hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, p, p, -1)
output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
if not return_dict: if not return_dict:
return (output,) return (output,)
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from typing import Any, Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -23,11 +23,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor ...@@ -23,11 +23,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models.autoencoders import AutoencoderKL from ...models.autoencoders import AutoencoderKL
from ...models.transformers import OmniGenTransformer2DModel from ...models.transformers import OmniGenTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import ( from ...utils import is_torch_xla_available, logging, replace_example_docstring
is_torch_xla_available,
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .processor_omnigen import OmniGenMultiModalProcessor from .processor_omnigen import OmniGenMultiModalProcessor
...@@ -48,11 +44,12 @@ EXAMPLE_DOC_STRING = """ ...@@ -48,11 +44,12 @@ EXAMPLE_DOC_STRING = """
>>> pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1-diffusers", torch_dtype=torch.bfloat16) >>> pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1-diffusers", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda") >>> pipe.to("cuda")
>>> prompt = "A cat holding a sign that says hello world" >>> prompt = "A cat holding a sign that says hello world"
>>> # Depending on the variant being used, the pipeline call will slightly vary. >>> # Depending on the variant being used, the pipeline call will slightly vary.
>>> # Refer to the pipeline documentation for more details. >>> # Refer to the pipeline documentation for more details.
>>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0] >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0]
>>> image.save("t2i.png") >>> image.save("output.png")
``` ```
""" """
...@@ -200,7 +197,6 @@ class OmniGenPipeline( ...@@ -200,7 +197,6 @@ class OmniGenPipeline(
width, width,
use_input_image_size_as_output, use_input_image_size_as_output,
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
): ):
if input_images is not None: if input_images is not None:
if len(input_images) != len(prompt): if len(input_images) != len(prompt):
...@@ -324,10 +320,8 @@ class OmniGenPipeline( ...@@ -324,10 +320,8 @@ class OmniGenPipeline(
latents: Optional[torch.Tensor] = None, latents: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 120000,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -376,10 +370,6 @@ class OmniGenPipeline( ...@@ -376,10 +370,6 @@ class OmniGenPipeline(
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*): callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
...@@ -389,7 +379,6 @@ class OmniGenPipeline( ...@@ -389,7 +379,6 @@ class OmniGenPipeline(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class. `._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
Examples: Examples:
...@@ -414,11 +403,9 @@ class OmniGenPipeline( ...@@ -414,11 +403,9 @@ class OmniGenPipeline(
width, width,
use_input_image_size_as_output, use_input_image_size_as_output,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
) )
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._interrupt = False self._interrupt = False
# 2. Define call parameters # 2. Define call parameters
...@@ -451,7 +438,8 @@ class OmniGenPipeline( ...@@ -451,7 +438,8 @@ class OmniGenPipeline(
) )
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
# 6. Prepare latents. # 6. Prepare latents
transformer_dtype = self.transformer.dtype
if use_input_image_size_as_output: if use_input_image_size_as_output:
height, width = processed_data["input_pixel_values"][0].shape[-2:] height, width = processed_data["input_pixel_values"][0].shape[-2:]
latent_channels = self.transformer.config.in_channels latent_channels = self.transformer.config.in_channels
...@@ -460,7 +448,7 @@ class OmniGenPipeline( ...@@ -460,7 +448,7 @@ class OmniGenPipeline(
latent_channels, latent_channels,
height, height,
width, width,
self.transformer.dtype, torch.float32,
device, device,
generator, generator,
latents, latents,
...@@ -471,6 +459,7 @@ class OmniGenPipeline( ...@@ -471,6 +459,7 @@ class OmniGenPipeline(
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * (num_cfg + 1)) latent_model_input = torch.cat([latents] * (num_cfg + 1))
latent_model_input = latent_model_input.to(transformer_dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]) timestep = t.expand(latent_model_input.shape[0])
...@@ -483,7 +472,6 @@ class OmniGenPipeline( ...@@ -483,7 +472,6 @@ class OmniGenPipeline(
input_image_sizes=processed_data["input_image_sizes"], input_image_sizes=processed_data["input_image_sizes"],
attention_mask=processed_data["attention_mask"], attention_mask=processed_data["attention_mask"],
position_ids=processed_data["position_ids"], position_ids=processed_data["position_ids"],
attention_kwargs=attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
...@@ -495,7 +483,6 @@ class OmniGenPipeline( ...@@ -495,7 +483,6 @@ class OmniGenPipeline(
noise_pred = uncond + guidance_scale * (cond - uncond) noise_pred = uncond + guidance_scale * (cond - uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None: if callback_on_step_end is not None:
...@@ -506,11 +493,6 @@ class OmniGenPipeline( ...@@ -506,11 +493,6 @@ class OmniGenPipeline(
latents = callback_outputs.pop("latents", latents) latents = callback_outputs.pop("latents", latents)
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
progress_bar.update() progress_bar.update()
if not output_type == "latent": if not output_type == "latent":
......
...@@ -18,17 +18,10 @@ from ..test_pipelines_common import PipelineTesterMixin ...@@ -18,17 +18,10 @@ from ..test_pipelines_common import PipelineTesterMixin
class OmniGenPipelineFastTests(unittest.TestCase, PipelineTesterMixin): class OmniGenPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = OmniGenPipeline pipeline_class = OmniGenPipeline
params = frozenset( params = frozenset(["prompt", "guidance_scale"])
[ batch_params = frozenset(["prompt"])
"prompt",
"guidance_scale", test_layerwise_casting = True
]
)
batch_params = frozenset(
[
"prompt",
]
)
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment