"vscode:/vscode.git/clone" did not exist on "e38b9ee4fc64fcb05447b6839615905adc5673f7"
Unverified Commit dd07b19e authored by David Bertoin's avatar David Bertoin Committed by GitHub
Browse files

Prx (#12525)

* rename photon to prx

* rename photon into prx

* Revert .gitignore to state before commit b7fb0fe9d63bf766bbe3c42ac154a043796dd370

* rename photon to prx

* rename photon into prx

* Revert .gitignore to state before commit b7fb0fe9d63bf766bbe3c42ac154a043796dd370

* make fix-copies
parent 57636ad4
...@@ -541,12 +541,12 @@ ...@@ -541,12 +541,12 @@
title: PAG title: PAG
- local: api/pipelines/paint_by_example - local: api/pipelines/paint_by_example
title: Paint by Example title: Paint by Example
- local: api/pipelines/photon
title: Photon
- local: api/pipelines/pixart - local: api/pipelines/pixart
title: PixArt-α title: PixArt-α
- local: api/pipelines/pixart_sigma - local: api/pipelines/pixart_sigma
title: PixArt-Σ title: PixArt-Σ
- local: api/pipelines/prx
title: PRX
- local: api/pipelines/qwenimage - local: api/pipelines/qwenimage
title: QwenImage title: QwenImage
- local: api/pipelines/sana - local: api/pipelines/sana
......
...@@ -12,43 +12,43 @@ ...@@ -12,43 +12,43 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. --> # limitations under the License. -->
# Photon # PRX
Photon generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing. PRX generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing.
## Available models ## Available models
Photon offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts. PRX offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts.
| Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype | | Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype |
|:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:| |:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:|
| [`Photoroom/photon-256-t2i`](https://huggingface.co/Photoroom/photon-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | | [`Photoroom/prx-256-t2i`](https://huggingface.co/Photoroom/prx-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-256-t2i-sft`](https://huggingface.co/Photoroom/photon-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` | | [`Photoroom/prx-256-t2i-sft`](https://huggingface.co/Photoroom/prx-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i`](https://huggingface.co/Photoroom/photon-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | | [`Photoroom/prx-512-t2i`](https://huggingface.co/Photoroom/prx-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i-sft`](https://huggingface.co/Photoroom/photon-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | | [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/photon-512-t2i-sft`](https://huggingface.co/Photoroom/photon-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` | | [`Photoroom/prx-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i-dc-ae`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | | [`Photoroom/prx-512-t2i-dc-ae`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | | [`Photoroom/prx-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/photon-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s | [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s
Refer to [this](https://huggingface.co/collections/Photoroom/photon-models-68e66254c202ebfab99ad38e) collection for more information. Refer to [this](https://huggingface.co/collections/Photoroom/prx-models-68e66254c202ebfab99ad38e) collection for more information.
## Loading the pipeline ## Loading the pipeline
Load the pipeline with [`~DiffusionPipeline.from_pretrained`]. Load the pipeline with [`~DiffusionPipeline.from_pretrained`].
```py ```py
from diffusers.pipelines.photon import PhotonPipeline from diffusers.pipelines.prx import PRXPipeline
# Load pipeline - VAE and text encoder will be loaded from HuggingFace # Load pipeline - VAE and text encoder will be loaded from HuggingFace
pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16) pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
pipe.to("cuda") pipe.to("cuda")
prompt = "A front-facing portrait of a lion the golden savanna at sunset." prompt = "A front-facing portrait of a lion the golden savanna at sunset."
image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0] image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
image.save("photon_output.png") image.save("prx_output.png")
``` ```
### Manual Component Loading ### Manual Component Loading
...@@ -57,9 +57,9 @@ Load components individually to customize the pipeline for instance to use quant ...@@ -57,9 +57,9 @@ Load components individually to customize the pipeline for instance to use quant
```py ```py
import torch import torch
from diffusers.pipelines.photon import PhotonPipeline from diffusers.pipelines.prx import PRXPipeline
from diffusers.models import AutoencoderKL, AutoencoderDC from diffusers.models import AutoencoderKL, AutoencoderDC
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from transformers import T5GemmaModel, GemmaTokenizerFast from transformers import T5GemmaModel, GemmaTokenizerFast
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
...@@ -67,8 +67,8 @@ from transformers import BitsAndBytesConfig as BitsAndBytesConfig ...@@ -67,8 +67,8 @@ from transformers import BitsAndBytesConfig as BitsAndBytesConfig
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True) quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
# Load transformer # Load transformer
transformer = PhotonTransformer2DModel.from_pretrained( transformer = PRXTransformer2DModel.from_pretrained(
"checkpoints/photon-512-t2i-sft", "checkpoints/prx-512-t2i-sft",
subfolder="transformer", subfolder="transformer",
quantization_config=quant_config, quantization_config=quant_config,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
...@@ -76,7 +76,7 @@ transformer = PhotonTransformer2DModel.from_pretrained( ...@@ -76,7 +76,7 @@ transformer = PhotonTransformer2DModel.from_pretrained(
# Load scheduler # Load scheduler
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
"checkpoints/photon-512-t2i-sft", subfolder="scheduler" "checkpoints/prx-512-t2i-sft", subfolder="scheduler"
) )
# Load T5Gemma text encoder # Load T5Gemma text encoder
...@@ -94,7 +94,7 @@ vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", ...@@ -94,7 +94,7 @@ vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev",
quantization_config=quant_config, quantization_config=quant_config,
torch_dtype=torch.bfloat16) torch_dtype=torch.bfloat16)
pipe = PhotonPipeline( pipe = PRXPipeline(
transformer=transformer, transformer=transformer,
scheduler=scheduler, scheduler=scheduler,
text_encoder=text_encoder, text_encoder=text_encoder,
...@@ -111,21 +111,21 @@ For memory-constrained environments: ...@@ -111,21 +111,21 @@ For memory-constrained environments:
```py ```py
import torch import torch
from diffusers.pipelines.photon import PhotonPipeline from diffusers.pipelines.prx import PRXPipeline
pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16) pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload() # Offload components to CPU when not in use pipe.enable_model_cpu_offload() # Offload components to CPU when not in use
# Or use sequential CPU offload for even lower memory # Or use sequential CPU offload for even lower memory
pipe.enable_sequential_cpu_offload() pipe.enable_sequential_cpu_offload()
``` ```
## PhotonPipeline ## PRXPipeline
[[autodoc]] PhotonPipeline [[autodoc]] PRXPipeline
- all - all
- __call__ - __call__
## PhotonPipelineOutput ## PRXPipelineOutput
[[autodoc]] pipelines.photon.pipeline_output.PhotonPipelineOutput [[autodoc]] pipelines.prx.pipeline_output.PRXPipelineOutput
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
Script to convert Photon checkpoint from original codebase to diffusers format. Script to convert PRX checkpoint from original codebase to diffusers format.
""" """
import argparse import argparse
...@@ -13,15 +13,15 @@ from typing import Dict, Tuple ...@@ -13,15 +13,15 @@ from typing import Dict, Tuple
import torch import torch
from safetensors.torch import save_file from safetensors.torch import save_file
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.pipelines.photon import PhotonPipeline from diffusers.pipelines.prx import PRXPipeline
DEFAULT_RESOLUTION = 512 DEFAULT_RESOLUTION = 512
@dataclass(frozen=True) @dataclass(frozen=True)
class PhotonBase: class PRXBase:
context_in_dim: int = 2304 context_in_dim: int = 2304
hidden_size: int = 1792 hidden_size: int = 1792
mlp_ratio: float = 3.5 mlp_ratio: float = 3.5
...@@ -34,22 +34,22 @@ class PhotonBase: ...@@ -34,22 +34,22 @@ class PhotonBase:
@dataclass(frozen=True) @dataclass(frozen=True)
class PhotonFlux(PhotonBase): class PRXFlux(PRXBase):
in_channels: int = 16 in_channels: int = 16
patch_size: int = 2 patch_size: int = 2
@dataclass(frozen=True) @dataclass(frozen=True)
class PhotonDCAE(PhotonBase): class PRXDCAE(PRXBase):
in_channels: int = 32 in_channels: int = 32
patch_size: int = 1 patch_size: int = 1
def build_config(vae_type: str) -> Tuple[dict, int]: def build_config(vae_type: str) -> Tuple[dict, int]:
if vae_type == "flux": if vae_type == "flux":
cfg = PhotonFlux() cfg = PRXFlux()
elif vae_type == "dc-ae": elif vae_type == "dc-ae":
cfg = PhotonDCAE() cfg = PRXDCAE()
else: else:
raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'") raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'")
...@@ -64,7 +64,7 @@ def create_parameter_mapping(depth: int) -> dict: ...@@ -64,7 +64,7 @@ def create_parameter_mapping(depth: int) -> dict:
# Key mappings for structural changes # Key mappings for structural changes
mapping = {} mapping = {}
# Map old structure (layers in PhotonBlock) to new structure (layers in PhotonAttention) # Map old structure (layers in PRXBlock) to new structure (layers in PRXAttention)
for i in range(depth): for i in range(depth):
# QKV projections moved to attention module # QKV projections moved to attention module
mapping[f"blocks.{i}.img_qkv_proj.weight"] = f"blocks.{i}.attention.img_qkv_proj.weight" mapping[f"blocks.{i}.img_qkv_proj.weight"] = f"blocks.{i}.attention.img_qkv_proj.weight"
...@@ -108,8 +108,8 @@ def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth ...@@ -108,8 +108,8 @@ def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth
return converted_state_dict return converted_state_dict
def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PhotonTransformer2DModel: def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PRXTransformer2DModel:
"""Create and load PhotonTransformer2DModel from old checkpoint.""" """Create and load PRXTransformer2DModel from old checkpoint."""
print(f"Loading checkpoint from: {checkpoint_path}") print(f"Loading checkpoint from: {checkpoint_path}")
...@@ -137,8 +137,8 @@ def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> Ph ...@@ -137,8 +137,8 @@ def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> Ph
converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth) converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth)
# Create transformer with config # Create transformer with config
print("Creating PhotonTransformer2DModel...") print("Creating PRXTransformer2DModel...")
transformer = PhotonTransformer2DModel(**config) transformer = PRXTransformer2DModel(**config)
# Load state dict # Load state dict
print("Loading converted parameters...") print("Loading converted parameters...")
...@@ -221,14 +221,14 @@ def create_model_index(vae_type: str, default_image_size: int, output_path: str) ...@@ -221,14 +221,14 @@ def create_model_index(vae_type: str, default_image_size: int, output_path: str)
vae_class = "AutoencoderDC" vae_class = "AutoencoderDC"
model_index = { model_index = {
"_class_name": "PhotonPipeline", "_class_name": "PRXPipeline",
"_diffusers_version": "0.31.0.dev0", "_diffusers_version": "0.31.0.dev0",
"_name_or_path": os.path.basename(output_path), "_name_or_path": os.path.basename(output_path),
"default_sample_size": default_image_size, "default_sample_size": default_image_size,
"scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"], "scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"],
"text_encoder": ["photon", "T5GemmaEncoder"], "text_encoder": ["prx", "T5GemmaEncoder"],
"tokenizer": ["transformers", "GemmaTokenizerFast"], "tokenizer": ["transformers", "GemmaTokenizerFast"],
"transformer": ["diffusers", "PhotonTransformer2DModel"], "transformer": ["diffusers", "PRXTransformer2DModel"],
"vae": ["diffusers", vae_class], "vae": ["diffusers", vae_class],
} }
...@@ -275,7 +275,7 @@ def main(args): ...@@ -275,7 +275,7 @@ def main(args):
# Verify the pipeline can be loaded # Verify the pipeline can be loaded
try: try:
pipeline = PhotonPipeline.from_pretrained(args.output_path) pipeline = PRXPipeline.from_pretrained(args.output_path)
print("Pipeline loaded successfully!") print("Pipeline loaded successfully!")
print(f"Transformer: {type(pipeline.transformer).__name__}") print(f"Transformer: {type(pipeline.transformer).__name__}")
print(f"VAE: {type(pipeline.vae).__name__}") print(f"VAE: {type(pipeline.vae).__name__}")
...@@ -298,10 +298,10 @@ def main(args): ...@@ -298,10 +298,10 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Photon checkpoint to diffusers format") parser = argparse.ArgumentParser(description="Convert PRX checkpoint to diffusers format")
parser.add_argument( parser.add_argument(
"--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file )" "--checkpoint_path", type=str, required=True, help="Path to the original PRX checkpoint (.pth file )"
) )
parser.add_argument( parser.add_argument(
......
...@@ -232,9 +232,9 @@ else: ...@@ -232,9 +232,9 @@ else:
"MultiControlNetModel", "MultiControlNetModel",
"OmniGenTransformer2DModel", "OmniGenTransformer2DModel",
"ParallelConfig", "ParallelConfig",
"PhotonTransformer2DModel",
"PixArtTransformer2DModel", "PixArtTransformer2DModel",
"PriorTransformer", "PriorTransformer",
"PRXTransformer2DModel",
"QwenImageControlNetModel", "QwenImageControlNetModel",
"QwenImageMultiControlNetModel", "QwenImageMultiControlNetModel",
"QwenImageTransformer2DModel", "QwenImageTransformer2DModel",
...@@ -516,11 +516,11 @@ else: ...@@ -516,11 +516,11 @@ else:
"MusicLDMPipeline", "MusicLDMPipeline",
"OmniGenPipeline", "OmniGenPipeline",
"PaintByExamplePipeline", "PaintByExamplePipeline",
"PhotonPipeline",
"PIAPipeline", "PIAPipeline",
"PixArtAlphaPipeline", "PixArtAlphaPipeline",
"PixArtSigmaPAGPipeline", "PixArtSigmaPAGPipeline",
"PixArtSigmaPipeline", "PixArtSigmaPipeline",
"PRXPipeline",
"QwenImageControlNetInpaintPipeline", "QwenImageControlNetInpaintPipeline",
"QwenImageControlNetPipeline", "QwenImageControlNetPipeline",
"QwenImageEditInpaintPipeline", "QwenImageEditInpaintPipeline",
...@@ -928,9 +928,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -928,9 +928,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MultiControlNetModel, MultiControlNetModel,
OmniGenTransformer2DModel, OmniGenTransformer2DModel,
ParallelConfig, ParallelConfig,
PhotonTransformer2DModel,
PixArtTransformer2DModel, PixArtTransformer2DModel,
PriorTransformer, PriorTransformer,
PRXTransformer2DModel,
QwenImageControlNetModel, QwenImageControlNetModel,
QwenImageMultiControlNetModel, QwenImageMultiControlNetModel,
QwenImageTransformer2DModel, QwenImageTransformer2DModel,
...@@ -1182,11 +1182,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -1182,11 +1182,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MusicLDMPipeline, MusicLDMPipeline,
OmniGenPipeline, OmniGenPipeline,
PaintByExamplePipeline, PaintByExamplePipeline,
PhotonPipeline,
PIAPipeline, PIAPipeline,
PixArtAlphaPipeline, PixArtAlphaPipeline,
PixArtSigmaPAGPipeline, PixArtSigmaPAGPipeline,
PixArtSigmaPipeline, PixArtSigmaPipeline,
PRXPipeline,
QwenImageControlNetInpaintPipeline, QwenImageControlNetInpaintPipeline,
QwenImageControlNetPipeline, QwenImageControlNetPipeline,
QwenImageEditInpaintPipeline, QwenImageEditInpaintPipeline,
......
...@@ -96,7 +96,7 @@ if is_torch_available(): ...@@ -96,7 +96,7 @@ if is_torch_available():
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_photon"] = ["PhotonTransformer2DModel"] _import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"] _import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"] _import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
...@@ -191,9 +191,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -191,9 +191,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LuminaNextDiT2DModel, LuminaNextDiT2DModel,
MochiTransformer3DModel, MochiTransformer3DModel,
OmniGenTransformer2DModel, OmniGenTransformer2DModel,
PhotonTransformer2DModel,
PixArtTransformer2DModel, PixArtTransformer2DModel,
PriorTransformer, PriorTransformer,
PRXTransformer2DModel,
QwenImageTransformer2DModel, QwenImageTransformer2DModel,
SanaTransformer2DModel, SanaTransformer2DModel,
SD3Transformer2DModel, SD3Transformer2DModel,
......
...@@ -32,7 +32,7 @@ if is_torch_available(): ...@@ -32,7 +32,7 @@ if is_torch_available():
from .transformer_lumina2 import Lumina2Transformer2DModel from .transformer_lumina2 import Lumina2Transformer2DModel
from .transformer_mochi import MochiTransformer3DModel from .transformer_mochi import MochiTransformer3DModel
from .transformer_omnigen import OmniGenTransformer2DModel from .transformer_omnigen import OmniGenTransformer2DModel
from .transformer_photon import PhotonTransformer2DModel from .transformer_prx import PRXTransformer2DModel
from .transformer_qwenimage import QwenImageTransformer2DModel from .transformer_qwenimage import QwenImageTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel from .transformer_sd3 import SD3Transformer2DModel
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
......
...@@ -80,9 +80,9 @@ def apply_rope(xq: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: ...@@ -80,9 +80,9 @@ def apply_rope(xq: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
return xq_out.reshape(*xq.shape).type_as(xq) return xq_out.reshape(*xq.shape).type_as(xq)
class PhotonAttnProcessor2_0: class PRXAttnProcessor2_0:
r""" r"""
Processor for implementing Photon-style attention with multi-source tokens and RoPE. Supports multiple attention Processor for implementing PRX-style attention with multi-source tokens and RoPE. Supports multiple attention
backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn. backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn.
""" """
...@@ -91,11 +91,11 @@ class PhotonAttnProcessor2_0: ...@@ -91,11 +91,11 @@ class PhotonAttnProcessor2_0:
def __init__(self): def __init__(self):
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
raise ImportError("PhotonAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.") raise ImportError("PRXAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.")
def __call__( def __call__(
self, self,
attn: "PhotonAttention", attn: "PRXAttention",
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
...@@ -103,10 +103,10 @@ class PhotonAttnProcessor2_0: ...@@ -103,10 +103,10 @@ class PhotonAttnProcessor2_0:
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Apply Photon attention using PhotonAttention module. Apply PRX attention using PRXAttention module.
Args: Args:
attn: PhotonAttention module containing projection layers attn: PRXAttention module containing projection layers
hidden_states: Image tokens [B, L_img, D] hidden_states: Image tokens [B, L_img, D]
encoder_hidden_states: Text tokens [B, L_txt, D] encoder_hidden_states: Text tokens [B, L_txt, D]
attention_mask: Boolean mask for text tokens [B, L_txt] attention_mask: Boolean mask for text tokens [B, L_txt]
...@@ -114,7 +114,7 @@ class PhotonAttnProcessor2_0: ...@@ -114,7 +114,7 @@ class PhotonAttnProcessor2_0:
""" """
if encoder_hidden_states is None: if encoder_hidden_states is None:
raise ValueError("PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.") raise ValueError("PRXAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.")
# Project image tokens to Q, K, V # Project image tokens to Q, K, V
img_qkv = attn.img_qkv_proj(hidden_states) img_qkv = attn.img_qkv_proj(hidden_states)
...@@ -190,14 +190,14 @@ class PhotonAttnProcessor2_0: ...@@ -190,14 +190,14 @@ class PhotonAttnProcessor2_0:
return attn_output return attn_output
class PhotonAttention(nn.Module, AttentionModuleMixin): class PRXAttention(nn.Module, AttentionModuleMixin):
r""" r"""
Photon-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for PRX-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for
Photon's architecture. PRX's architecture.
""" """
_default_processor_cls = PhotonAttnProcessor2_0 _default_processor_cls = PRXAttnProcessor2_0
_available_processors = [PhotonAttnProcessor2_0] _available_processors = [PRXAttnProcessor2_0]
def __init__( def __init__(
self, self,
...@@ -251,7 +251,7 @@ class PhotonAttention(nn.Module, AttentionModuleMixin): ...@@ -251,7 +251,7 @@ class PhotonAttention(nn.Module, AttentionModuleMixin):
# inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py # inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
class PhotonEmbedND(nn.Module): class PRXEmbedND(nn.Module):
r""" r"""
N-dimensional rotary positional embedding. N-dimensional rotary positional embedding.
...@@ -347,7 +347,7 @@ class Modulation(nn.Module): ...@@ -347,7 +347,7 @@ class Modulation(nn.Module):
return tuple(out[:3]), tuple(out[3:]) return tuple(out[:3]), tuple(out[3:])
class PhotonBlock(nn.Module): class PRXBlock(nn.Module):
r""" r"""
Multimodal transformer block with text–image cross-attention, modulation, and MLP. Multimodal transformer block with text–image cross-attention, modulation, and MLP.
...@@ -364,7 +364,7 @@ class PhotonBlock(nn.Module): ...@@ -364,7 +364,7 @@ class PhotonBlock(nn.Module):
Attributes: Attributes:
img_pre_norm (`nn.LayerNorm`): img_pre_norm (`nn.LayerNorm`):
Pre-normalization applied to image tokens before attention. Pre-normalization applied to image tokens before attention.
attention (`PhotonAttention`): attention (`PRXAttention`):
Multi-head attention module with built-in QKV projections and normalizations for cross-attention between Multi-head attention module with built-in QKV projections and normalizations for cross-attention between
image and text tokens. image and text tokens.
post_attention_layernorm (`nn.LayerNorm`): post_attention_layernorm (`nn.LayerNorm`):
...@@ -400,15 +400,15 @@ class PhotonBlock(nn.Module): ...@@ -400,15 +400,15 @@ class PhotonBlock(nn.Module):
# Pre-attention normalization for image tokens # Pre-attention normalization for image tokens
self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# PhotonAttention module with built-in projections and norms # PRXAttention module with built-in projections and norms
self.attention = PhotonAttention( self.attention = PRXAttention(
query_dim=hidden_size, query_dim=hidden_size,
heads=num_heads, heads=num_heads,
dim_head=self.head_dim, dim_head=self.head_dim,
bias=False, bias=False,
out_bias=False, out_bias=False,
eps=1e-6, eps=1e-6,
processor=PhotonAttnProcessor2_0(), processor=PRXAttnProcessor2_0(),
) )
# mlp # mlp
...@@ -557,7 +557,7 @@ def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Te ...@@ -557,7 +557,7 @@ def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Te
return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size) return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size)
class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
r""" r"""
Transformer-based 2D model for text to image generation. Transformer-based 2D model for text to image generation.
...@@ -595,7 +595,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): ...@@ -595,7 +595,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
txt_in (`nn.Linear`): txt_in (`nn.Linear`):
Projection layer for text conditioning. Projection layer for text conditioning.
blocks (`nn.ModuleList`): blocks (`nn.ModuleList`):
Stack of transformer blocks (`PhotonBlock`). Stack of transformer blocks (`PRXBlock`).
final_layer (`LastLayer`): final_layer (`LastLayer`):
Projection layer mapping hidden tokens back to patch outputs. Projection layer mapping hidden tokens back to patch outputs.
...@@ -661,14 +661,14 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): ...@@ -661,14 +661,14 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_heads = num_heads self.num_heads = num_heads
self.pe_embedder = PhotonEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) self.pe_embedder = PRXEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True) self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.txt_in = nn.Linear(context_in_dim, self.hidden_size) self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ [
PhotonBlock( PRXBlock(
self.hidden_size, self.hidden_size,
self.num_heads, self.num_heads,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
...@@ -702,7 +702,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): ...@@ -702,7 +702,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
return_dict: bool = True, return_dict: bool = True,
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
r""" r"""
Forward pass of the PhotonTransformer2DModel. Forward pass of the PRXTransformer2DModel.
The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of
transformer blocks modulated by the timestep. The output is reconstructed into the latent image space. transformer blocks modulated by the timestep. The output is reconstructed into the latent image space.
......
...@@ -144,7 +144,7 @@ else: ...@@ -144,7 +144,7 @@ else:
"FluxKontextPipeline", "FluxKontextPipeline",
"FluxKontextInpaintPipeline", "FluxKontextInpaintPipeline",
] ]
_import_structure["photon"] = ["PhotonPipeline"] _import_structure["prx"] = ["PRXPipeline"]
_import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [ _import_structure["audioldm2"] = [
"AudioLDM2Pipeline", "AudioLDM2Pipeline",
...@@ -718,9 +718,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -718,9 +718,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLPAGPipeline, StableDiffusionXLPAGPipeline,
) )
from .paint_by_example import PaintByExamplePipeline from .paint_by_example import PaintByExamplePipeline
from .photon import PhotonPipeline
from .pia import PIAPipeline from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .prx import PRXPipeline
from .qwenimage import ( from .qwenimage import (
QwenImageControlNetInpaintPipeline, QwenImageControlNetInpaintPipeline,
QwenImageControlNetPipeline, QwenImageControlNetPipeline,
......
...@@ -12,7 +12,7 @@ from ...utils import ( ...@@ -12,7 +12,7 @@ from ...utils import (
_dummy_objects = {} _dummy_objects = {}
_additional_imports = {} _additional_imports = {}
_import_structure = {"pipeline_output": ["PhotonPipelineOutput"]} _import_structure = {"pipeline_output": ["PRXPipelineOutput"]}
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
...@@ -22,7 +22,7 @@ except OptionalDependencyNotAvailable: ...@@ -22,7 +22,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
_import_structure["pipeline_photon"] = ["PhotonPipeline"] _import_structure["pipeline_prx"] = ["PRXPipeline"]
# Import T5GemmaEncoder for pipeline loading compatibility # Import T5GemmaEncoder for pipeline loading compatibility
try: try:
...@@ -44,8 +44,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -44,8 +44,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else: else:
from .pipeline_output import PhotonPipelineOutput from .pipeline_output import PRXPipelineOutput
from .pipeline_photon import PhotonPipeline from .pipeline_prx import PRXPipeline
else: else:
import sys import sys
......
...@@ -22,9 +22,9 @@ from ...utils import BaseOutput ...@@ -22,9 +22,9 @@ from ...utils import BaseOutput
@dataclass @dataclass
class PhotonPipelineOutput(BaseOutput): class PRXPipelineOutput(BaseOutput):
""" """
Output class for Photon pipelines. Output class for PRX pipelines.
Args: Args:
images (`List[PIL.Image.Image]` or `np.ndarray`) images (`List[PIL.Image.Image]` or `np.ndarray`)
......
...@@ -30,9 +30,9 @@ from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder ...@@ -30,9 +30,9 @@ from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
from diffusers.image_processor import PixArtImageProcessor from diffusers.image_processor import PixArtImageProcessor
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderDC, AutoencoderKL from diffusers.models import AutoencoderDC, AutoencoderKL
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.pipelines.photon.pipeline_output import PhotonPipelineOutput
from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.prx.pipeline_output import PRXPipelineOutput
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import ( from diffusers.utils import (
logging, logging,
...@@ -73,7 +73,7 @@ logger = logging.get_logger(__name__) ...@@ -73,7 +73,7 @@ logger = logging.get_logger(__name__)
class TextPreprocessor: class TextPreprocessor:
"""Text preprocessing utility for PhotonPipeline.""" """Text preprocessing utility for PRXPipeline."""
def __init__(self): def __init__(self):
"""Initialize text preprocessor.""" """Initialize text preprocessor."""
...@@ -203,34 +203,34 @@ EXAMPLE_DOC_STRING = """ ...@@ -203,34 +203,34 @@ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
>>> import torch >>> import torch
>>> from diffusers import PhotonPipeline >>> from diffusers import PRXPipeline
>>> # Load pipeline with from_pretrained >>> # Load pipeline with from_pretrained
>>> pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft") >>> pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft")
>>> pipe.to("cuda") >>> pipe.to("cuda")
>>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach" >>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach"
>>> image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0] >>> image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
>>> image.save("photon_output.png") >>> image.save("prx_output.png")
``` ```
""" """
class PhotonPipeline( class PRXPipeline(
DiffusionPipeline, DiffusionPipeline,
LoraLoaderMixin, LoraLoaderMixin,
FromSingleFileMixin, FromSingleFileMixin,
TextualInversionLoaderMixin, TextualInversionLoaderMixin,
): ):
r""" r"""
Pipeline for text-to-image generation using Photon Transformer. Pipeline for text-to-image generation using PRX Transformer.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args: Args:
transformer ([`PhotonTransformer2DModel`]): transformer ([`PRXTransformer2DModel`]):
The Photon transformer model to denoise the encoded image latents. The PRX transformer model to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]): scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents. A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
text_encoder ([`T5GemmaEncoder`]): text_encoder ([`T5GemmaEncoder`]):
...@@ -248,7 +248,7 @@ class PhotonPipeline( ...@@ -248,7 +248,7 @@ class PhotonPipeline(
def __init__( def __init__(
self, self,
transformer: PhotonTransformer2DModel, transformer: PRXTransformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler, scheduler: FlowMatchEulerDiscreteScheduler,
text_encoder: T5GemmaEncoder, text_encoder: T5GemmaEncoder,
tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer], tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer],
...@@ -257,9 +257,9 @@ class PhotonPipeline( ...@@ -257,9 +257,9 @@ class PhotonPipeline(
): ):
super().__init__() super().__init__()
if PhotonTransformer2DModel is None: if PRXTransformer2DModel is None:
raise ImportError( raise ImportError(
"PhotonTransformer2DModel is not available. Please ensure the transformer_photon module is properly installed." "PRXTransformer2DModel is not available. Please ensure the transformer_prx module is properly installed."
) )
self.text_preprocessor = TextPreprocessor() self.text_preprocessor = TextPreprocessor()
...@@ -567,7 +567,7 @@ class PhotonPipeline( ...@@ -567,7 +567,7 @@ class PhotonPipeline(
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.photon.PhotonPipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.prx.PRXPipelineOutput`] instead of a plain tuple.
use_resolution_binning (`bool`, *optional*, defaults to `True`): use_resolution_binning (`bool`, *optional*, defaults to `True`):
If set to `True`, the requested height and width are first mapped to the closest resolutions using If set to `True`, the requested height and width are first mapped to the closest resolutions using
predefined aspect ratio bins. After the produced latents are decoded into images, they are resized back predefined aspect ratio bins. After the produced latents are decoded into images, they are resized back
...@@ -585,9 +585,8 @@ class PhotonPipeline( ...@@ -585,9 +585,8 @@ class PhotonPipeline(
Examples: Examples:
Returns: Returns:
[`~pipelines.photon.PhotonPipelineOutput`] or `tuple`: [`~pipelines.photon.PhotonPipelineOutput`] if [`~pipelines.prx.PRXPipelineOutput`] or `tuple`: [`~pipelines.prx.PRXPipelineOutput`] if `return_dict` is
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
generated images.
""" """
# 0. Set height and width # 0. Set height and width
...@@ -765,4 +764,4 @@ class PhotonPipeline( ...@@ -765,4 +764,4 @@ class PhotonPipeline(
if not return_dict: if not return_dict:
return (image,) return (image,)
return PhotonPipelineOutput(images=image) return PRXPipelineOutput(images=image)
...@@ -1098,7 +1098,7 @@ class ParallelConfig(metaclass=DummyObject): ...@@ -1098,7 +1098,7 @@ class ParallelConfig(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class PhotonTransformer2DModel(metaclass=DummyObject): class PixArtTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -1113,7 +1113,7 @@ class PhotonTransformer2DModel(metaclass=DummyObject): ...@@ -1113,7 +1113,7 @@ class PhotonTransformer2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class PixArtTransformer2DModel(metaclass=DummyObject): class PriorTransformer(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -1128,7 +1128,7 @@ class PixArtTransformer2DModel(metaclass=DummyObject): ...@@ -1128,7 +1128,7 @@ class PixArtTransformer2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class PriorTransformer(metaclass=DummyObject): class PRXTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
......
...@@ -1847,7 +1847,7 @@ class PaintByExamplePipeline(metaclass=DummyObject): ...@@ -1847,7 +1847,7 @@ class PaintByExamplePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class PhotonPipeline(metaclass=DummyObject): class PIAPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -1862,7 +1862,7 @@ class PhotonPipeline(metaclass=DummyObject): ...@@ -1862,7 +1862,7 @@ class PhotonPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class PIAPipeline(metaclass=DummyObject): class PixArtAlphaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -1877,7 +1877,7 @@ class PIAPipeline(metaclass=DummyObject): ...@@ -1877,7 +1877,7 @@ class PIAPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class PixArtAlphaPipeline(metaclass=DummyObject): class PixArtSigmaPAGPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -1892,7 +1892,7 @@ class PixArtAlphaPipeline(metaclass=DummyObject): ...@@ -1892,7 +1892,7 @@ class PixArtAlphaPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class PixArtSigmaPAGPipeline(metaclass=DummyObject): class PixArtSigmaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -1907,7 +1907,7 @@ class PixArtSigmaPAGPipeline(metaclass=DummyObject): ...@@ -1907,7 +1907,7 @@ class PixArtSigmaPAGPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class PixArtSigmaPipeline(metaclass=DummyObject): class PRXPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
......
...@@ -17,7 +17,7 @@ import unittest ...@@ -17,7 +17,7 @@ import unittest
import torch import torch
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from ...testing_utils import enable_full_determinism, torch_device from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin from ..test_modeling_common import ModelTesterMixin
...@@ -26,8 +26,8 @@ from ..test_modeling_common import ModelTesterMixin ...@@ -26,8 +26,8 @@ from ..test_modeling_common import ModelTesterMixin
enable_full_determinism() enable_full_determinism()
class PhotonTransformerTests(ModelTesterMixin, unittest.TestCase): class PRXTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = PhotonTransformer2DModel model_class = PRXTransformer2DModel
main_input_name = "hidden_states" main_input_name = "hidden_states"
uses_custom_attn_processor = True uses_custom_attn_processor = True
...@@ -75,7 +75,7 @@ class PhotonTransformerTests(ModelTesterMixin, unittest.TestCase): ...@@ -75,7 +75,7 @@ class PhotonTransformerTests(ModelTesterMixin, unittest.TestCase):
return init_dict, inputs_dict return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self): def test_gradient_checkpointing_is_applied(self):
expected_set = {"PhotonTransformer2DModel"} expected_set = {"PRXTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
......
...@@ -8,8 +8,8 @@ from transformers.models.t5gemma.configuration_t5gemma import T5GemmaConfig, T5G ...@@ -8,8 +8,8 @@ from transformers.models.t5gemma.configuration_t5gemma import T5GemmaConfig, T5G
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
from diffusers.models import AutoencoderDC, AutoencoderKL from diffusers.models import AutoencoderDC, AutoencoderKL
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.pipelines.photon.pipeline_photon import PhotonPipeline from diffusers.pipelines.prx.pipeline_prx import PRXPipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import is_transformers_version from diffusers.utils import is_transformers_version
...@@ -22,8 +22,8 @@ from ..test_pipelines_common import PipelineTesterMixin ...@@ -22,8 +22,8 @@ from ..test_pipelines_common import PipelineTesterMixin
reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544", reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544",
strict=False, strict=False,
) )
class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class PRXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = PhotonPipeline pipeline_class = PRXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = frozenset(["prompt", "negative_prompt", "num_images_per_prompt"]) batch_params = frozenset(["prompt", "negative_prompt", "num_images_per_prompt"])
test_xformers_attention = False test_xformers_attention = False
...@@ -32,16 +32,16 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -32,16 +32,16 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
# Ensure PhotonPipeline has an _execution_device property expected by __call__ # Ensure PRXPipeline has an _execution_device property expected by __call__
if not isinstance(getattr(PhotonPipeline, "_execution_device", None), property): if not isinstance(getattr(PRXPipeline, "_execution_device", None), property):
try: try:
setattr(PhotonPipeline, "_execution_device", property(lambda self: torch.device("cpu"))) setattr(PRXPipeline, "_execution_device", property(lambda self: torch.device("cpu")))
except Exception: except Exception:
pass pass
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
transformer = PhotonTransformer2DModel( transformer = PRXTransformer2DModel(
patch_size=1, patch_size=1,
in_channels=4, in_channels=4,
context_in_dim=8, context_in_dim=8,
...@@ -129,7 +129,7 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -129,7 +129,7 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def test_inference(self): def test_inference(self):
device = "cpu" device = "cpu"
components = self.get_dummy_components() components = self.get_dummy_components()
pipe = PhotonPipeline(**components) pipe = PRXPipeline(**components)
pipe.to(device) pipe.to(device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
try: try:
...@@ -148,7 +148,7 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -148,7 +148,7 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def test_callback_inputs(self): def test_callback_inputs(self):
components = self.get_dummy_components() components = self.get_dummy_components()
pipe = PhotonPipeline(**components) pipe = PRXPipeline(**components)
pipe = pipe.to("cpu") pipe = pipe.to("cpu")
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
try: try:
...@@ -157,7 +157,7 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -157,7 +157,7 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pass pass
self.assertTrue( self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"), hasattr(pipe, "_callback_tensor_inputs"),
f" {PhotonPipeline} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", f" {PRXPipeline} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
) )
def callback_inputs_subset(pipe, i, t, callback_kwargs): def callback_inputs_subset(pipe, i, t, callback_kwargs):
...@@ -216,7 +216,7 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -216,7 +216,7 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
self.assertLess(max(max_diff1, max_diff2), expected_max_diff) self.assertLess(max(max_diff1, max_diff2), expected_max_diff)
def test_inference_with_autoencoder_dc(self): def test_inference_with_autoencoder_dc(self):
"""Test PhotonPipeline with AutoencoderDC (DCAE) instead of AutoencoderKL.""" """Test PRXPipeline with AutoencoderDC (DCAE) instead of AutoencoderKL."""
device = "cpu" device = "cpu"
components = self.get_dummy_components() components = self.get_dummy_components()
...@@ -248,7 +248,7 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -248,7 +248,7 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
components["vae"] = vae_dc components["vae"] = vae_dc
pipe = PhotonPipeline(**components) pipe = PRXPipeline(**components)
pipe.to(device) pipe.to(device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment