"src/torio/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "7c988b438a57a293da63cfc2acbddf4d01b281c5"
Unverified Commit b4226bd6 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Tests] fix config checking tests (#7247)

* debig

* cast tuples to lists.

* debug

* handle upcast attention

* handle downblock types for vae.

* remove print.

* address Dhruv's comments.

* fix: upblock types.

* upcast attention

* debug

* debug

* debug

* better guarding.

* style
parent 46fac824
......@@ -462,8 +462,8 @@ def create_unet_diffusers_config(original_config, image_size: int):
config = {
"sample_size": image_size // vae_scale_factor,
"in_channels": unet_params["in_channels"],
"down_block_types": tuple(down_block_types),
"block_out_channels": tuple(block_out_channels),
"down_block_types": down_block_types,
"block_out_channels": block_out_channels,
"layers_per_block": unet_params["num_res_blocks"],
"cross_attention_dim": context_dim,
"attention_head_dim": head_dim,
......@@ -482,7 +482,7 @@ def create_unet_diffusers_config(original_config, image_size: int):
config["num_class_embeds"] = unet_params["num_classes"]
config["out_channels"] = unet_params["out_channels"]
config["up_block_types"] = tuple(up_block_types)
config["up_block_types"] = up_block_types
return config
......@@ -530,9 +530,9 @@ def create_vae_diffusers_config(original_config, image_size, scaling_factor=None
"sample_size": image_size,
"in_channels": vae_params["in_channels"],
"out_channels": vae_params["out_ch"],
"down_block_types": tuple(down_block_types),
"up_block_types": tuple(up_block_types),
"block_out_channels": tuple(block_out_channels),
"down_block_types": down_block_types,
"up_block_types": up_block_types,
"block_out_channels": block_out_channels,
"latent_channels": vae_params["z_channels"],
"layers_per_block": vae_params["num_res_blocks"],
"scaling_factor": scaling_factor,
......
......@@ -1088,6 +1088,8 @@ class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):
for param_name, param_value in single_file_pipe.unet.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
if param_name == "upcast_attention" and pipe.unet.config[param_name] is None:
pipe.unet.config[param_name] = False
assert (
pipe.unet.config[param_name] == param_value
), f"{param_name} is differs between single file loading and pretrained loading"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment