Unverified Commit abd86d1c authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[AudioLDM] Generalise conversion script (#3328)


Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent e9aa0925
...@@ -774,6 +774,8 @@ def load_pipeline_from_original_audioldm_ckpt( ...@@ -774,6 +774,8 @@ def load_pipeline_from_original_audioldm_ckpt(
extract_ema: bool = False, extract_ema: bool = False,
scheduler_type: str = "ddim", scheduler_type: str = "ddim",
num_in_channels: int = None, num_in_channels: int = None,
model_channels: int = None,
num_head_channels: int = None,
device: str = None, device: str = None,
from_safetensors: bool = False, from_safetensors: bool = False,
) -> AudioLDMPipeline: ) -> AudioLDMPipeline:
...@@ -784,23 +786,36 @@ def load_pipeline_from_original_audioldm_ckpt( ...@@ -784,23 +786,36 @@ def load_pipeline_from_original_audioldm_ckpt(
global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
recommended that you override the default values and/or supply an `original_config_file` wherever possible. recommended that you override the default values and/or supply an `original_config_file` wherever possible.
:param checkpoint_path: Path to `.ckpt` file. :param original_config_file: Path to `.yaml` config file Args:
corresponding to the original architecture. checkpoint_path (`str`): Path to `.ckpt` file.
If `None`, will be automatically instantiated based on default values. original_config_file (`str`):
:param image_size: The image size that the model was trained on. Use 512 for original AudioLDM checkpoints. :param Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for original set to the audioldm-s-full-v2 config.
AudioLDM checkpoints. image_size (`int`, *optional*, defaults to 512):
:param num_in_channels: The number of input channels. If `None` number of input channels will be automatically The image size that the model was trained on.
inferred. prediction_type (`str`, *optional*):
:param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", The prediction type that the model was trained on. If `None`, will be automatically
"euler-ancestral", "dpm", "ddim"]`. inferred by looking for a key in the config. For the default config, the prediction type is `'epsilon'`.
:param extract_ema: Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract num_in_channels (`int`, *optional*, defaults to None):
the EMA weights or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually The number of UNet input channels. If `None`, it will be automatically inferred from the config.
yield higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning. model_channels (`int`, *optional*, defaults to None):
:param device: The device to use. Pass `None` to determine automatically. :param from_safetensors: If The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override
`checkpoint_path` is in `safetensors` format, load checkpoint with safetensors to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large.
instead of PyTorch. num_head_channels (`int`, *optional*, defaults to None):
:return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file. The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override
to 32 for the small and medium checkpoints, and 64 for the large.
scheduler_type (`str`, *optional*, defaults to 'pndm'):
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
"ddim"]`.
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
`False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
inference. Non-EMA weights are usually better to continue fine-tuning.
device (`str`, *optional*, defaults to `None`):
The device to use. Pass `None` to determine automatically.
from_safetensors (`str`, *optional*, defaults to `False`):
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
""" """
if not is_omegaconf_available(): if not is_omegaconf_available():
...@@ -837,6 +852,12 @@ def load_pipeline_from_original_audioldm_ckpt( ...@@ -837,6 +852,12 @@ def load_pipeline_from_original_audioldm_ckpt(
if num_in_channels is not None: if num_in_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
if model_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["model_channels"] = model_channels
if num_head_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["num_head_channels"] = num_head_channels
if ( if (
"parameterization" in original_config["model"]["params"] "parameterization" in original_config["model"]["params"]
and original_config["model"]["params"]["parameterization"] == "v" and original_config["model"]["params"]["parameterization"] == "v"
...@@ -960,6 +981,20 @@ if __name__ == "__main__": ...@@ -960,6 +981,20 @@ if __name__ == "__main__":
type=int, type=int,
help="The number of input channels. If `None` number of input channels will be automatically inferred.", help="The number of input channels. If `None` number of input channels will be automatically inferred.",
) )
parser.add_argument(
"--model_channels",
default=None,
type=int,
help="The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override"
" to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large.",
)
parser.add_argument(
"--num_head_channels",
default=None,
type=int,
help="The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override"
" to 32 for the small and medium checkpoints, and 64 for the large.",
)
parser.add_argument( parser.add_argument(
"--scheduler_type", "--scheduler_type",
default="ddim", default="ddim",
...@@ -1009,6 +1044,8 @@ if __name__ == "__main__": ...@@ -1009,6 +1044,8 @@ if __name__ == "__main__":
extract_ema=args.extract_ema, extract_ema=args.extract_ema,
scheduler_type=args.scheduler_type, scheduler_type=args.scheduler_type,
num_in_channels=args.num_in_channels, num_in_channels=args.num_in_channels,
model_channels=args.model_channels,
num_head_channels=args.num_head_channels,
from_safetensors=args.from_safetensors, from_safetensors=args.from_safetensors,
device=args.device, device=args.device,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment