Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
abd86d1c
Unverified
Commit
abd86d1c
authored
May 06, 2023
by
Sanchit Gandhi
Committed by
GitHub
May 06, 2023
Browse files
[AudioLDM] Generalise conversion script (#3328)
Co-authored-by:
Patrick von Platen
<
patrick.v.platen@gmail.com
>
parent
e9aa0925
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
54 additions
and
17 deletions
+54
-17
scripts/convert_original_audioldm_to_diffusers.py
scripts/convert_original_audioldm_to_diffusers.py
+54
-17
No files found.
scripts/convert_original_audioldm_to_diffusers.py
View file @
abd86d1c
...
@@ -774,6 +774,8 @@ def load_pipeline_from_original_audioldm_ckpt(
...
@@ -774,6 +774,8 @@ def load_pipeline_from_original_audioldm_ckpt(
extract_ema
:
bool
=
False
,
extract_ema
:
bool
=
False
,
scheduler_type
:
str
=
"ddim"
,
scheduler_type
:
str
=
"ddim"
,
num_in_channels
:
int
=
None
,
num_in_channels
:
int
=
None
,
model_channels
:
int
=
None
,
num_head_channels
:
int
=
None
,
device
:
str
=
None
,
device
:
str
=
None
,
from_safetensors
:
bool
=
False
,
from_safetensors
:
bool
=
False
,
)
->
AudioLDMPipeline
:
)
->
AudioLDMPipeline
:
...
@@ -784,23 +786,36 @@ def load_pipeline_from_original_audioldm_ckpt(
...
@@ -784,23 +786,36 @@ def load_pipeline_from_original_audioldm_ckpt(
global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
recommended that you override the default values and/or supply an `original_config_file` wherever possible.
recommended that you override the default values and/or supply an `original_config_file` wherever possible.
:param checkpoint_path: Path to `.ckpt` file. :param original_config_file: Path to `.yaml` config file
Args:
corresponding to the original architecture.
checkpoint_path (`str`): Path to `.ckpt` file.
If `None`, will be automatically instantiated based on default values.
original_config_file (`str`):
:param image_size: The image size that the model was trained on. Use 512 for original AudioLDM checkpoints. :param
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for original
set to the audioldm-s-full-v2 config.
AudioLDM checkpoints.
image_size (`int`, *optional*, defaults to 512):
:param num_in_channels: The number of input channels. If `None` number of input channels will be automatically
The image size that the model was trained on.
inferred.
prediction_type (`str`, *optional*):
:param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler",
The prediction type that the model was trained on. If `None`, will be automatically
"euler-ancestral", "dpm", "ddim"]`.
inferred by looking for a key in the config. For the default config, the prediction type is `'epsilon'`.
:param extract_ema: Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract
num_in_channels (`int`, *optional*, defaults to None):
the EMA weights or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually
The number of UNet input channels. If `None`, it will be automatically inferred from the config.
yield higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning.
model_channels (`int`, *optional*, defaults to None):
:param device: The device to use. Pass `None` to determine automatically. :param from_safetensors: If
The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override
`checkpoint_path` is in `safetensors` format, load checkpoint with safetensors
to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large.
instead of PyTorch.
num_head_channels (`int`, *optional*, defaults to None):
:return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override
to 32 for the small and medium checkpoints, and 64 for the large.
scheduler_type (`str`, *optional*, defaults to 'pndm'):
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
"ddim"]`.
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
`False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
inference. Non-EMA weights are usually better to continue fine-tuning.
device (`str`, *optional*, defaults to `None`):
The device to use. Pass `None` to determine automatically.
from_safetensors (`str`, *optional*, defaults to `False`):
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
"""
"""
if
not
is_omegaconf_available
():
if
not
is_omegaconf_available
():
...
@@ -837,6 +852,12 @@ def load_pipeline_from_original_audioldm_ckpt(
...
@@ -837,6 +852,12 @@ def load_pipeline_from_original_audioldm_ckpt(
if
num_in_channels
is
not
None
:
if
num_in_channels
is
not
None
:
original_config
[
"model"
][
"params"
][
"unet_config"
][
"params"
][
"in_channels"
]
=
num_in_channels
original_config
[
"model"
][
"params"
][
"unet_config"
][
"params"
][
"in_channels"
]
=
num_in_channels
if
model_channels
is
not
None
:
original_config
[
"model"
][
"params"
][
"unet_config"
][
"params"
][
"model_channels"
]
=
model_channels
if
num_head_channels
is
not
None
:
original_config
[
"model"
][
"params"
][
"unet_config"
][
"params"
][
"num_head_channels"
]
=
num_head_channels
if
(
if
(
"parameterization"
in
original_config
[
"model"
][
"params"
]
"parameterization"
in
original_config
[
"model"
][
"params"
]
and
original_config
[
"model"
][
"params"
][
"parameterization"
]
==
"v"
and
original_config
[
"model"
][
"params"
][
"parameterization"
]
==
"v"
...
@@ -960,6 +981,20 @@ if __name__ == "__main__":
...
@@ -960,6 +981,20 @@ if __name__ == "__main__":
type
=
int
,
type
=
int
,
help
=
"The number of input channels. If `None` number of input channels will be automatically inferred."
,
help
=
"The number of input channels. If `None` number of input channels will be automatically inferred."
,
)
)
parser
.
add_argument
(
"--model_channels"
,
default
=
None
,
type
=
int
,
help
=
"The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override"
" to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large."
,
)
parser
.
add_argument
(
"--num_head_channels"
,
default
=
None
,
type
=
int
,
help
=
"The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override"
" to 32 for the small and medium checkpoints, and 64 for the large."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--scheduler_type"
,
"--scheduler_type"
,
default
=
"ddim"
,
default
=
"ddim"
,
...
@@ -1009,6 +1044,8 @@ if __name__ == "__main__":
...
@@ -1009,6 +1044,8 @@ if __name__ == "__main__":
extract_ema
=
args
.
extract_ema
,
extract_ema
=
args
.
extract_ema
,
scheduler_type
=
args
.
scheduler_type
,
scheduler_type
=
args
.
scheduler_type
,
num_in_channels
=
args
.
num_in_channels
,
num_in_channels
=
args
.
num_in_channels
,
model_channels
=
args
.
model_channels
,
num_head_channels
=
args
.
num_head_channels
,
from_safetensors
=
args
.
from_safetensors
,
from_safetensors
=
args
.
from_safetensors
,
device
=
args
.
device
,
device
=
args
.
device
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment