Unverified Commit 57ac6738 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Refactor OmniGen (#10771)



* OmniGen model.py

* update OmniGenTransformerModel

* omnigen pipeline

* omnigen pipeline

* update omnigen_pipeline

* test case for omnigen

* update omnigenpipeline

* update docs

* update docs

* offload_transformer

* enable_transformer_block_cpu_offload

* update docs

* reformat

* reformat

* reformat

* update docs

* update docs

* make style

* make style

* Update docs/source/en/api/models/omnigen_transformer.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* update docs

* revert changes to examples/

* update OmniGen2DModel

* make style

* update test cases

* Update docs/source/en/api/pipelines/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/using-diffusers/omnigen.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* update docs

* typo

* Update src/diffusers/models/embeddings.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/models/attention.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/models/transformers/transformer_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/models/transformers/transformer_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/models/transformers/transformer_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update tests/pipelines/omnigen/test_pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update tests/pipelines/omnigen/test_pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* consistent attention processor

* updata

* update

* check_inputs

* make style

* update testpipeline

* update testpipeline

* refactor omnigen

* more updates

* apply review suggestion

---------
Co-authored-by: default avatarshitao <2906698981@qq.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent 81440fd4
......@@ -14,6 +14,17 @@ specific language governing permissions and limitations under the License.
A Transformer model that accepts multimodal instructions to generate images for [OmniGen](https://github.com/VectorSpaceLab/OmniGen/).
The abstract from the paper is:
*The emergence of Large Language Models (LLMs) has unified language generation tasks and revolutionized human-machine interaction. However, in the realm of image generation, a unified model capable of handling various tasks within a single framework remains largely unexplored. In this work, we introduce OmniGen, a new diffusion model for unified image generation. OmniGen is characterized by the following features: 1) Unification: OmniGen not only demonstrates text-to-image generation capabilities but also inherently supports various downstream tasks, such as image editing, subject-driven generation, and visual conditional generation. 2) Simplicity: The architecture of OmniGen is highly simplified, eliminating the need for additional plugins. Moreover, compared to existing diffusion models, it is more user-friendly and can complete complex tasks end-to-end through instructions without the need for extra intermediate steps, greatly simplifying the image generation workflow. 3) Knowledge Transfer: Benefit from learning in a unified format, OmniGen effectively transfers knowledge across different tasks, manages unseen tasks and domains, and exhibits novel capabilities. We also explore the model’s reasoning capabilities and potential applications of the chain-of-thought mechanism. This work represents the first attempt at a general-purpose image generation model, and we will release our resources at https://github.com/VectorSpaceLab/OmniGen to foster future advancements.*
```python
import torch
from diffusers import OmniGenTransformer2DModel
transformer = OmniGenTransformer2DModel.from_pretrained("Shitao/OmniGen-v1-diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## OmniGenTransformer2DModel
[[autodoc]] OmniGenTransformer2DModel
......@@ -19,27 +19,7 @@
The abstract from the paper is:
*The emergence of Large Language Models (LLMs) has unified language
generation tasks and revolutionized human-machine interaction.
However, in the realm of image generation, a unified model capable of handling various tasks
within a single framework remains largely unexplored. In
this work, we introduce OmniGen, a new diffusion model
for unified image generation. OmniGen is characterized
by the following features: 1) Unification: OmniGen not
only demonstrates text-to-image generation capabilities but
also inherently supports various downstream tasks, such
as image editing, subject-driven generation, and visual conditional generation. 2) Simplicity: The architecture of
OmniGen is highly simplified, eliminating the need for additional plugins. Moreover, compared to existing diffusion
models, it is more user-friendly and can complete complex
tasks end-to-end through instructions without the need for
extra intermediate steps, greatly simplifying the image generation workflow. 3) Knowledge Transfer: Benefit from
learning in a unified format, OmniGen effectively transfers
knowledge across different tasks, manages unseen tasks and
domains, and exhibits novel capabilities. We also explore
the model’s reasoning capabilities and potential applications of the chain-of-thought mechanism.
This work represents the first attempt at a general-purpose image generation model,
and we will release our resources at https:
//github.com/VectorSpaceLab/OmniGen to foster future advancements.*
*The emergence of Large Language Models (LLMs) has unified language generation tasks and revolutionized human-machine interaction. However, in the realm of image generation, a unified model capable of handling various tasks within a single framework remains largely unexplored. In this work, we introduce OmniGen, a new diffusion model for unified image generation. OmniGen is characterized by the following features: 1) Unification: OmniGen not only demonstrates text-to-image generation capabilities but also inherently supports various downstream tasks, such as image editing, subject-driven generation, and visual conditional generation. 2) Simplicity: The architecture of OmniGen is highly simplified, eliminating the need for additional plugins. Moreover, compared to existing diffusion models, it is more user-friendly and can complete complex tasks end-to-end through instructions without the need for extra intermediate steps, greatly simplifying the image generation workflow. 3) Knowledge Transfer: Benefit from learning in a unified format, OmniGen effectively transfers knowledge across different tasks, manages unseen tasks and domains, and exhibits novel capabilities. We also explore the model’s reasoning capabilities and potential applications of the chain-of-thought mechanism. This work represents the first attempt at a general-purpose image generation model, and we will release our resources at https://github.com/VectorSpaceLab/OmniGen to foster future advancements.*
<Tip>
......@@ -49,7 +29,6 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
This pipeline was contributed by [staoxiao](https://github.com/staoxiao). The original codebase can be found [here](https://github.com/VectorSpaceLab/OmniGen). The original weights can be found under [hf.co/shitao](https://huggingface.co/Shitao/OmniGen-v1).
## Inference
First, load the pipeline:
......@@ -57,17 +36,15 @@ First, load the pipeline:
```python
import torch
from diffusers import OmniGenPipeline
pipe = OmniGenPipeline.from_pretrained(
"Shitao/OmniGen-v1-diffusers",
torch_dtype=torch.bfloat16
)
pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1-diffusers", torch_dtype=torch.bfloat16)
pipe.to("cuda")
```
For text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image.
You can try setting the `height` and `width` parameters to generate images with different size.
```py
```python
prompt = "Realistic photo. A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD."
image = pipe(
prompt=prompt,
......@@ -76,14 +53,14 @@ image = pipe(
guidance_scale=3,
generator=torch.Generator(device="cpu").manual_seed(111),
).images[0]
image
image.save("output.png")
```
OmniGen supports multimodal inputs.
When the input includes an image, you need to add a placeholder `<img><|image_1|></img>` in the text prompt to represent the image.
It is recommended to enable `use_input_image_size_as_output` to keep the edited image the same size as the original image.
```py
```python
prompt="<img><|image_1|></img> Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola."
input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png")]
image = pipe(
......@@ -93,14 +70,11 @@ image = pipe(
img_guidance_scale=1.6,
use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(222)).images[0]
image
image.save("output.png")
```
## OmniGenPipeline
[[autodoc]] OmniGenPipeline
- all
- __call__
......@@ -19,25 +19,22 @@ For more information, please refer to the [paper](https://arxiv.org/pdf/2409.113
This guide will walk you through using OmniGen for various tasks and use cases.
## Load model checkpoints
Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method.
```py
```python
import torch
from diffusers import OmniGenPipeline
pipe = OmniGenPipeline.from_pretrained(
"Shitao/OmniGen-v1-diffusers",
torch_dtype=torch.bfloat16
)
```
pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1-diffusers", torch_dtype=torch.bfloat16)
```
## Text-to-image
For text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image.
You can try setting the `height` and `width` parameters to generate images with different size.
```py
```python
import torch
from diffusers import OmniGenPipeline
......@@ -55,8 +52,9 @@ image = pipe(
guidance_scale=3,
generator=torch.Generator(device="cpu").manual_seed(111),
).images[0]
image
image.save("output.png")
```
<div class="flex justify-center">
<img src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png" alt="generated image"/>
</div>
......@@ -67,7 +65,7 @@ OmniGen supports multimodal inputs.
When the input includes an image, you need to add a placeholder `<img><|image_1|></img>` in the text prompt to represent the image.
It is recommended to enable `use_input_image_size_as_output` to keep the edited image the same size as the original image.
```py
```python
import torch
from diffusers import OmniGenPipeline
from diffusers.utils import load_image
......@@ -86,9 +84,11 @@ image = pipe(
guidance_scale=2,
img_guidance_scale=1.6,
use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(222)).images[0]
image
generator=torch.Generator(device="cpu").manual_seed(222)
).images[0]
image.save("output.png")
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png"/>
......@@ -101,7 +101,8 @@ image
</div>
OmniGen has some interesting features, such as visual reasoning, as shown in the example below.
```py
```python
prompt="If the woman is thirsty, what should she take? Find it in the image and highlight it in blue. <img><|image_1|></img>"
input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png")]
image = pipe(
......@@ -110,20 +111,20 @@ image = pipe(
guidance_scale=2,
img_guidance_scale=1.6,
use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(0)).images[0]
image
generator=torch.Generator(device="cpu").manual_seed(0)
).images[0]
image.save("output.png")
```
<div class="flex justify-center">
<img src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/reasoning.png" alt="generated image"/>
</div>
## Controllable generation
OmniGen can handle several classic computer vision tasks.
As shown below, OmniGen can detect human skeletons in input images, which can be used as control conditions to generate new images.
OmniGen can handle several classic computer vision tasks. As shown below, OmniGen can detect human skeletons in input images, which can be used as control conditions to generate new images.
```py
```python
import torch
from diffusers import OmniGenPipeline
from diffusers.utils import load_image
......@@ -142,8 +143,9 @@ image1 = pipe(
guidance_scale=2,
img_guidance_scale=1.6,
use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(333)).images[0]
image1
generator=torch.Generator(device="cpu").manual_seed(333)
).images[0]
image1.save("image1.png")
prompt="Generate a new photo using the following picture and text as conditions: <img><|image_1|></img>\n A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him."
input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/skeletal.png")]
......@@ -153,8 +155,9 @@ image2 = pipe(
guidance_scale=2,
img_guidance_scale=1.6,
use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(333)).images[0]
image2
generator=torch.Generator(device="cpu").manual_seed(333)
).images[0]
image2.save("image2.png")
```
<div class="flex flex-row gap-4">
......@@ -174,7 +177,8 @@ image2
OmniGen can also directly use relevant information from input images to generate new images.
```py
```python
import torch
from diffusers import OmniGenPipeline
from diffusers.utils import load_image
......@@ -193,9 +197,11 @@ image = pipe(
guidance_scale=2,
img_guidance_scale=1.6,
use_input_image_size_as_output=True,
generator=torch.Generator(device="cpu").manual_seed(0)).images[0]
image
generator=torch.Generator(device="cpu").manual_seed(0)
).images[0]
image.save("output.png")
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/same_pose.png"/>
......@@ -203,13 +209,12 @@ image
</div>
</div>
## ID and object preserving
OmniGen can generate multiple images based on the people and objects in the input image and supports inputting multiple images simultaneously.
Additionally, OmniGen can extract desired objects from an image containing multiple objects based on instructions.
```py
```python
import torch
from diffusers import OmniGenPipeline
from diffusers.utils import load_image
......@@ -231,9 +236,11 @@ image = pipe(
width=1024,
guidance_scale=2.5,
img_guidance_scale=1.6,
generator=torch.Generator(device="cpu").manual_seed(666)).images[0]
image
generator=torch.Generator(device="cpu").manual_seed(666)
).images[0]
image.save("output.png")
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/3.png"/>
......@@ -249,7 +256,6 @@ image
</div>
</div>
```py
import torch
from diffusers import OmniGenPipeline
......@@ -261,7 +267,6 @@ pipe = OmniGenPipeline.from_pretrained(
)
pipe.to("cuda")
prompt="A woman is walking down the street, wearing a white long-sleeve blouse with lace details on the sleeves, paired with a blue pleated skirt. The woman is <img><|image_1|></img>. The long-sleeve blouse and a pleated skirt are <img><|image_2|></img>."
input_image_1 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/emma.jpeg")
input_image_2 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/dress.jpg")
......@@ -273,8 +278,9 @@ image = pipe(
width=1024,
guidance_scale=2.5,
img_guidance_scale=1.6,
generator=torch.Generator(device="cpu").manual_seed(666)).images[0]
image
generator=torch.Generator(device="cpu").manual_seed(666)
).images[0]
image.save("output.png")
```
<div class="flex flex-row gap-4">
......@@ -292,13 +298,12 @@ image
</div>
</div>
## Optimization when inputting multiple images
## Optimization when using multiple images
For text-to-image task, OmniGen requires minimal memory and time costs (9GB memory and 31s for a 1024x1024 image on A800 GPU).
However, when using input images, the computational cost increases.
Here are some guidelines to help you reduce computational costs when inputting multiple images. The experiments are conducted on an A800 GPU with two input images.
Here are some guidelines to help you reduce computational costs when using multiple images. The experiments are conducted on an A800 GPU with two input images.
Like other pipelines, you can reduce memory usage by offloading the model: `pipe.enable_model_cpu_offload()` or `pipe.enable_sequential_cpu_offload() `.
In OmniGen, you can also decrease computational overhead by reducing the `max_input_image_size`.
......@@ -310,5 +315,3 @@ The memory consumption for different image sizes is shown in the table below:
| max_input_image_size=512 | 17GB |
| max_input_image_size=256 | 14GB |
......@@ -1199,7 +1199,7 @@ def apply_rotary_emb(
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
# Used for Stable Audio
# Used for Stable Audio and OmniGen
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
......
......@@ -13,17 +13,15 @@
# limitations under the License.
import math
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers
from ..attention_processor import Attention, AttentionProcessor
from ...utils import logging
from ..attention_processor import Attention
from ..embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
......@@ -34,39 +32,21 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class OmniGenFeedForward(nn.Module):
r"""
A feed-forward layer for OmniGen.
Parameters:
hidden_size (`int`):
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
hidden representations.
intermediate_size (`int`): The intermediate dimension of the feedforward layer.
"""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
):
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.gate_up_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
self.activation_fn = nn.SiLU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
up_states = self.gate_up_proj(hidden_states)
gate, up_states = up_states.chunk(2, dim=-1)
up_states = up_states * self.activation_fn(gate)
return self.down_proj(up_states)
class OmniGenPatchEmbed(nn.Module):
"""2D Image to Patch Embedding with support for OmniGen."""
def __init__(
self,
patch_size: int = 2,
......@@ -99,7 +79,7 @@ class OmniGenPatchEmbed(nn.Module):
)
self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=True)
def cropped_pos_embed(self, height, width):
def _cropped_pos_embed(self, height, width):
"""Crops positional embeddings for SD3 compatibility."""
if self.pos_embed_max_size is None:
raise ValueError("`pos_embed_max_size` must be set for cropping.")
......@@ -122,43 +102,34 @@ class OmniGenPatchEmbed(nn.Module):
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
return spatial_pos_embed
def patch_embeddings(self, latent, is_input_image: bool):
def _patch_embeddings(self, hidden_states: torch.Tensor, is_input_image: bool) -> torch.Tensor:
if is_input_image:
latent = self.input_image_proj(latent)
hidden_states = self.input_image_proj(hidden_states)
else:
latent = self.output_image_proj(latent)
latent = latent.flatten(2).transpose(1, 2)
return latent
def forward(self, latent: torch.Tensor, is_input_image: bool, padding_latent: torch.Tensor = None):
"""
Args:
latent: encoded image latents
is_input_image: use input_image_proj or output_image_proj
padding_latent:
When sizes of target images are inconsistent, use `padding_latent` to maintain consistent sequence
length.
Returns: torch.Tensor
"""
if isinstance(latent, list):
hidden_states = self.output_image_proj(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
return hidden_states
def forward(
self, hidden_states: torch.Tensor, is_input_image: bool, padding_latent: torch.Tensor = None
) -> torch.Tensor:
if isinstance(hidden_states, list):
if padding_latent is None:
padding_latent = [None] * len(latent)
padding_latent = [None] * len(hidden_states)
patched_latents = []
for sub_latent, padding in zip(latent, padding_latent):
for sub_latent, padding in zip(hidden_states, padding_latent):
height, width = sub_latent.shape[-2:]
sub_latent = self.patch_embeddings(sub_latent, is_input_image)
pos_embed = self.cropped_pos_embed(height, width)
sub_latent = self._patch_embeddings(sub_latent, is_input_image)
pos_embed = self._cropped_pos_embed(height, width)
sub_latent = sub_latent + pos_embed
if padding is not None:
sub_latent = torch.cat([sub_latent, padding.to(sub_latent.device)], dim=-2)
patched_latents.append(sub_latent)
else:
height, width = latent.shape[-2:]
pos_embed = self.cropped_pos_embed(height, width)
latent = self.patch_embeddings(latent, is_input_image)
patched_latents = latent + pos_embed
height, width = hidden_states.shape[-2:]
pos_embed = self._cropped_pos_embed(height, width)
hidden_states = self._patch_embeddings(hidden_states, is_input_image)
patched_latents = hidden_states + pos_embed
return patched_latents
......@@ -180,15 +151,16 @@ class OmniGenSuScaledRotaryEmbedding(nn.Module):
self.long_factor = rope_scaling["long_factor"]
self.original_max_position_embeddings = original_max_position_embeddings
@torch.no_grad()
def forward(self, x, position_ids):
def forward(self, hidden_states, position_ids):
seq_len = torch.max(position_ids) + 1
if seq_len > self.original_max_position_embeddings:
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=hidden_states.device)
else:
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=hidden_states.device)
inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
inv_freq_shape = (
torch.arange(0, self.dim, 2, dtype=torch.int64, device=hidden_states.device).float() / self.dim
)
self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
......@@ -196,11 +168,11 @@ class OmniGenSuScaledRotaryEmbedding(nn.Module):
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = hidden_states.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
emb = torch.cat((freqs, freqs), dim=-1)[0]
scale = self.max_position_embeddings / self.original_max_position_embeddings
if scale <= 1.0:
......@@ -210,44 +182,7 @@ class OmniGenSuScaledRotaryEmbedding(nn.Module):
cos = emb.cos() * scaling_factor
sin = emb.sin() * scaling_factor
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def apply_rotary_emb(
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
cos, sin = freqs_cis # [S, D]
if len(cos.shape) == 2:
cos = cos[None, None]
sin = sin[None, None]
elif len(cos.shape) == 3:
cos = cos[:, None]
sin = sin[:, None]
cos, sin = cos.to(x.device), sin.to(x.device)
# Rotates half the hidden dims of the input. this rorate function is widely used in LLM, e.g. Llama, Phi3, etc.
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
x_rotated = torch.cat((-x2, x1), dim=-1)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
return cos, sin
class OmniGenAttnProcessor2_0:
......@@ -278,7 +213,6 @@ class OmniGenAttnProcessor2_0:
bsz, q_len, query_dim = query.size()
inner_dim = key.shape[-1]
head_dim = query_dim // attn.heads
dtype = query.dtype
# Get key-value heads
kv_heads = inner_dim // head_dim
......@@ -289,32 +223,19 @@ class OmniGenAttnProcessor2_0:
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
from ..embeddings import apply_rotary_emb
query, key = query.to(dtype), key.to(dtype)
query = apply_rotary_emb(query, image_rotary_emb, use_real_unbind_dim=-2)
key = apply_rotary_emb(key, image_rotary_emb, use_real_unbind_dim=-2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
hidden_states = hidden_states.transpose(1, 2).to(dtype)
hidden_states = hidden_states.transpose(1, 2).type_as(query)
hidden_states = hidden_states.reshape(bsz, q_len, attn.out_dim)
hidden_states = attn.to_out[0](hidden_states)
return hidden_states
class OmniGenBlock(nn.Module):
"""
A LuminaNextDiTBlock for LuminaNextDiT2DModel.
Parameters:
hidden_size (`int`): Embedding dimension of the input features.
num_attention_heads (`int`): Number of attention heads.
num_key_value_heads (`int`):
Number of attention heads in key and value features (if using GQA), or set to None for the same as query.
intermediate_size (`int`): size of intermediate layer.
rms_norm_eps (`float`): The eps for norm layer.
"""
def __init__(
self,
hidden_size: int,
......@@ -341,78 +262,77 @@ class OmniGenBlock(nn.Module):
self.mlp = OmniGenFeedForward(hidden_size, intermediate_size)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
image_rotary_emb: torch.Tensor,
):
"""
Perform a forward pass through the LuminaNextDiTBlock.
Parameters:
hidden_states (`torch.Tensor`): The input of hidden_states for LuminaNextDiTBlock.
attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask.
image_rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies.
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
attn_outputs = self.self_attn(
hidden_states=hidden_states,
encoder_hidden_states=hidden_states,
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor
) -> torch.Tensor:
# 1. Attention
norm_hidden_states = self.input_layernorm(hidden_states)
attn_output = self.self_attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + attn_output
hidden_states = residual + attn_outputs
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
# 2. Feed Forward
norm_hidden_states = self.post_attention_layernorm(hidden_states)
ff_output = self.mlp(norm_hidden_states)
hidden_states = hidden_states + ff_output
return hidden_states
class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
class OmniGenTransformer2DModel(ModelMixin, ConfigMixin):
"""
The Transformer model introduced in OmniGen.
Reference: https://arxiv.org/pdf/2409.11340
The Transformer model introduced in OmniGen (https://arxiv.org/pdf/2409.11340).
Parameters:
hidden_size (`int`, *optional*, defaults to 3072):
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
hidden representations.
rms_norm_eps (`float`, *optional*, defaults to 1e-5): eps for RMSNorm layer.
num_attention_heads (`int`, *optional*, defaults to 32):
The number of attention heads in each attention layer. This parameter specifies how many separate attention
mechanisms are used.
num_kv_heads (`int`, *optional*, defaults to 32):
The number of key-value heads in the attention mechanism, if different from the number of attention heads.
If None, it defaults to num_attention_heads.
intermediate_size (`int`, *optional*, defaults to 8192): dimension of the intermediate layer in FFN
num_layers (`int`, *optional*, default to 32):
The number of layers in the model. This defines the depth of the neural network.
pad_token_id (`int`, *optional*, default to 32000):
id for pad token
vocab_size (`int`, *optional*, default to 32064):
size of vocabulary
patch_size (`int`, defaults to 2): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input.
pos_embed_max_size (`int`, *optional*, defaults to 192): The max size of pos emb.
in_channels (`int`, defaults to `4`):
The number of channels in the input.
patch_size (`int`, defaults to `2`):
The size of the spatial patches to use in the patch embedding layer.
hidden_size (`int`, defaults to `3072`):
The dimensionality of the hidden layers in the model.
rms_norm_eps (`float`, defaults to `1e-5`):
Eps for RMSNorm layer.
num_attention_heads (`int`, defaults to `32`):
The number of heads to use for multi-head attention.
num_key_value_heads (`int`, defaults to `32`):
The number of heads to use for keys and values in multi-head attention.
intermediate_size (`int`, defaults to `8192`):
Dimension of the hidden layer in FeedForward layers.
num_layers (`int`, default to `32`):
The number of layers of transformer blocks to use.
pad_token_id (`int`, default to `32000`):
The id of the padding token.
vocab_size (`int`, default to `32064`):
The size of the vocabulary of the embedding vocabulary.
rope_base (`int`, default to `10000`):
The default theta value to use when creating RoPE.
rope_scaling (`Dict`, optional):
The scaling factors for the RoPE. Must contain `short_factor` and `long_factor`.
pos_embed_max_size (`int`, default to `192`):
The maximum size of the positional embeddings.
time_step_dim (`int`, default to `256`):
Output dimension of timestep embeddings.
flip_sin_to_cos (`bool`, default to `True`):
Whether to flip the sin and cos in the positional embeddings when preparing timestep embeddings.
downscale_freq_shift (`int`, default to `0`):
The frequency shift to use when downscaling the timestep embeddings.
timestep_activation_fn (`str`, default to `silu`):
The activation function to use for the timestep embeddings.
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["OmniGenBlock"]
_skip_layerwise_casting_patterns = ["patch_embedding", "embed_tokens", "norm"]
@register_to_config
def __init__(
self,
in_channels: int = 4,
patch_size: int = 2,
hidden_size: int = 3072,
rms_norm_eps: float = 1e-05,
rms_norm_eps: float = 1e-5,
num_attention_heads: int = 32,
num_key_value_heads: int = 32,
intermediate_size: int = 8192,
......@@ -423,8 +343,6 @@ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
original_max_position_embeddings: int = 4096,
rope_base: int = 10000,
rope_scaling: Dict = None,
patch_size=2,
in_channels=4,
pos_embed_max_size: int = 192,
time_step_dim: int = 256,
flip_sin_to_cos: bool = True,
......@@ -434,8 +352,6 @@ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.pos_embed_max_size = pos_embed_max_size
self.patch_embedding = OmniGenPatchEmbed(
patch_size=patch_size,
......@@ -448,11 +364,8 @@ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.time_token = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn)
self.t_embedder = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn)
self.norm_out = AdaLayerNorm(hidden_size, norm_elementwise_affine=False, norm_eps=1e-6, chunk_dim=1)
self.proj_out = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True)
self.embed_tokens = nn.Embedding(vocab_size, hidden_size, pad_token_id)
self.rotary_emb = OmniGenSuScaledRotaryEmbedding(
self.rope = OmniGenSuScaledRotaryEmbedding(
hidden_size // num_attention_heads,
max_position_embeddings=max_position_embeddings,
original_max_position_embeddings=original_max_position_embeddings,
......@@ -462,126 +375,34 @@ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.layers = nn.ModuleList(
[
OmniGenBlock(
hidden_size,
num_attention_heads,
num_key_value_heads,
intermediate_size,
rms_norm_eps,
)
OmniGenBlock(hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, rms_norm_eps)
for _ in range(num_layers)
]
)
self.norm = RMSNorm(hidden_size, eps=rms_norm_eps)
self.norm_out = AdaLayerNorm(hidden_size, norm_elementwise_affine=False, norm_eps=1e-6, chunk_dim=1)
self.proj_out = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
def unpatchify(self, x, h, w):
"""
x: (N, T, patch_size**2 * C) imgs: (N, H, W, C)
"""
c = self.out_channels
x = x.reshape(
shape=(x.shape[0], h // self.patch_size, w // self.patch_size, self.patch_size, self.patch_size, c)
)
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, h, w))
return imgs
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def _get_multimodal_embeddings(
self, input_ids: torch.Tensor, input_img_latents: List[torch.Tensor], input_image_sizes: Dict
) -> Optional[torch.Tensor]:
if input_ids is None:
return None
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[OmniGenAttnProcessor2_0, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def get_multimodal_embeddings(
self,
input_ids: torch.Tensor,
input_img_latents: List[torch.Tensor],
input_image_sizes: Dict,
):
"""
get the multi-modal conditional embeddings
Args:
input_ids: a sequence of text id
input_img_latents: continues embedding of input images
input_image_sizes: the index of the input image in the input_ids sequence.
Returns: torch.Tensor
"""
input_img_latents = [x.to(self.dtype) for x in input_img_latents]
condition_tokens = None
if input_ids is not None:
condition_tokens = self.embed_tokens(input_ids)
input_img_inx = 0
if input_img_latents is not None:
input_image_tokens = self.patch_embedding(input_img_latents, is_input_image=True)
for b_inx in input_image_sizes.keys():
for start_inx, end_inx in input_image_sizes[b_inx]:
# replace the placeholder in text tokens with the image embedding.
condition_tokens[b_inx, start_inx:end_inx] = input_image_tokens[input_img_inx].to(
condition_tokens.dtype
)
input_img_inx += 1
condition_tokens = self.embed_tokens(input_ids)
input_img_inx = 0
input_image_tokens = self.patch_embedding(input_img_latents, is_input_image=True)
for b_inx in input_image_sizes.keys():
for start_inx, end_inx in input_image_sizes[b_inx]:
# replace the placeholder in text tokens with the image embedding.
condition_tokens[b_inx, start_inx:end_inx] = input_image_tokens[input_img_inx].to(
condition_tokens.dtype
)
input_img_inx += 1
return condition_tokens
def forward(
......@@ -593,106 +414,55 @@ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
input_image_sizes: Dict[int, List[int]],
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
):
"""
The [`OmniGenTransformer2DModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
timestep (`torch.FloatTensor`):
Used to indicate denoising step.
input_ids (`torch.LongTensor`):
token ids
input_img_latents (`torch.Tensor`):
encoded image latents by VAE
input_image_sizes (`dict`):
the indices of the input_img_latents in the input_ids
attention_mask (`torch.Tensor`):
mask for self-attention
position_ids (`torch.LongTensor`):
id to represent position
past_key_values (`transformers.cache_utils.Cache`):
previous key and value states
offload_transformer_block (`bool`, *optional*, defaults to `True`):
offload transformer block to cpu
attention_kwargs: (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`OmniGen2DModelOutput`] instead of a plain tuple.
Returns:
If `return_dict` is True, an [`OmniGen2DModelOutput`] is returned, otherwise a `tuple` where the first
element is the sample tensor.
"""
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
) -> Union[Transformer2DModelOutput, Tuple[torch.Tensor]]:
batch_size, num_channels, height, width = hidden_states.shape
p = self.config.patch_size
post_patch_height, post_patch_width = height // p, width // p
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
height, width = hidden_states.size()[-2:]
# 1. Patch & Timestep & Conditional Embedding
hidden_states = self.patch_embedding(hidden_states, is_input_image=False)
num_tokens_for_output_image = hidden_states.size(1)
time_token = self.time_token(self.time_proj(timestep).to(hidden_states.dtype)).unsqueeze(1)
timestep_proj = self.time_proj(timestep).type_as(hidden_states)
time_token = self.time_token(timestep_proj).unsqueeze(1)
temb = self.t_embedder(timestep_proj)
condition_tokens = self.get_multimodal_embeddings(
input_ids=input_ids,
input_img_latents=input_img_latents,
input_image_sizes=input_image_sizes,
)
condition_tokens = self._get_multimodal_embeddings(input_ids, input_img_latents, input_image_sizes)
if condition_tokens is not None:
inputs_embeds = torch.cat([condition_tokens, time_token, hidden_states], dim=1)
hidden_states = torch.cat([condition_tokens, time_token, hidden_states], dim=1)
else:
inputs_embeds = torch.cat([time_token, hidden_states], dim=1)
hidden_states = torch.cat([time_token, hidden_states], dim=1)
batch_size, seq_length = inputs_embeds.shape[:2]
seq_length = hidden_states.size(1)
position_ids = position_ids.view(-1, seq_length).long()
# 2. Attention mask preprocessing
if attention_mask is not None and attention_mask.dim() == 3:
dtype = inputs_embeds.dtype
dtype = hidden_states.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = (1 - attention_mask) * min_dtype
attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype)
else:
raise Exception("attention_mask parameter was unavailable or invalid")
attention_mask = attention_mask.unsqueeze(1).type_as(hidden_states)
hidden_states = inputs_embeds
# 3. Rotary position embedding
image_rotary_emb = self.rope(hidden_states, position_ids)
image_rotary_emb = self.rotary_emb(hidden_states, position_ids)
for decoder_layer in self.layers:
# 4. Transformer blocks
for block in self.layers:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
decoder_layer, hidden_states, attention_mask, image_rotary_emb
block, hidden_states, attention_mask, image_rotary_emb
)
else:
hidden_states = decoder_layer(
hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb
)
hidden_states = block(hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb)
# 5. Output norm & projection
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states[:, -num_tokens_for_output_image:]
timestep_proj = self.time_proj(timestep)
temb = self.t_embedder(timestep_proj.type_as(hidden_states))
hidden_states = self.norm_out(hidden_states, temb=temb)
hidden_states = self.proj_out(hidden_states)
output = self.unpatchify(hidden_states, height, width)
hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, p, p, -1)
output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
if not return_dict:
return (output,)
......
......@@ -13,7 +13,7 @@
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import torch
......@@ -23,11 +23,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import OmniGenTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
is_torch_xla_available,
logging,
replace_example_docstring,
)
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .processor_omnigen import OmniGenMultiModalProcessor
......@@ -48,11 +44,12 @@ EXAMPLE_DOC_STRING = """
>>> pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1-diffusers", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "A cat holding a sign that says hello world"
>>> # Depending on the variant being used, the pipeline call will slightly vary.
>>> # Refer to the pipeline documentation for more details.
>>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0]
>>> image.save("t2i.png")
>>> image.save("output.png")
```
"""
......@@ -200,7 +197,6 @@ class OmniGenPipeline(
width,
use_input_image_size_as_output,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if input_images is not None:
if len(input_images) != len(prompt):
......@@ -324,10 +320,8 @@ class OmniGenPipeline(
latents: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 120000,
):
r"""
Function invoked when calling the pipeline for generation.
......@@ -376,10 +370,6 @@ class OmniGenPipeline(
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
......@@ -389,7 +379,6 @@ class OmniGenPipeline(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
Examples:
......@@ -414,11 +403,9 @@ class OmniGenPipeline(
width,
use_input_image_size_as_output,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._interrupt = False
# 2. Define call parameters
......@@ -451,7 +438,8 @@ class OmniGenPipeline(
)
self._num_timesteps = len(timesteps)
# 6. Prepare latents.
# 6. Prepare latents
transformer_dtype = self.transformer.dtype
if use_input_image_size_as_output:
height, width = processed_data["input_pixel_values"][0].shape[-2:]
latent_channels = self.transformer.config.in_channels
......@@ -460,7 +448,7 @@ class OmniGenPipeline(
latent_channels,
height,
width,
self.transformer.dtype,
torch.float32,
device,
generator,
latents,
......@@ -471,6 +459,7 @@ class OmniGenPipeline(
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * (num_cfg + 1))
latent_model_input = latent_model_input.to(transformer_dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
......@@ -483,7 +472,6 @@ class OmniGenPipeline(
input_image_sizes=processed_data["input_image_sizes"],
attention_mask=processed_data["attention_mask"],
position_ids=processed_data["position_ids"],
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
......@@ -495,7 +483,6 @@ class OmniGenPipeline(
noise_pred = uncond + guidance_scale * (cond - uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None:
......@@ -506,11 +493,6 @@ class OmniGenPipeline(
latents = callback_outputs.pop("latents", latents)
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
progress_bar.update()
if not output_type == "latent":
......
......@@ -18,17 +18,10 @@ from ..test_pipelines_common import PipelineTesterMixin
class OmniGenPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = OmniGenPipeline
params = frozenset(
[
"prompt",
"guidance_scale",
]
)
batch_params = frozenset(
[
"prompt",
]
)
params = frozenset(["prompt", "guidance_scale"])
batch_params = frozenset(["prompt"])
test_layerwise_casting = True
def get_dummy_components(self):
torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment