Unverified Commit 8adc6003 authored by Edna's avatar Edna Committed by GitHub
Browse files

Chroma Pipeline (#11698)



* working state from hameerabbasi and iddl

* working state form hameerabbasi and iddl (transformer)

* working state (normalization)

* working state (embeddings)

* add chroma loader

* add chroma to mappings

* add chroma to transformer init

* take out variant stuff

* get decently far in changing variant stuff

* add chroma init

* make chroma output class

* add chroma transformer to dummy tp

* add chroma to init

* add chroma to init

* fix single file

* update

* update

* add chroma to auto pipeline

* add chroma to pipeline init

* change to chroma transformer

* take out variant from blocks

* swap embedder location

* remove prompt_2

* work on swapping text encoders

* remove mask function

* dont modify mask (for now)

* wrap attn mask

* no attn mask (can't get it to work)

* remove pooled prompt embeds

* change to my own unpooled embeddeer

* fix load

* take pooled projections out of transformer

* ensure correct dtype for chroma embeddings

* update

* use dn6 attn mask + fix true_cfg_scale

* use chroma pipeline output

* use DN6 embeddings

* remove guidance

* remove guidance embed (pipeline)

* remove guidance from embeddings

* don't return length

* dont change dtype

* remove unused stuff, fix up docs

* add chroma autodoc

* add .md (oops)

* initial chroma docs

* undo don't change dtype

* undo arxiv change

unsure why that happened

* fix hf papers regression in more places

* Update docs/source/en/api/pipelines/chroma.md
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* do_cfg -> self.do_classifier_free_guidance

* Update docs/source/en/api/models/chroma_transformer.md
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* Update chroma.md

* Move chroma layers into transformer

* Remove pruned AdaLayerNorms

* Add chroma fast tests

* (untested) batch cond and uncond

* Add # Copied from for shift

* Update # Copied from statements

* update norm imports

* Revert cond + uncond batching

* Add transformer tests

* move chroma test (oops)

* chroma init

* fix chroma pipeline fast tests

* Update src/diffusers/models/transformers/transformer_chroma.py
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* Move Approximator and Embeddings

* Fix auto pipeline + make style, quality

* make style

* Apply style fixes

* switch to new input ids

* fix # Copied from error

* remove # Copied from on protected members

* try to fix import

* fix import

* make fix-copes

* revert style fix

* update chroma transformer params

* update chroma transformer approximator init params

* update to pad tokens

* fix batch inference

* Make more pipeline tests work

* Make most transformer tests work

* fix docs

* make style, make quality

* skip batch tests

* fix test skipping

* fix test skipping again

* fix for tests

* Fix all pipeline test

* update

* push local changes, fix docs

* add encoder test, remove pooled dim

* default proj dim

* fix tests

* fix equal size list input

* update

* push local changes, fix docs

* add encoder test, remove pooled dim

* default proj dim

* fix tests

* fix equal size list input

* Revert "fix equal size list input"

This reverts commit 3fe4ad67d58d83715bc238f8654f5e90bfc5653c.

* update

* update

* update

* update

* update

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 9f91305f
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, ChromaPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import torch_device
from ..test_pipelines_common import (
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
class ChromaPipelineFastTests(
unittest.TestCase,
PipelineTesterMixin,
FluxIPAdapterTesterMixin,
):
pipeline_class = ChromaPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
batch_params = frozenset(["prompt"])
# there is no xformers processor for Flux
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0)
transformer = ChromaTransformer2DModel(
patch_size=1,
in_channels=4,
num_layers=num_layers,
num_single_layers=num_single_layers,
attention_head_dim=16,
num_attention_heads=2,
joint_attention_dim=32,
axes_dims_rope=[4, 4, 8],
approximator_hidden_dim=32,
approximator_layers=1,
approximator_num_channels=16,
)
torch.manual_seed(0)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=1,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0609,
scaling_factor=1.5035,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"transformer": transformer,
"vae": vae,
"image_encoder": None,
"feature_extractor": None,
}
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"negative_prompt": "bad, ugly",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 8,
"width": 8,
"max_sequence_length": 48,
"output_type": "np",
}
return inputs
def test_chroma_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = "a different prompt"
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
# Outputs should be different here
# For some reasons, they don't show large differences
assert max_diff > 1e-6
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
original_image_slice = image[0, -3:, -3:, -1]
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(pipe.transformer), (
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
)
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_fused = image[0, -3:, -3:, -1]
pipe.transformer.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
"Fusion of QKV projections shouldn't affect the outputs."
)
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
)
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
"Original outputs should match when fused QKV projections are disabled."
)
def test_chroma_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (72, 57)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
......@@ -521,7 +521,8 @@ class FluxIPAdapterTesterMixin:
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
inputs["negative_prompt"] = ""
inputs["true_cfg_scale"] = 4.0
if "true_cfg_scale" in inspect.signature(self.pipeline_class.__call__).parameters:
inputs["true_cfg_scale"] = 4.0
inputs["output_type"] = "np"
inputs["return_dict"] = False
return inputs
......@@ -542,7 +543,11 @@ class FluxIPAdapterTesterMixin:
components = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
pipe.set_progress_bar_config(disable=None)
image_embed_dim = pipe.transformer.config.pooled_projection_dim
image_embed_dim = (
pipe.transformer.config.pooled_projection_dim
if hasattr(pipe.transformer.config, "pooled_projection_dim")
else 768
)
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment