Unverified Commit dd07b19e authored by David Bertoin's avatar David Bertoin Committed by GitHub
Browse files

Prx (#12525)

* rename photon to prx

* rename photon into prx

* Revert .gitignore to state before commit b7fb0fe9d63bf766bbe3c42ac154a043796dd370

* rename photon to prx

* rename photon into prx

* Revert .gitignore to state before commit b7fb0fe9d63bf766bbe3c42ac154a043796dd370

* make fix-copies
parent 57636ad4
......@@ -541,12 +541,12 @@
title: PAG
- local: api/pipelines/paint_by_example
title: Paint by Example
- local: api/pipelines/photon
title: Photon
- local: api/pipelines/pixart
title: PixArt-α
- local: api/pipelines/pixart_sigma
title: PixArt-Σ
- local: api/pipelines/prx
title: PRX
- local: api/pipelines/qwenimage
title: QwenImage
- local: api/pipelines/sana
......
......@@ -12,43 +12,43 @@
# See the License for the specific language governing permissions and
# limitations under the License. -->
# Photon
# PRX
Photon generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing.
PRX generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing.
## Available models
Photon offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts.
PRX offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts.
| Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype |
|:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:|
| [`Photoroom/photon-256-t2i`](https://huggingface.co/Photoroom/photon-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-256-t2i-sft`](https://huggingface.co/Photoroom/photon-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i`](https://huggingface.co/Photoroom/photon-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i-sft`](https://huggingface.co/Photoroom/photon-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/photon-512-t2i-sft`](https://huggingface.co/Photoroom/photon-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i-dc-ae`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/photon-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s
| [`Photoroom/prx-256-t2i`](https://huggingface.co/Photoroom/prx-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-256-t2i-sft`](https://huggingface.co/Photoroom/prx-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i`](https://huggingface.co/Photoroom/prx-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-dc-ae`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s
Refer to [this](https://huggingface.co/collections/Photoroom/photon-models-68e66254c202ebfab99ad38e) collection for more information.
Refer to [this](https://huggingface.co/collections/Photoroom/prx-models-68e66254c202ebfab99ad38e) collection for more information.
## Loading the pipeline
Load the pipeline with [`~DiffusionPipeline.from_pretrained`].
```py
from diffusers.pipelines.photon import PhotonPipeline
from diffusers.pipelines.prx import PRXPipeline
# Load pipeline - VAE and text encoder will be loaded from HuggingFace
pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16)
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = "A front-facing portrait of a lion the golden savanna at sunset."
image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
image.save("photon_output.png")
image.save("prx_output.png")
```
### Manual Component Loading
......@@ -57,9 +57,9 @@ Load components individually to customize the pipeline for instance to use quant
```py
import torch
from diffusers.pipelines.photon import PhotonPipeline
from diffusers.pipelines.prx import PRXPipeline
from diffusers.models import AutoencoderKL, AutoencoderDC
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from transformers import T5GemmaModel, GemmaTokenizerFast
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
......@@ -67,8 +67,8 @@ from transformers import BitsAndBytesConfig as BitsAndBytesConfig
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
# Load transformer
transformer = PhotonTransformer2DModel.from_pretrained(
"checkpoints/photon-512-t2i-sft",
transformer = PRXTransformer2DModel.from_pretrained(
"checkpoints/prx-512-t2i-sft",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
......@@ -76,7 +76,7 @@ transformer = PhotonTransformer2DModel.from_pretrained(
# Load scheduler
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
"checkpoints/photon-512-t2i-sft", subfolder="scheduler"
"checkpoints/prx-512-t2i-sft", subfolder="scheduler"
)
# Load T5Gemma text encoder
......@@ -94,7 +94,7 @@ vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev",
quantization_config=quant_config,
torch_dtype=torch.bfloat16)
pipe = PhotonPipeline(
pipe = PRXPipeline(
transformer=transformer,
scheduler=scheduler,
text_encoder=text_encoder,
......@@ -111,21 +111,21 @@ For memory-constrained environments:
```py
import torch
from diffusers.pipelines.photon import PhotonPipeline
from diffusers.pipelines.prx import PRXPipeline
pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16)
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload() # Offload components to CPU when not in use
# Or use sequential CPU offload for even lower memory
pipe.enable_sequential_cpu_offload()
```
## PhotonPipeline
## PRXPipeline
[[autodoc]] PhotonPipeline
[[autodoc]] PRXPipeline
- all
- __call__
## PhotonPipelineOutput
## PRXPipelineOutput
[[autodoc]] pipelines.photon.pipeline_output.PhotonPipelineOutput
[[autodoc]] pipelines.prx.pipeline_output.PRXPipelineOutput
#!/usr/bin/env python3
"""
Script to convert Photon checkpoint from original codebase to diffusers format.
Script to convert PRX checkpoint from original codebase to diffusers format.
"""
import argparse
......@@ -13,15 +13,15 @@ from typing import Dict, Tuple
import torch
from safetensors.torch import save_file
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
from diffusers.pipelines.photon import PhotonPipeline
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.pipelines.prx import PRXPipeline
DEFAULT_RESOLUTION = 512
@dataclass(frozen=True)
class PhotonBase:
class PRXBase:
context_in_dim: int = 2304
hidden_size: int = 1792
mlp_ratio: float = 3.5
......@@ -34,22 +34,22 @@ class PhotonBase:
@dataclass(frozen=True)
class PhotonFlux(PhotonBase):
class PRXFlux(PRXBase):
in_channels: int = 16
patch_size: int = 2
@dataclass(frozen=True)
class PhotonDCAE(PhotonBase):
class PRXDCAE(PRXBase):
in_channels: int = 32
patch_size: int = 1
def build_config(vae_type: str) -> Tuple[dict, int]:
if vae_type == "flux":
cfg = PhotonFlux()
cfg = PRXFlux()
elif vae_type == "dc-ae":
cfg = PhotonDCAE()
cfg = PRXDCAE()
else:
raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'")
......@@ -64,7 +64,7 @@ def create_parameter_mapping(depth: int) -> dict:
# Key mappings for structural changes
mapping = {}
# Map old structure (layers in PhotonBlock) to new structure (layers in PhotonAttention)
# Map old structure (layers in PRXBlock) to new structure (layers in PRXAttention)
for i in range(depth):
# QKV projections moved to attention module
mapping[f"blocks.{i}.img_qkv_proj.weight"] = f"blocks.{i}.attention.img_qkv_proj.weight"
......@@ -108,8 +108,8 @@ def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth
return converted_state_dict
def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PhotonTransformer2DModel:
"""Create and load PhotonTransformer2DModel from old checkpoint."""
def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PRXTransformer2DModel:
"""Create and load PRXTransformer2DModel from old checkpoint."""
print(f"Loading checkpoint from: {checkpoint_path}")
......@@ -137,8 +137,8 @@ def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> Ph
converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth)
# Create transformer with config
print("Creating PhotonTransformer2DModel...")
transformer = PhotonTransformer2DModel(**config)
print("Creating PRXTransformer2DModel...")
transformer = PRXTransformer2DModel(**config)
# Load state dict
print("Loading converted parameters...")
......@@ -221,14 +221,14 @@ def create_model_index(vae_type: str, default_image_size: int, output_path: str)
vae_class = "AutoencoderDC"
model_index = {
"_class_name": "PhotonPipeline",
"_class_name": "PRXPipeline",
"_diffusers_version": "0.31.0.dev0",
"_name_or_path": os.path.basename(output_path),
"default_sample_size": default_image_size,
"scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"],
"text_encoder": ["photon", "T5GemmaEncoder"],
"text_encoder": ["prx", "T5GemmaEncoder"],
"tokenizer": ["transformers", "GemmaTokenizerFast"],
"transformer": ["diffusers", "PhotonTransformer2DModel"],
"transformer": ["diffusers", "PRXTransformer2DModel"],
"vae": ["diffusers", vae_class],
}
......@@ -275,7 +275,7 @@ def main(args):
# Verify the pipeline can be loaded
try:
pipeline = PhotonPipeline.from_pretrained(args.output_path)
pipeline = PRXPipeline.from_pretrained(args.output_path)
print("Pipeline loaded successfully!")
print(f"Transformer: {type(pipeline.transformer).__name__}")
print(f"VAE: {type(pipeline.vae).__name__}")
......@@ -298,10 +298,10 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Photon checkpoint to diffusers format")
parser = argparse.ArgumentParser(description="Convert PRX checkpoint to diffusers format")
parser.add_argument(
"--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file )"
"--checkpoint_path", type=str, required=True, help="Path to the original PRX checkpoint (.pth file )"
)
parser.add_argument(
......
......@@ -232,9 +232,9 @@ else:
"MultiControlNetModel",
"OmniGenTransformer2DModel",
"ParallelConfig",
"PhotonTransformer2DModel",
"PixArtTransformer2DModel",
"PriorTransformer",
"PRXTransformer2DModel",
"QwenImageControlNetModel",
"QwenImageMultiControlNetModel",
"QwenImageTransformer2DModel",
......@@ -516,11 +516,11 @@ else:
"MusicLDMPipeline",
"OmniGenPipeline",
"PaintByExamplePipeline",
"PhotonPipeline",
"PIAPipeline",
"PixArtAlphaPipeline",
"PixArtSigmaPAGPipeline",
"PixArtSigmaPipeline",
"PRXPipeline",
"QwenImageControlNetInpaintPipeline",
"QwenImageControlNetPipeline",
"QwenImageEditInpaintPipeline",
......@@ -928,9 +928,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MultiControlNetModel,
OmniGenTransformer2DModel,
ParallelConfig,
PhotonTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
PRXTransformer2DModel,
QwenImageControlNetModel,
QwenImageMultiControlNetModel,
QwenImageTransformer2DModel,
......@@ -1182,11 +1182,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MusicLDMPipeline,
OmniGenPipeline,
PaintByExamplePipeline,
PhotonPipeline,
PIAPipeline,
PixArtAlphaPipeline,
PixArtSigmaPAGPipeline,
PixArtSigmaPipeline,
PRXPipeline,
QwenImageControlNetInpaintPipeline,
QwenImageControlNetPipeline,
QwenImageEditInpaintPipeline,
......
......@@ -96,7 +96,7 @@ if is_torch_available():
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_photon"] = ["PhotonTransformer2DModel"]
_import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
......@@ -191,9 +191,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LuminaNextDiT2DModel,
MochiTransformer3DModel,
OmniGenTransformer2DModel,
PhotonTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
PRXTransformer2DModel,
QwenImageTransformer2DModel,
SanaTransformer2DModel,
SD3Transformer2DModel,
......
......@@ -32,7 +32,7 @@ if is_torch_available():
from .transformer_lumina2 import Lumina2Transformer2DModel
from .transformer_mochi import MochiTransformer3DModel
from .transformer_omnigen import OmniGenTransformer2DModel
from .transformer_photon import PhotonTransformer2DModel
from .transformer_prx import PRXTransformer2DModel
from .transformer_qwenimage import QwenImageTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
......
......@@ -80,9 +80,9 @@ def apply_rope(xq: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
return xq_out.reshape(*xq.shape).type_as(xq)
class PhotonAttnProcessor2_0:
class PRXAttnProcessor2_0:
r"""
Processor for implementing Photon-style attention with multi-source tokens and RoPE. Supports multiple attention
Processor for implementing PRX-style attention with multi-source tokens and RoPE. Supports multiple attention
backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn.
"""
......@@ -91,11 +91,11 @@ class PhotonAttnProcessor2_0:
def __init__(self):
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
raise ImportError("PhotonAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.")
raise ImportError("PRXAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: "PhotonAttention",
attn: "PRXAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
......@@ -103,10 +103,10 @@ class PhotonAttnProcessor2_0:
**kwargs,
) -> torch.Tensor:
"""
Apply Photon attention using PhotonAttention module.
Apply PRX attention using PRXAttention module.
Args:
attn: PhotonAttention module containing projection layers
attn: PRXAttention module containing projection layers
hidden_states: Image tokens [B, L_img, D]
encoder_hidden_states: Text tokens [B, L_txt, D]
attention_mask: Boolean mask for text tokens [B, L_txt]
......@@ -114,7 +114,7 @@ class PhotonAttnProcessor2_0:
"""
if encoder_hidden_states is None:
raise ValueError("PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.")
raise ValueError("PRXAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.")
# Project image tokens to Q, K, V
img_qkv = attn.img_qkv_proj(hidden_states)
......@@ -190,14 +190,14 @@ class PhotonAttnProcessor2_0:
return attn_output
class PhotonAttention(nn.Module, AttentionModuleMixin):
class PRXAttention(nn.Module, AttentionModuleMixin):
r"""
Photon-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for
Photon's architecture.
PRX-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for
PRX's architecture.
"""
_default_processor_cls = PhotonAttnProcessor2_0
_available_processors = [PhotonAttnProcessor2_0]
_default_processor_cls = PRXAttnProcessor2_0
_available_processors = [PRXAttnProcessor2_0]
def __init__(
self,
......@@ -251,7 +251,7 @@ class PhotonAttention(nn.Module, AttentionModuleMixin):
# inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
class PhotonEmbedND(nn.Module):
class PRXEmbedND(nn.Module):
r"""
N-dimensional rotary positional embedding.
......@@ -347,7 +347,7 @@ class Modulation(nn.Module):
return tuple(out[:3]), tuple(out[3:])
class PhotonBlock(nn.Module):
class PRXBlock(nn.Module):
r"""
Multimodal transformer block with text–image cross-attention, modulation, and MLP.
......@@ -364,7 +364,7 @@ class PhotonBlock(nn.Module):
Attributes:
img_pre_norm (`nn.LayerNorm`):
Pre-normalization applied to image tokens before attention.
attention (`PhotonAttention`):
attention (`PRXAttention`):
Multi-head attention module with built-in QKV projections and normalizations for cross-attention between
image and text tokens.
post_attention_layernorm (`nn.LayerNorm`):
......@@ -400,15 +400,15 @@ class PhotonBlock(nn.Module):
# Pre-attention normalization for image tokens
self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# PhotonAttention module with built-in projections and norms
self.attention = PhotonAttention(
# PRXAttention module with built-in projections and norms
self.attention = PRXAttention(
query_dim=hidden_size,
heads=num_heads,
dim_head=self.head_dim,
bias=False,
out_bias=False,
eps=1e-6,
processor=PhotonAttnProcessor2_0(),
processor=PRXAttnProcessor2_0(),
)
# mlp
......@@ -557,7 +557,7 @@ def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Te
return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size)
class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
r"""
Transformer-based 2D model for text to image generation.
......@@ -595,7 +595,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
txt_in (`nn.Linear`):
Projection layer for text conditioning.
blocks (`nn.ModuleList`):
Stack of transformer blocks (`PhotonBlock`).
Stack of transformer blocks (`PRXBlock`).
final_layer (`LastLayer`):
Projection layer mapping hidden tokens back to patch outputs.
......@@ -661,14 +661,14 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
self.hidden_size = hidden_size
self.num_heads = num_heads
self.pe_embedder = PhotonEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
self.pe_embedder = PRXEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
self.blocks = nn.ModuleList(
[
PhotonBlock(
PRXBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=mlp_ratio,
......@@ -702,7 +702,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
r"""
Forward pass of the PhotonTransformer2DModel.
Forward pass of the PRXTransformer2DModel.
The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of
transformer blocks modulated by the timestep. The output is reconstructed into the latent image space.
......
......@@ -144,7 +144,7 @@ else:
"FluxKontextPipeline",
"FluxKontextInpaintPipeline",
]
_import_structure["photon"] = ["PhotonPipeline"]
_import_structure["prx"] = ["PRXPipeline"]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
"AudioLDM2Pipeline",
......@@ -718,9 +718,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLPAGPipeline,
)
from .paint_by_example import PaintByExamplePipeline
from .photon import PhotonPipeline
from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .prx import PRXPipeline
from .qwenimage import (
QwenImageControlNetInpaintPipeline,
QwenImageControlNetPipeline,
......
......@@ -12,7 +12,7 @@ from ...utils import (
_dummy_objects = {}
_additional_imports = {}
_import_structure = {"pipeline_output": ["PhotonPipelineOutput"]}
_import_structure = {"pipeline_output": ["PRXPipelineOutput"]}
try:
if not (is_transformers_available() and is_torch_available()):
......@@ -22,7 +22,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_photon"] = ["PhotonPipeline"]
_import_structure["pipeline_prx"] = ["PRXPipeline"]
# Import T5GemmaEncoder for pipeline loading compatibility
try:
......@@ -44,8 +44,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_output import PhotonPipelineOutput
from .pipeline_photon import PhotonPipeline
from .pipeline_output import PRXPipelineOutput
from .pipeline_prx import PRXPipeline
else:
import sys
......
......@@ -22,9 +22,9 @@ from ...utils import BaseOutput
@dataclass
class PhotonPipelineOutput(BaseOutput):
class PRXPipelineOutput(BaseOutput):
"""
Output class for Photon pipelines.
Output class for PRX pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
......
......@@ -30,9 +30,9 @@ from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
from diffusers.image_processor import PixArtImageProcessor
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderDC, AutoencoderKL
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
from diffusers.pipelines.photon.pipeline_output import PhotonPipelineOutput
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.prx.pipeline_output import PRXPipelineOutput
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
logging,
......@@ -73,7 +73,7 @@ logger = logging.get_logger(__name__)
class TextPreprocessor:
"""Text preprocessing utility for PhotonPipeline."""
"""Text preprocessing utility for PRXPipeline."""
def __init__(self):
"""Initialize text preprocessor."""
......@@ -203,34 +203,34 @@ EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import PhotonPipeline
>>> from diffusers import PRXPipeline
>>> # Load pipeline with from_pretrained
>>> pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft")
>>> pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft")
>>> pipe.to("cuda")
>>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach"
>>> image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
>>> image.save("photon_output.png")
>>> image.save("prx_output.png")
```
"""
class PhotonPipeline(
class PRXPipeline(
DiffusionPipeline,
LoraLoaderMixin,
FromSingleFileMixin,
TextualInversionLoaderMixin,
):
r"""
Pipeline for text-to-image generation using Photon Transformer.
Pipeline for text-to-image generation using PRX Transformer.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
transformer ([`PhotonTransformer2DModel`]):
The Photon transformer model to denoise the encoded image latents.
transformer ([`PRXTransformer2DModel`]):
The PRX transformer model to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
text_encoder ([`T5GemmaEncoder`]):
......@@ -248,7 +248,7 @@ class PhotonPipeline(
def __init__(
self,
transformer: PhotonTransformer2DModel,
transformer: PRXTransformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
text_encoder: T5GemmaEncoder,
tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer],
......@@ -257,9 +257,9 @@ class PhotonPipeline(
):
super().__init__()
if PhotonTransformer2DModel is None:
if PRXTransformer2DModel is None:
raise ImportError(
"PhotonTransformer2DModel is not available. Please ensure the transformer_photon module is properly installed."
"PRXTransformer2DModel is not available. Please ensure the transformer_prx module is properly installed."
)
self.text_preprocessor = TextPreprocessor()
......@@ -567,7 +567,7 @@ class PhotonPipeline(
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.photon.PhotonPipelineOutput`] instead of a plain tuple.
Whether or not to return a [`~pipelines.prx.PRXPipelineOutput`] instead of a plain tuple.
use_resolution_binning (`bool`, *optional*, defaults to `True`):
If set to `True`, the requested height and width are first mapped to the closest resolutions using
predefined aspect ratio bins. After the produced latents are decoded into images, they are resized back
......@@ -585,9 +585,8 @@ class PhotonPipeline(
Examples:
Returns:
[`~pipelines.photon.PhotonPipelineOutput`] or `tuple`: [`~pipelines.photon.PhotonPipelineOutput`] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
[`~pipelines.prx.PRXPipelineOutput`] or `tuple`: [`~pipelines.prx.PRXPipelineOutput`] if `return_dict` is
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
"""
# 0. Set height and width
......@@ -765,4 +764,4 @@ class PhotonPipeline(
if not return_dict:
return (image,)
return PhotonPipelineOutput(images=image)
return PRXPipelineOutput(images=image)
......@@ -1098,7 +1098,7 @@ class ParallelConfig(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class PhotonTransformer2DModel(metaclass=DummyObject):
class PixArtTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
......@@ -1113,7 +1113,7 @@ class PhotonTransformer2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class PixArtTransformer2DModel(metaclass=DummyObject):
class PriorTransformer(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
......@@ -1128,7 +1128,7 @@ class PixArtTransformer2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class PriorTransformer(metaclass=DummyObject):
class PRXTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
......
......@@ -1847,7 +1847,7 @@ class PaintByExamplePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class PhotonPipeline(metaclass=DummyObject):
class PIAPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
......@@ -1862,7 +1862,7 @@ class PhotonPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class PIAPipeline(metaclass=DummyObject):
class PixArtAlphaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
......@@ -1877,7 +1877,7 @@ class PIAPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class PixArtAlphaPipeline(metaclass=DummyObject):
class PixArtSigmaPAGPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
......@@ -1892,7 +1892,7 @@ class PixArtAlphaPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class PixArtSigmaPAGPipeline(metaclass=DummyObject):
class PixArtSigmaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
......@@ -1907,7 +1907,7 @@ class PixArtSigmaPAGPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class PixArtSigmaPipeline(metaclass=DummyObject):
class PRXPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
......
......@@ -17,7 +17,7 @@ import unittest
import torch
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
......@@ -26,8 +26,8 @@ from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class PhotonTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = PhotonTransformer2DModel
class PRXTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = PRXTransformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
......@@ -75,7 +75,7 @@ class PhotonTransformerTests(ModelTesterMixin, unittest.TestCase):
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"PhotonTransformer2DModel"}
expected_set = {"PRXTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
......
......@@ -8,8 +8,8 @@ from transformers.models.t5gemma.configuration_t5gemma import T5GemmaConfig, T5G
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
from diffusers.models import AutoencoderDC, AutoencoderKL
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
from diffusers.pipelines.photon.pipeline_photon import PhotonPipeline
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.pipelines.prx.pipeline_prx import PRXPipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import is_transformers_version
......@@ -22,8 +22,8 @@ from ..test_pipelines_common import PipelineTesterMixin
reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544",
strict=False,
)
class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = PhotonPipeline
class PRXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = PRXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = frozenset(["prompt", "negative_prompt", "num_images_per_prompt"])
test_xformers_attention = False
......@@ -32,16 +32,16 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@classmethod
def setUpClass(cls):
# Ensure PhotonPipeline has an _execution_device property expected by __call__
if not isinstance(getattr(PhotonPipeline, "_execution_device", None), property):
# Ensure PRXPipeline has an _execution_device property expected by __call__
if not isinstance(getattr(PRXPipeline, "_execution_device", None), property):
try:
setattr(PhotonPipeline, "_execution_device", property(lambda self: torch.device("cpu")))
setattr(PRXPipeline, "_execution_device", property(lambda self: torch.device("cpu")))
except Exception:
pass
def get_dummy_components(self):
torch.manual_seed(0)
transformer = PhotonTransformer2DModel(
transformer = PRXTransformer2DModel(
patch_size=1,
in_channels=4,
context_in_dim=8,
......@@ -129,7 +129,7 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = PhotonPipeline(**components)
pipe = PRXPipeline(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
try:
......@@ -148,7 +148,7 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def test_callback_inputs(self):
components = self.get_dummy_components()
pipe = PhotonPipeline(**components)
pipe = PRXPipeline(**components)
pipe = pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
try:
......@@ -157,7 +157,7 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pass
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {PhotonPipeline} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
f" {PRXPipeline} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_subset(pipe, i, t, callback_kwargs):
......@@ -216,7 +216,7 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
self.assertLess(max(max_diff1, max_diff2), expected_max_diff)
def test_inference_with_autoencoder_dc(self):
"""Test PhotonPipeline with AutoencoderDC (DCAE) instead of AutoencoderKL."""
"""Test PRXPipeline with AutoencoderDC (DCAE) instead of AutoencoderKL."""
device = "cpu"
components = self.get_dummy_components()
......@@ -248,7 +248,7 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
components["vae"] = vae_dc
pipe = PhotonPipeline(**components)
pipe = PRXPipeline(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment