Unverified Commit e65b71ab authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Add an explicit `--image_size` to the conversion script (#1509)

* Add an explicit `--image_size` to the conversion script

* style
parent a6a25ceb
......@@ -207,12 +207,12 @@ def conv_attn_to_linear(checkpoint):
checkpoint[key] = checkpoint[key][:, :, 0]
def create_unet_diffusers_config(original_config):
def create_unet_diffusers_config(original_config, image_size: int):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
model_params = original_config.model.params
unet_params = original_config.model.params.unet_config.params
vae_params = original_config.model.params.first_stage_config.params.ddconfig
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
......@@ -230,8 +230,10 @@ def create_unet_diffusers_config(original_config):
up_block_types.append(block_type)
resolution //= 2
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
config = dict(
sample_size=model_params.image_size,
sample_size=image_size // vae_scale_factor,
in_channels=unet_params.in_channels,
out_channels=unet_params.out_channels,
down_block_types=tuple(down_block_types),
......@@ -245,7 +247,7 @@ def create_unet_diffusers_config(original_config):
return config
def create_vae_diffusers_config(original_config):
def create_vae_diffusers_config(original_config, image_size: int):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
......@@ -257,7 +259,7 @@ def create_vae_diffusers_config(original_config):
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
config = dict(
sample_size=vae_params.resolution,
sample_size=image_size,
in_channels=vae_params.in_channels,
out_channels=vae_params.out_ch,
down_block_types=tuple(down_block_types),
......@@ -653,6 +655,15 @@ if __name__ == "__main__":
type=str,
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancest', 'dpm']",
)
parser.add_argument(
"--image_size",
default=512,
type=int,
help=(
"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
" Base. Use 768 for Stable Diffusion v2."
),
)
parser.add_argument(
"--extract_ema",
action="store_true",
......@@ -712,7 +723,7 @@ if __name__ == "__main__":
raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config)
unet_config = create_unet_diffusers_config(original_config, image_size=args.image_size)
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema
)
......@@ -721,7 +732,7 @@ if __name__ == "__main__":
unet.load_state_dict(converted_unet_checkpoint)
# Convert the VAE model.
vae_config = create_vae_diffusers_config(original_config)
vae_config = create_vae_diffusers_config(original_config, image_size=args.image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
vae = AutoencoderKL(**vae_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment