Unverified Commit a1f36ee3 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Z-Image] various small changes, Z-Image transformer tests, etc. (#12741)



* start zimage model tests.

* up

* up

* up

* up

* up

* up

* up

* up

* up

* up

* up

* up

* Revert "up"

This reverts commit bca3e27c96b942db49ccab8ddf824e7a54d43ed1.

* expand upon compilation failure reason.

* Update tests/models/transformers/test_models_transformer_z_image.py
Co-authored-by: default avatardg845 <58458699+dg845@users.noreply.github.com>

* reinitialize the padding tokens to ones to prevent NaN problems.

* updates

* up

* skipping ZImage DiT tests

* up

* up

---------
Co-authored-by: default avatardg845 <58458699+dg845@users.noreply.github.com>
parent d96cbaca
...@@ -27,6 +27,7 @@ from ...models.modeling_utils import ModelMixin ...@@ -27,6 +27,7 @@ from ...models.modeling_utils import ModelMixin
from ...models.normalization import RMSNorm from ...models.normalization import RMSNorm
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_dispatch import dispatch_attention_fn from ..attention_dispatch import dispatch_attention_fn
from ..modeling_outputs import Transformer2DModelOutput
ADALN_EMBED_DIM = 256 ADALN_EMBED_DIM = 256
...@@ -39,17 +40,9 @@ class TimestepEmbedder(nn.Module): ...@@ -39,17 +40,9 @@ class TimestepEmbedder(nn.Module):
if mid_size is None: if mid_size is None:
mid_size = out_size mid_size = out_size
self.mlp = nn.Sequential( self.mlp = nn.Sequential(
nn.Linear( nn.Linear(frequency_embedding_size, mid_size, bias=True),
frequency_embedding_size,
mid_size,
bias=True,
),
nn.SiLU(), nn.SiLU(),
nn.Linear( nn.Linear(mid_size, out_size, bias=True),
mid_size,
out_size,
bias=True,
),
) )
self.frequency_embedding_size = frequency_embedding_size self.frequency_embedding_size = frequency_embedding_size
...@@ -211,9 +204,7 @@ class ZImageTransformerBlock(nn.Module): ...@@ -211,9 +204,7 @@ class ZImageTransformerBlock(nn.Module):
self.modulation = modulation self.modulation = modulation
if modulation: if modulation:
self.adaLN_modulation = nn.Sequential( self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True))
nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True),
)
def forward( def forward(
self, self,
...@@ -230,33 +221,19 @@ class ZImageTransformerBlock(nn.Module): ...@@ -230,33 +221,19 @@ class ZImageTransformerBlock(nn.Module):
# Attention block # Attention block
attn_out = self.attention( attn_out = self.attention(
self.attention_norm1(x) * scale_msa, self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis
attention_mask=attn_mask,
freqs_cis=freqs_cis,
) )
x = x + gate_msa * self.attention_norm2(attn_out) x = x + gate_msa * self.attention_norm2(attn_out)
# FFN block # FFN block
x = x + gate_mlp * self.ffn_norm2( x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
self.feed_forward(
self.ffn_norm1(x) * scale_mlp,
)
)
else: else:
# Attention block # Attention block
attn_out = self.attention( attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis)
self.attention_norm1(x),
attention_mask=attn_mask,
freqs_cis=freqs_cis,
)
x = x + self.attention_norm2(attn_out) x = x + self.attention_norm2(attn_out)
# FFN block # FFN block
x = x + self.ffn_norm2( x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
self.feed_forward(
self.ffn_norm1(x),
)
)
return x return x
...@@ -404,10 +381,7 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr ...@@ -404,10 +381,7 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
] ]
) )
self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
self.cap_embedder = nn.Sequential( self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True))
RMSNorm(cap_feat_dim, eps=norm_eps),
nn.Linear(cap_feat_dim, dim, bias=True),
)
self.x_pad_token = nn.Parameter(torch.empty((1, dim))) self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
...@@ -494,11 +468,8 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr ...@@ -494,11 +468,8 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
) )
# padded feature # padded feature
cap_padded_feat = torch.cat( cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0)
[cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], all_cap_feats_out.append(cap_padded_feat)
dim=0,
)
all_cap_feats_out.append(cap_padded_feat if cap_padding_len > 0 else cap_feat)
### Process Image ### Process Image
C, F, H, W = image.size() C, F, H, W = image.size()
...@@ -564,6 +535,7 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr ...@@ -564,6 +535,7 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
cap_feats: List[torch.Tensor], cap_feats: List[torch.Tensor],
patch_size=2, patch_size=2,
f_patch_size=1, f_patch_size=1,
return_dict: bool = True,
): ):
assert patch_size in self.all_patch_size assert patch_size in self.all_patch_size
assert f_patch_size in self.all_f_patch_size assert f_patch_size in self.all_f_patch_size
...@@ -672,4 +644,7 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr ...@@ -672,4 +644,7 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
unified = list(unified.unbind(dim=0)) unified = list(unified.unbind(dim=0))
x = self.unpatchify(unified, x_size, patch_size, f_patch_size) x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
return x, {} if not return_dict:
return (x,)
return Transformer2DModelOutput(sample=x)
...@@ -525,9 +525,7 @@ class ZImagePipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMix ...@@ -525,9 +525,7 @@ class ZImagePipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMix
latent_model_input_list = list(latent_model_input.unbind(dim=0)) latent_model_input_list = list(latent_model_input.unbind(dim=0))
model_out_list = self.transformer( model_out_list = self.transformer(
latent_model_input_list, latent_model_input_list, timestep_model_input, prompt_embeds_model_input, return_dict=False
timestep_model_input,
prompt_embeds_model_input,
)[0] )[0]
if apply_cfg: if apply_cfg:
......
...@@ -15,17 +15,13 @@ ...@@ -15,17 +15,13 @@
import sys import sys
import unittest import unittest
import numpy as np
import torch import torch
from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model
from diffusers import ( from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, ZImagePipeline, ZImageTransformer2DModel
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
ZImagePipeline,
ZImageTransformer2DModel,
)
from ..testing_utils import floats_tensor, is_peft_available, require_peft_backend from ..testing_utils import floats_tensor, is_peft_available, require_peft_backend, skip_mps, torch_device
if is_peft_available(): if is_peft_available():
...@@ -34,13 +30,9 @@ if is_peft_available(): ...@@ -34,13 +30,9 @@ if is_peft_available():
sys.path.append(".") sys.path.append(".")
from .utils import PeftLoraLoaderMixinTests # noqa: E402 from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
@unittest.skip(
"ZImage LoRA tests are skipped due to non-deterministic behavior from complex64 RoPE operations "
"and torch.empty padding tokens. LoRA functionality works correctly with real models."
)
@require_peft_backend @require_peft_backend
class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = ZImagePipeline pipeline_class = ZImagePipeline
...@@ -127,6 +119,12 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -127,6 +119,12 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
tokenizer = Qwen2Tokenizer.from_pretrained(self.tokenizer_id) tokenizer = Qwen2Tokenizer.from_pretrained(self.tokenizer_id)
transformer = self.transformer_cls(**self.transformer_kwargs) transformer = self.transformer_cls(**self.transformer_kwargs)
# `x_pad_token` and `cap_pad_token` are initialized with `torch.empty`.
# This can cause NaN data values in our testing environment. Fixating them
# helps prevent that issue.
with torch.no_grad():
transformer.x_pad_token.copy_(torch.ones_like(transformer.x_pad_token.data))
transformer.cap_pad_token.copy_(torch.ones_like(transformer.cap_pad_token.data))
vae = self.vae_cls(**self.vae_kwargs) vae = self.vae_cls(**self.vae_kwargs)
if scheduler_cls is None: if scheduler_cls is None:
...@@ -161,3 +159,127 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -161,3 +159,127 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
} }
return pipeline_components, text_lora_config, denoiser_lora_config return pipeline_components, text_lora_config, denoiser_lora_config
def test_correct_lora_configs_with_different_ranks(self):
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.transformer.delete_adapters("adapter-1")
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
for name, _ in denoiser.named_modules():
if "to_k" in name and "attention" in name and "lora" not in name:
module_name_to_rank_update = name.replace(".base_layer.", ".")
break
# change the rank_pattern
updated_rank = denoiser_lora_config.r * 2
denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank}
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern
self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank})
lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))
pipe.transformer.delete_adapters("adapter-1")
# similarly change the alpha_pattern
updated_alpha = denoiser_lora_config.lora_alpha * 2
denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha}
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(
pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
)
lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
@skip_mps
def test_lora_fuse_nan(self):
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
possible_tower_names = ["noise_refiner"]
filtered_tower_names = [
tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name)
]
for tower_name in filtered_tower_names:
transformer_tower = getattr(pipe.transformer, tower_name)
transformer_tower[0].attention.to_q.lora_A["adapter-1"].weight += float("inf")
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(**inputs)[0]
self.assertTrue(np.isnan(out).all())
def test_lora_scale_kwargs_match_fusion(self):
super().test_lora_scale_kwargs_match_fusion(5e-2, 5e-2)
@unittest.skip("Needs to be debugged.")
def test_set_adapters_match_attention_kwargs(self):
super().test_set_adapters_match_attention_kwargs()
@unittest.skip("Needs to be debugged.")
def test_simple_inference_with_text_denoiser_lora_and_scale(self):
super().test_simple_inference_with_text_denoiser_lora_and_scale()
@unittest.skip("Not supported in ZImage.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in ZImage.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in ZImage.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
def test_simple_inference_with_text_lora_save_load(self):
pass
...@@ -47,6 +47,7 @@ from diffusers.models.attention_processor import ( ...@@ -47,6 +47,7 @@ from diffusers.models.attention_processor import (
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from diffusers.models.auto_model import AutoModel from diffusers.models.auto_model import AutoModel
from diffusers.models.modeling_outputs import BaseOutput
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from diffusers.utils import ( from diffusers.utils import (
SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME,
...@@ -108,6 +109,11 @@ def check_if_lora_correctly_set(model) -> bool: ...@@ -108,6 +109,11 @@ def check_if_lora_correctly_set(model) -> bool:
return False return False
def normalize_output(out):
out0 = out[0] if isinstance(out, (BaseOutput, tuple)) else out
return torch.stack(out0) if isinstance(out0, list) else out0
# Will be run via run_test_in_subprocess # Will be run via run_test_in_subprocess
def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout): def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
error = None error = None
...@@ -536,6 +542,9 @@ class ModelTesterMixin: ...@@ -536,6 +542,9 @@ class ModelTesterMixin:
if isinstance(new_image, dict): if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0] new_image = new_image.to_tuple()[0]
image = normalize_output(image)
new_image = normalize_output(new_image)
max_diff = (image - new_image).abs().max().item() max_diff = (image - new_image).abs().max().item()
self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes") self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes")
...@@ -780,6 +789,9 @@ class ModelTesterMixin: ...@@ -780,6 +789,9 @@ class ModelTesterMixin:
if isinstance(new_image, dict): if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0] new_image = new_image.to_tuple()[0]
image = normalize_output(image)
new_image = normalize_output(new_image)
max_diff = (image - new_image).abs().max().item() max_diff = (image - new_image).abs().max().item()
self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes") self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes")
...@@ -842,6 +854,9 @@ class ModelTesterMixin: ...@@ -842,6 +854,9 @@ class ModelTesterMixin:
if isinstance(second, dict): if isinstance(second, dict):
second = second.to_tuple()[0] second = second.to_tuple()[0]
first = normalize_output(first)
second = normalize_output(second)
out_1 = first.cpu().numpy() out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy() out_2 = second.cpu().numpy()
out_1 = out_1[~np.isnan(out_1)] out_1 = out_1[~np.isnan(out_1)]
...@@ -860,11 +875,15 @@ class ModelTesterMixin: ...@@ -860,11 +875,15 @@ class ModelTesterMixin:
if isinstance(output, dict): if isinstance(output, dict):
output = output.to_tuple()[0] output = output.to_tuple()[0]
if isinstance(output, list):
output = torch.stack(output)
self.assertIsNotNone(output) self.assertIsNotNone(output)
# input & output have to have the same shape # input & output have to have the same shape
input_tensor = inputs_dict[self.main_input_name] input_tensor = inputs_dict[self.main_input_name]
if isinstance(input_tensor, list):
input_tensor = torch.stack(input_tensor)
if expected_output_shape is None: if expected_output_shape is None:
expected_shape = input_tensor.shape expected_shape = input_tensor.shape
...@@ -898,11 +917,15 @@ class ModelTesterMixin: ...@@ -898,11 +917,15 @@ class ModelTesterMixin:
if isinstance(output_1, dict): if isinstance(output_1, dict):
output_1 = output_1.to_tuple()[0] output_1 = output_1.to_tuple()[0]
if isinstance(output_1, list):
output_1 = torch.stack(output_1)
output_2 = new_model(**inputs_dict) output_2 = new_model(**inputs_dict)
if isinstance(output_2, dict): if isinstance(output_2, dict):
output_2 = output_2.to_tuple()[0] output_2 = output_2.to_tuple()[0]
if isinstance(output_2, list):
output_2 = torch.stack(output_2)
self.assertEqual(output_1.shape, output_2.shape) self.assertEqual(output_1.shape, output_2.shape)
...@@ -1138,6 +1161,8 @@ class ModelTesterMixin: ...@@ -1138,6 +1161,8 @@ class ModelTesterMixin:
torch.manual_seed(0) torch.manual_seed(0)
output_no_lora = model(**inputs_dict, return_dict=False)[0] output_no_lora = model(**inputs_dict, return_dict=False)[0]
if isinstance(output_no_lora, list):
output_no_lora = torch.stack(output_no_lora)
denoiser_lora_config = LoraConfig( denoiser_lora_config = LoraConfig(
r=rank, r=rank,
...@@ -1151,6 +1176,8 @@ class ModelTesterMixin: ...@@ -1151,6 +1176,8 @@ class ModelTesterMixin:
torch.manual_seed(0) torch.manual_seed(0)
outputs_with_lora = model(**inputs_dict, return_dict=False)[0] outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
if isinstance(outputs_with_lora, list):
outputs_with_lora = torch.stack(outputs_with_lora)
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4)) self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4))
...@@ -1175,6 +1202,8 @@ class ModelTesterMixin: ...@@ -1175,6 +1202,8 @@ class ModelTesterMixin:
torch.manual_seed(0) torch.manual_seed(0)
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0] outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
if isinstance(outputs_with_lora_2, list):
outputs_with_lora_2 = torch.stack(outputs_with_lora_2)
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)) self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)) self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
...@@ -1296,31 +1325,35 @@ class ModelTesterMixin: ...@@ -1296,31 +1325,35 @@ class ModelTesterMixin:
def test_cpu_offload(self): def test_cpu_offload(self):
if self.model_class._no_split_modules is None: if self.model_class._no_split_modules is None:
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.") pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
config, inputs_dict = self.prepare_init_args_and_inputs_for_common() config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval() model = self.model_class(**config).eval()
model = model.to(torch_device) model = model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
base_output = model(**inputs_dict) base_output = model(**inputs_dict)
base_normalized_output = normalize_output(base_output)
model_size = compute_module_sizes(model)[""] model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]] max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir) model.cpu().save_pretrained(tmp_dir)
for max_size in max_gpu_sizes: for max_size in max_gpu_sizes:
max_memory = {0: max_size, "cpu": model_size * 2} max_memory = {0: max_size, "cpu": model_size * 2}
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded # Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"}) self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"})
self.check_device_map_is_respected(new_model, new_model.hf_device_map) self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0) torch.manual_seed(0)
new_output = new_model(**inputs_dict) new_output = new_model(**inputs_dict)
new_normalized_output = normalize_output(new_output)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5))
@require_torch_accelerator @require_torch_accelerator
def test_disk_offload_without_safetensors(self): def test_disk_offload_without_safetensors(self):
...@@ -1333,6 +1366,7 @@ class ModelTesterMixin: ...@@ -1333,6 +1366,7 @@ class ModelTesterMixin:
torch.manual_seed(0) torch.manual_seed(0)
base_output = model(**inputs_dict) base_output = model(**inputs_dict)
base_normalized_output = normalize_output(base_output)
model_size = compute_module_sizes(model)[""] model_size = compute_module_sizes(model)[""]
max_size = int(self.model_split_percents[0] * model_size) max_size = int(self.model_split_percents[0] * model_size)
...@@ -1352,8 +1386,8 @@ class ModelTesterMixin: ...@@ -1352,8 +1386,8 @@ class ModelTesterMixin:
self.check_device_map_is_respected(new_model, new_model.hf_device_map) self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0) torch.manual_seed(0)
new_output = new_model(**inputs_dict) new_output = new_model(**inputs_dict)
new_normalized_output = normalize_output(new_output)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5))
@require_torch_accelerator @require_torch_accelerator
def test_disk_offload_with_safetensors(self): def test_disk_offload_with_safetensors(self):
...@@ -1366,6 +1400,7 @@ class ModelTesterMixin: ...@@ -1366,6 +1400,7 @@ class ModelTesterMixin:
torch.manual_seed(0) torch.manual_seed(0)
base_output = model(**inputs_dict) base_output = model(**inputs_dict)
base_normalized_output = normalize_output(base_output)
model_size = compute_module_sizes(model)[""] model_size = compute_module_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
...@@ -1380,8 +1415,9 @@ class ModelTesterMixin: ...@@ -1380,8 +1415,9 @@ class ModelTesterMixin:
self.check_device_map_is_respected(new_model, new_model.hf_device_map) self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0) torch.manual_seed(0)
new_output = new_model(**inputs_dict) new_output = new_model(**inputs_dict)
new_normalized_output = normalize_output(new_output)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5))
@require_torch_multi_accelerator @require_torch_multi_accelerator
def test_model_parallelism(self): def test_model_parallelism(self):
...@@ -1422,6 +1458,7 @@ class ModelTesterMixin: ...@@ -1422,6 +1458,7 @@ class ModelTesterMixin:
model = model.to(torch_device) model = model.to(torch_device)
base_output = model(**inputs_dict) base_output = model(**inputs_dict)
base_normalized_output = normalize_output(base_output)
model_size = compute_module_persistent_sizes(model)[""] model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
...@@ -1443,8 +1480,9 @@ class ModelTesterMixin: ...@@ -1443,8 +1480,9 @@ class ModelTesterMixin:
if "generator" in inputs_dict: if "generator" in inputs_dict:
_, inputs_dict = self.prepare_init_args_and_inputs_for_common() _, inputs_dict = self.prepare_init_args_and_inputs_for_common()
new_output = new_model(**inputs_dict) new_output = new_model(**inputs_dict)
new_normalized_output = normalize_output(new_output)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5))
@require_torch_accelerator @require_torch_accelerator
def test_sharded_checkpoints_with_variant(self): def test_sharded_checkpoints_with_variant(self):
...@@ -1454,6 +1492,7 @@ class ModelTesterMixin: ...@@ -1454,6 +1492,7 @@ class ModelTesterMixin:
model = model.to(torch_device) model = model.to(torch_device)
base_output = model(**inputs_dict) base_output = model(**inputs_dict)
base_normalized_output = normalize_output(base_output)
model_size = compute_module_persistent_sizes(model)[""] model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
...@@ -1481,8 +1520,9 @@ class ModelTesterMixin: ...@@ -1481,8 +1520,9 @@ class ModelTesterMixin:
if "generator" in inputs_dict: if "generator" in inputs_dict:
_, inputs_dict = self.prepare_init_args_and_inputs_for_common() _, inputs_dict = self.prepare_init_args_and_inputs_for_common()
new_output = new_model(**inputs_dict) new_output = new_model(**inputs_dict)
new_normalized_output = normalize_output(new_output)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5))
@require_torch_accelerator @require_torch_accelerator
def test_sharded_checkpoints_with_parallel_loading(self): def test_sharded_checkpoints_with_parallel_loading(self):
...@@ -1492,6 +1532,7 @@ class ModelTesterMixin: ...@@ -1492,6 +1532,7 @@ class ModelTesterMixin:
model = model.to(torch_device) model = model.to(torch_device)
base_output = model(**inputs_dict) base_output = model(**inputs_dict)
base_normalized_output = normalize_output(base_output)
model_size = compute_module_persistent_sizes(model)[""] model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
...@@ -1515,7 +1556,9 @@ class ModelTesterMixin: ...@@ -1515,7 +1556,9 @@ class ModelTesterMixin:
if "generator" in inputs_dict: if "generator" in inputs_dict:
_, inputs_dict = self.prepare_init_args_and_inputs_for_common() _, inputs_dict = self.prepare_init_args_and_inputs_for_common()
new_output = new_model(**inputs_dict) new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) new_normalized_output = normalize_output(new_output)
self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5))
# set to no. # set to no.
os.environ["HF_ENABLE_PARALLEL_LOADING"] = "no" os.environ["HF_ENABLE_PARALLEL_LOADING"] = "no"
...@@ -1529,6 +1572,7 @@ class ModelTesterMixin: ...@@ -1529,6 +1572,7 @@ class ModelTesterMixin:
torch.manual_seed(0) torch.manual_seed(0)
base_output = model(**inputs_dict) base_output = model(**inputs_dict)
base_normalized_output = normalize_output(base_output)
model_size = compute_module_persistent_sizes(model)[""] model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
...@@ -1549,7 +1593,9 @@ class ModelTesterMixin: ...@@ -1549,7 +1593,9 @@ class ModelTesterMixin:
if "generator" in inputs_dict: if "generator" in inputs_dict:
_, inputs_dict = self.prepare_init_args_and_inputs_for_common() _, inputs_dict = self.prepare_init_args_and_inputs_for_common()
new_output = new_model(**inputs_dict) new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) new_normalized_output = normalize_output(new_output)
self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5))
# This test is okay without a GPU because we're not running any execution. We're just serializing # This test is okay without a GPU because we're not running any execution. We're just serializing
# and check if the resultant files are following an expected format. # and check if the resultant files are following an expected format.
...@@ -1629,7 +1675,9 @@ class ModelTesterMixin: ...@@ -1629,7 +1675,9 @@ class ModelTesterMixin:
model = self.model_class(**config) model = self.model_class(**config)
model.eval() model.eval()
model.to(torch_device) model.to(torch_device)
base_slice = model(**inputs_dict)[0].detach().flatten().cpu().numpy() base_slice = model(**inputs_dict)[0]
base_slice = normalize_output(base_slice)
base_slice = base_slice.detach().flatten().cpu().numpy()
def check_linear_dtype(module, storage_dtype, compute_dtype): def check_linear_dtype(module, storage_dtype, compute_dtype):
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
...@@ -1655,7 +1703,9 @@ class ModelTesterMixin: ...@@ -1655,7 +1703,9 @@ class ModelTesterMixin:
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
check_linear_dtype(model, storage_dtype, compute_dtype) check_linear_dtype(model, storage_dtype, compute_dtype)
output = model(**inputs_dict)[0].float().flatten().detach().cpu().numpy() output = model(**inputs_dict)[0]
output = normalize_output(output)
output = output.float().flatten().detach().cpu().numpy()
# The precision test is not very important for fast tests. In most cases, the outputs will not be the same. # The precision test is not very important for fast tests. In most cases, the outputs will not be the same.
# We just want to make sure that the layerwise casting is working as expected. # We just want to make sure that the layerwise casting is working as expected.
...@@ -1716,6 +1766,12 @@ class ModelTesterMixin: ...@@ -1716,6 +1766,12 @@ class ModelTesterMixin:
@parameterized.expand([False, True]) @parameterized.expand([False, True])
@require_torch_accelerator @require_torch_accelerator
def test_group_offloading(self, record_stream): def test_group_offloading(self, record_stream):
for cls in inspect.getmro(self.__class__):
if "test_group_offloading" in cls.__dict__ and cls is not ModelTesterMixin:
# Skip this test if it is overwritten by child class. We need to do this because parameterized
# materializes the test methods on invocation which cannot be overridden.
pytest.skip("Model does not support group offloading.")
if not self.model_class._supports_group_offloading: if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.") pytest.skip("Model does not support group offloading.")
...@@ -1738,21 +1794,25 @@ class ModelTesterMixin: ...@@ -1738,21 +1794,25 @@ class ModelTesterMixin:
model.to(torch_device) model.to(torch_device)
output_without_group_offloading = run_forward(model) output_without_group_offloading = run_forward(model)
output_without_group_offloading = normalize_output(output_without_group_offloading)
torch.manual_seed(0) torch.manual_seed(0)
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1) model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
output_with_group_offloading1 = run_forward(model) output_with_group_offloading1 = run_forward(model)
output_with_group_offloading1 = normalize_output(output_with_group_offloading1)
torch.manual_seed(0) torch.manual_seed(0)
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True) model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True)
output_with_group_offloading2 = run_forward(model) output_with_group_offloading2 = run_forward(model)
output_with_group_offloading2 = normalize_output(output_with_group_offloading2)
torch.manual_seed(0) torch.manual_seed(0)
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="leaf_level") model.enable_group_offload(torch_device, offload_type="leaf_level")
output_with_group_offloading3 = run_forward(model) output_with_group_offloading3 = run_forward(model)
output_with_group_offloading3 = normalize_output(output_with_group_offloading3)
torch.manual_seed(0) torch.manual_seed(0)
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
...@@ -1760,6 +1820,7 @@ class ModelTesterMixin: ...@@ -1760,6 +1820,7 @@ class ModelTesterMixin:
torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream
) )
output_with_group_offloading4 = run_forward(model) output_with_group_offloading4 = run_forward(model)
output_with_group_offloading4 = normalize_output(output_with_group_offloading4)
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5)) self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5)) self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5))
...@@ -1799,6 +1860,12 @@ class ModelTesterMixin: ...@@ -1799,6 +1860,12 @@ class ModelTesterMixin:
@torch.no_grad() @torch.no_grad()
@torch.inference_mode() @torch.inference_mode()
def test_group_offloading_with_disk(self, offload_type, record_stream, atol=1e-5): def test_group_offloading_with_disk(self, offload_type, record_stream, atol=1e-5):
for cls in inspect.getmro(self.__class__):
if "test_group_offloading_with_disk" in cls.__dict__ and cls is not ModelTesterMixin:
# Skip this test if it is overwritten by child class. We need to do this because parameterized
# materializes the test methods on invocation which cannot be overridden.
pytest.skip("Model does not support group offloading with disk yet.")
if not self.model_class._supports_group_offloading: if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.") pytest.skip("Model does not support group offloading.")
...@@ -1821,6 +1888,7 @@ class ModelTesterMixin: ...@@ -1821,6 +1888,7 @@ class ModelTesterMixin:
model.eval() model.eval()
model.to(torch_device) model.to(torch_device)
output_without_group_offloading = _run_forward(model, inputs_dict) output_without_group_offloading = _run_forward(model, inputs_dict)
output_without_group_offloading = normalize_output(output_without_group_offloading)
torch.manual_seed(0) torch.manual_seed(0)
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
...@@ -1856,6 +1924,7 @@ class ModelTesterMixin: ...@@ -1856,6 +1924,7 @@ class ModelTesterMixin:
raise ValueError(f"Following files are missing: {', '.join(missing_files)}") raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
output_with_group_offloading = _run_forward(model, inputs_dict) output_with_group_offloading = _run_forward(model, inputs_dict)
output_with_group_offloading = normalize_output(output_with_group_offloading)
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol)) self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol))
def test_auto_model(self, expected_max_diff=5e-5): def test_auto_model(self, expected_max_diff=5e-5):
...@@ -1889,10 +1958,17 @@ class ModelTesterMixin: ...@@ -1889,10 +1958,17 @@ class ModelTesterMixin:
output_original = model(**inputs_dict) output_original = model(**inputs_dict)
output_auto = auto_model(**inputs_dict) output_auto = auto_model(**inputs_dict)
if isinstance(output_original, dict): if isinstance(output_original, dict):
output_original = output_original.to_tuple()[0] output_original = output_original.to_tuple()[0]
if isinstance(output_auto, dict): if isinstance(output_auto, dict):
output_auto = output_auto.to_tuple()[0] output_auto = output_auto.to_tuple()[0]
if isinstance(output_original, list):
output_original = torch.stack(output_original)
if isinstance(output_auto, list):
output_auto = torch.stack(output_auto)
output_original, output_auto = output_original.float(), output_auto.float()
max_diff = (output_original - output_auto).abs().max().item() max_diff = (output_original - output_auto).abs().max().item()
self.assertLessEqual( self.assertLessEqual(
...@@ -2083,6 +2159,8 @@ class TorchCompileTesterMixin: ...@@ -2083,6 +2159,8 @@ class TorchCompileTesterMixin:
recompile_limit = 1 recompile_limit = 1
if self.model_class.__name__ == "UNet2DConditionModel": if self.model_class.__name__ == "UNet2DConditionModel":
recompile_limit = 2 recompile_limit = 2
elif self.model_class.__name__ == "ZImageTransformer2DModel":
recompile_limit = 3
with ( with (
torch._inductor.utils.fresh_inductor_cache(), torch._inductor.utils.fresh_inductor_cache(),
...@@ -2184,7 +2262,6 @@ class LoraHotSwappingForModelTesterMixin: ...@@ -2184,7 +2262,6 @@ class LoraHotSwappingForModelTesterMixin:
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
def get_lora_config(self, lora_rank, lora_alpha, target_modules): def get_lora_config(self, lora_rank, lora_alpha, target_modules):
# from diffusers test_models_unet_2d_condition.py
from peft import LoraConfig from peft import LoraConfig
lora_config = LoraConfig( lora_config = LoraConfig(
......
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import os
import unittest
import torch
from diffusers import ZImageTransformer2DModel
from ...testing_utils import IS_GITHUB_ACTIONS, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations
# Cannot use enable_full_determinism() which sets it to True
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if hasattr(torch.backends, "cuda"):
torch.backends.cuda.matmul.allow_tf32 = False
@unittest.skipIf(
IS_GITHUB_ACTIONS,
reason="Skipping test-suite inside the CI because the model has `torch.empty()` inside of it during init and we don't have a clear way to override it in the modeling tests.",
)
class ZImageTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = ZImageTransformer2DModel
main_input_name = "x"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.9, 0.9, 0.9]
def prepare_dummy_input(self, height=16, width=16):
batch_size = 1
num_channels = 16
embedding_dim = 16
sequence_length = 16
hidden_states = [torch.randn((num_channels, 1, height, width)).to(torch_device) for _ in range(batch_size)]
encoder_hidden_states = [
torch.randn((sequence_length, embedding_dim)).to(torch_device) for _ in range(batch_size)
]
timestep = torch.tensor([0.0]).to(torch_device)
return {"x": hidden_states, "cap_feats": encoder_hidden_states, "t": timestep}
@property
def dummy_input(self):
return self.prepare_dummy_input()
@property
def input_shape(self):
return (4, 32, 32)
@property
def output_shape(self):
return (4, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"all_patch_size": (2,),
"all_f_patch_size": (1,),
"in_channels": 16,
"dim": 16,
"n_layers": 1,
"n_refiner_layers": 1,
"n_heads": 1,
"n_kv_heads": 2,
"qk_norm": True,
"cap_feat_dim": 16,
"rope_theta": 256.0,
"t_scale": 1000.0,
"axes_dims": [8, 4, 4],
"axes_lens": [256, 32, 32],
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def setUp(self):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
def tearDown(self):
super().tearDown()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
def test_gradient_checkpointing_is_applied(self):
expected_set = {"ZImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Test is not supported for handling main inputs that are lists.")
def test_training(self):
super().test_training()
@unittest.skip("Test is not supported for handling main inputs that are lists.")
def test_ema_training(self):
super().test_ema_training()
@unittest.skip("Test is not supported for handling main inputs that are lists.")
def test_effective_gradient_checkpointing(self):
super().test_effective_gradient_checkpointing()
@unittest.skip(
"Test needs to be revisited. But we need to ensure `x_pad_token` and `cap_pad_token` are cast to the same dtype as the destination tensor before they are assigned to the padding indices."
)
def test_layerwise_casting_training(self):
super().test_layerwise_casting_training()
@unittest.skip("Test is not supported for handling main inputs that are lists.")
def test_outputs_equivalence(self):
super().test_outputs_equivalence()
@unittest.skip("Test will pass if we change to deterministic values instead of empty in the DiT.")
def test_group_offloading(self):
super().test_group_offloading()
@unittest.skip("Test will pass if we change to deterministic values instead of empty in the DiT.")
def test_group_offloading_with_disk(self):
super().test_group_offloading_with_disk()
class ZImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = ZImageTransformer2DModel
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
def prepare_init_args_and_inputs_for_common(self):
return ZImageTransformerTests().prepare_init_args_and_inputs_for_common()
def prepare_dummy_input(self, height, width):
return ZImageTransformerTests().prepare_dummy_input(height=height, width=width)
@unittest.skip(
"The repeated block in this model is ZImageTransformerBlock, which is used for noise_refiner, context_refiner, and layers. As a consequence of this, the inputs recorded for the block would vary during compilation and full compilation with fullgraph=True would trigger recompilation at least thrice."
)
def test_torch_compile_recompilation_and_graph_break(self):
super().test_torch_compile_recompilation_and_graph_break()
@unittest.skip("Fullgraph AoT is broken")
def test_compile_works_with_aot(self):
super().test_compile_works_with_aot()
@unittest.skip("Fullgraph is broken")
def test_compile_on_different_shapes(self):
super().test_compile_on_different_shapes()
...@@ -20,12 +20,7 @@ import numpy as np ...@@ -20,12 +20,7 @@ import numpy as np
import torch import torch
from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model
from diffusers import ( from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, ZImagePipeline, ZImageTransformer2DModel
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
ZImagePipeline,
ZImageTransformer2DModel,
)
from ...testing_utils import torch_device from ...testing_utils import torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
...@@ -106,6 +101,12 @@ class ZImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -106,6 +101,12 @@ class ZImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
axes_dims=[8, 4, 4], axes_dims=[8, 4, 4],
axes_lens=[256, 32, 32], axes_lens=[256, 32, 32],
) )
# `x_pad_token` and `cap_pad_token` are initialized with `torch.empty`.
# This can cause NaN data values in our testing environment. Fixating them
# helps prevent that issue.
with torch.no_grad():
transformer.x_pad_token.copy_(torch.ones_like(transformer.x_pad_token.data))
transformer.cap_pad_token.copy_(torch.ones_like(transformer.cap_pad_token.data))
torch.manual_seed(0) torch.manual_seed(0)
vae = AutoencoderKL( vae = AutoencoderKL(
...@@ -183,7 +184,7 @@ class ZImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -183,7 +184,7 @@ class ZImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
self.assertEqual(generated_image.shape, (3, 32, 32)) self.assertEqual(generated_image.shape, (3, 32, 32))
# fmt: off # fmt: off
expected_slice = torch.tensor([0.4521, 0.4512, 0.4693, 0.5115, 0.5250, 0.5271, 0.4776, 0.4688, 0.2765, 0.2164, 0.5656, 0.6909, 0.3831, 0.5431, 0.5493, 0.4732]) expected_slice = torch.tensor([0.4622, 0.4532, 0.4714, 0.5087, 0.5371, 0.5405, 0.4492, 0.4479, 0.2984, 0.2783, 0.5409, 0.6577, 0.3952, 0.5524, 0.5262, 0.453])
# fmt: on # fmt: on
generated_slice = generated_image.flatten() generated_slice = generated_image.flatten()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment