".github/vscode:/vscode.git/clone" did not exist on "5a75fa9f1a21bf981f770655df848eedd0854799"
Unverified Commit e65b71ab authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Add an explicit `--image_size` to the conversion script (#1509)

* Add an explicit `--image_size` to the conversion script

* style
parent a6a25ceb
...@@ -207,12 +207,12 @@ def conv_attn_to_linear(checkpoint): ...@@ -207,12 +207,12 @@ def conv_attn_to_linear(checkpoint):
checkpoint[key] = checkpoint[key][:, :, 0] checkpoint[key] = checkpoint[key][:, :, 0]
def create_unet_diffusers_config(original_config): def create_unet_diffusers_config(original_config, image_size: int):
""" """
Creates a config for the diffusers based on the config of the LDM model. Creates a config for the diffusers based on the config of the LDM model.
""" """
model_params = original_config.model.params
unet_params = original_config.model.params.unet_config.params unet_params = original_config.model.params.unet_config.params
vae_params = original_config.model.params.first_stage_config.params.ddconfig
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
...@@ -230,8 +230,10 @@ def create_unet_diffusers_config(original_config): ...@@ -230,8 +230,10 @@ def create_unet_diffusers_config(original_config):
up_block_types.append(block_type) up_block_types.append(block_type)
resolution //= 2 resolution //= 2
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
config = dict( config = dict(
sample_size=model_params.image_size, sample_size=image_size // vae_scale_factor,
in_channels=unet_params.in_channels, in_channels=unet_params.in_channels,
out_channels=unet_params.out_channels, out_channels=unet_params.out_channels,
down_block_types=tuple(down_block_types), down_block_types=tuple(down_block_types),
...@@ -245,7 +247,7 @@ def create_unet_diffusers_config(original_config): ...@@ -245,7 +247,7 @@ def create_unet_diffusers_config(original_config):
return config return config
def create_vae_diffusers_config(original_config): def create_vae_diffusers_config(original_config, image_size: int):
""" """
Creates a config for the diffusers based on the config of the LDM model. Creates a config for the diffusers based on the config of the LDM model.
""" """
...@@ -257,7 +259,7 @@ def create_vae_diffusers_config(original_config): ...@@ -257,7 +259,7 @@ def create_vae_diffusers_config(original_config):
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
config = dict( config = dict(
sample_size=vae_params.resolution, sample_size=image_size,
in_channels=vae_params.in_channels, in_channels=vae_params.in_channels,
out_channels=vae_params.out_ch, out_channels=vae_params.out_ch,
down_block_types=tuple(down_block_types), down_block_types=tuple(down_block_types),
...@@ -653,6 +655,15 @@ if __name__ == "__main__": ...@@ -653,6 +655,15 @@ if __name__ == "__main__":
type=str, type=str,
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancest', 'dpm']", help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancest', 'dpm']",
) )
parser.add_argument(
"--image_size",
default=512,
type=int,
help=(
"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
" Base. Use 768 for Stable Diffusion v2."
),
)
parser.add_argument( parser.add_argument(
"--extract_ema", "--extract_ema",
action="store_true", action="store_true",
...@@ -712,7 +723,7 @@ if __name__ == "__main__": ...@@ -712,7 +723,7 @@ if __name__ == "__main__":
raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!") raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
# Convert the UNet2DConditionModel model. # Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config) unet_config = create_unet_diffusers_config(original_config, image_size=args.image_size)
converted_unet_checkpoint = convert_ldm_unet_checkpoint( converted_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema checkpoint, unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema
) )
...@@ -721,7 +732,7 @@ if __name__ == "__main__": ...@@ -721,7 +732,7 @@ if __name__ == "__main__":
unet.load_state_dict(converted_unet_checkpoint) unet.load_state_dict(converted_unet_checkpoint)
# Convert the VAE model. # Convert the VAE model.
vae_config = create_vae_diffusers_config(original_config) vae_config = create_vae_diffusers_config(original_config, image_size=args.image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
vae = AutoencoderKL(**vae_config) vae = AutoencoderKL(**vae_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment