Unverified Commit 60286132 authored by hlky's avatar hlky Committed by GitHub
Browse files

Z-Image-Turbo `from_single_file` (#12756)

* Z-Image-Turbo `from_single_file`

* compute_dtype

* -device cast
parent a1f36ee3
......@@ -49,6 +49,7 @@ from .single_file_utils import (
convert_stable_cascade_unet_single_file_to_diffusers,
convert_wan_transformer_to_diffusers,
convert_wan_vae_to_diffusers,
convert_z_image_transformer_checkpoint_to_diffusers,
create_controlnet_diffusers_config_from_ldm,
create_unet_diffusers_config_from_ldm,
create_vae_diffusers_config_from_ldm,
......@@ -167,6 +168,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_flux2_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"ZImageTransformer2DModel": {
"checkpoint_mapping_fn": convert_z_image_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
}
......
......@@ -120,6 +120,7 @@ CHECKPOINT_KEY_NAMES = {
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
"z-image-turbo": "cap_embedder.0.weight",
"sana": [
"blocks.0.cross_attn.q_linear.weight",
"blocks.0.cross_attn.q_linear.bias",
......@@ -218,6 +219,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"cosmos-2.0-t2i-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Text2Image"},
"cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"},
"cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
"z-image-turbo": {"pretrained_model_name_or_path": "Tongyi-MAI/Z-Image-Turbo"},
}
# Use to configure model sample size when original config is provided
......@@ -721,6 +723,12 @@ def infer_diffusers_model_type(checkpoint):
):
model_type = "instruct-pix2pix"
elif (
CHECKPOINT_KEY_NAMES["z-image-turbo"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["z-image-turbo"]].shape[0] == 2560
):
model_type = "z-image-turbo"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
model_type = "lumina2"
......@@ -3824,3 +3832,56 @@ def convert_flux2_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict
def convert_z_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
Z_IMAGE_KEYS_RENAME_DICT = {
"final_layer.": "all_final_layer.2-1.",
"x_embedder.": "all_x_embedder.2-1.",
".attention.out.bias": ".attention.to_out.0.bias",
".attention.k_norm.weight": ".attention.norm_k.weight",
".attention.q_norm.weight": ".attention.norm_q.weight",
".attention.out.weight": ".attention.to_out.0.weight",
}
def convert_z_image_fused_attention(key: str, state_dict: dict[str, object]) -> None:
if ".attention.qkv.weight" not in key:
return
fused_qkv_weight = state_dict.pop(key)
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
new_q_name = key.replace(".attention.qkv.weight", ".attention.to_q.weight")
new_k_name = key.replace(".attention.qkv.weight", ".attention.to_k.weight")
new_v_name = key.replace(".attention.qkv.weight", ".attention.to_v.weight")
state_dict[new_q_name] = to_q_weight
state_dict[new_k_name] = to_k_weight
state_dict[new_v_name] = to_v_weight
return
TRANSFORMER_SPECIAL_KEYS_REMAP = {
".attention.qkv.weight": convert_z_image_fused_attention,
}
def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
# Handle single file --> diffusers key remapping via the remap dict
for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in Z_IMAGE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict(converted_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict
......@@ -63,8 +63,11 @@ class TimestepEmbedder(nn.Module):
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
weight_dtype = self.mlp[0].weight.dtype
compute_dtype = getattr(self.mlp[0], "compute_dtype", None)
if weight_dtype.is_floating_point:
t_freq = t_freq.to(weight_dtype)
elif compute_dtype is not None:
t_freq = t_freq.to(compute_dtype)
t_emb = self.mlp(t_freq)
return t_emb
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment