Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
abd86d1c
Unverified
Commit
abd86d1c
authored
May 06, 2023
by
Sanchit Gandhi
Committed by
GitHub
May 06, 2023
Browse files
[AudioLDM] Generalise conversion script (#3328)
Co-authored-by:
Patrick von Platen
<
patrick.v.platen@gmail.com
>
parent
e9aa0925
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
54 additions
and
17 deletions
+54
-17
scripts/convert_original_audioldm_to_diffusers.py
scripts/convert_original_audioldm_to_diffusers.py
+54
-17
No files found.
scripts/convert_original_audioldm_to_diffusers.py
View file @
abd86d1c
...
@@ -774,6 +774,8 @@ def load_pipeline_from_original_audioldm_ckpt(
...
@@ -774,6 +774,8 @@ def load_pipeline_from_original_audioldm_ckpt(
extract_ema
:
bool
=
False
,
extract_ema
:
bool
=
False
,
scheduler_type
:
str
=
"ddim"
,
scheduler_type
:
str
=
"ddim"
,
num_in_channels
:
int
=
None
,
num_in_channels
:
int
=
None
,
model_channels
:
int
=
None
,
num_head_channels
:
int
=
None
,
device
:
str
=
None
,
device
:
str
=
None
,
from_safetensors
:
bool
=
False
,
from_safetensors
:
bool
=
False
,
)
->
AudioLDMPipeline
:
)
->
AudioLDMPipeline
:
...
@@ -784,23 +786,36 @@ def load_pipeline_from_original_audioldm_ckpt(
...
@@ -784,23 +786,36 @@ def load_pipeline_from_original_audioldm_ckpt(
global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
recommended that you override the default values and/or supply an `original_config_file` wherever possible.
recommended that you override the default values and/or supply an `original_config_file` wherever possible.
:param checkpoint_path: Path to `.ckpt` file. :param original_config_file: Path to `.yaml` config file
Args:
corresponding to the original architecture.
checkpoint_path (`str`): Path to `.ckpt` file.
If `None`, will be automatically instantiated based on default values.
original_config_file (`str`):
:param image_size: The image size that the model was trained on. Use 512 for original AudioLDM checkpoints. :param
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for original
set to the audioldm-s-full-v2 config.
AudioLDM checkpoints.
image_size (`int`, *optional*, defaults to 512):
:param num_in_channels: The number of input channels. If `None` number of input channels will be automatically
The image size that the model was trained on.
inferred.
prediction_type (`str`, *optional*):
:param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler",
The prediction type that the model was trained on. If `None`, will be automatically
"euler-ancestral", "dpm", "ddim"]`.
inferred by looking for a key in the config. For the default config, the prediction type is `'epsilon'`.
:param extract_ema: Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract
num_in_channels (`int`, *optional*, defaults to None):
the EMA weights or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually
The number of UNet input channels. If `None`, it will be automatically inferred from the config.
yield higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning.
model_channels (`int`, *optional*, defaults to None):
:param device: The device to use. Pass `None` to determine automatically. :param from_safetensors: If
The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override
`checkpoint_path` is in `safetensors` format, load checkpoint with safetensors
to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large.
instead of PyTorch.
num_head_channels (`int`, *optional*, defaults to None):
:return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override
to 32 for the small and medium checkpoints, and 64 for the large.
scheduler_type (`str`, *optional*, defaults to 'pndm'):
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
"ddim"]`.
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
`False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
inference. Non-EMA weights are usually better to continue fine-tuning.
device (`str`, *optional*, defaults to `None`):
The device to use. Pass `None` to determine automatically.
from_safetensors (`str`, *optional*, defaults to `False`):
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
"""
"""
if
not
is_omegaconf_available
():
if
not
is_omegaconf_available
():
...
@@ -837,6 +852,12 @@ def load_pipeline_from_original_audioldm_ckpt(
...
@@ -837,6 +852,12 @@ def load_pipeline_from_original_audioldm_ckpt(
if
num_in_channels
is
not
None
:
if
num_in_channels
is
not
None
:
original_config
[
"model"
][
"params"
][
"unet_config"
][
"params"
][
"in_channels"
]
=
num_in_channels
original_config
[
"model"
][
"params"
][
"unet_config"
][
"params"
][
"in_channels"
]
=
num_in_channels
if
model_channels
is
not
None
:
original_config
[
"model"
][
"params"
][
"unet_config"
][
"params"
][
"model_channels"
]
=
model_channels
if
num_head_channels
is
not
None
:
original_config
[
"model"
][
"params"
][
"unet_config"
][
"params"
][
"num_head_channels"
]
=
num_head_channels
if
(
if
(
"parameterization"
in
original_config
[
"model"
][
"params"
]
"parameterization"
in
original_config
[
"model"
][
"params"
]
and
original_config
[
"model"
][
"params"
][
"parameterization"
]
==
"v"
and
original_config
[
"model"
][
"params"
][
"parameterization"
]
==
"v"
...
@@ -960,6 +981,20 @@ if __name__ == "__main__":
...
@@ -960,6 +981,20 @@ if __name__ == "__main__":
type
=
int
,
type
=
int
,
help
=
"The number of input channels. If `None` number of input channels will be automatically inferred."
,
help
=
"The number of input channels. If `None` number of input channels will be automatically inferred."
,
)
)
parser
.
add_argument
(
"--model_channels"
,
default
=
None
,
type
=
int
,
help
=
"The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override"
" to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large."
,
)
parser
.
add_argument
(
"--num_head_channels"
,
default
=
None
,
type
=
int
,
help
=
"The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override"
" to 32 for the small and medium checkpoints, and 64 for the large."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--scheduler_type"
,
"--scheduler_type"
,
default
=
"ddim"
,
default
=
"ddim"
,
...
@@ -1009,6 +1044,8 @@ if __name__ == "__main__":
...
@@ -1009,6 +1044,8 @@ if __name__ == "__main__":
extract_ema
=
args
.
extract_ema
,
extract_ema
=
args
.
extract_ema
,
scheduler_type
=
args
.
scheduler_type
,
scheduler_type
=
args
.
scheduler_type
,
num_in_channels
=
args
.
num_in_channels
,
num_in_channels
=
args
.
num_in_channels
,
model_channels
=
args
.
model_channels
,
num_head_channels
=
args
.
num_head_channels
,
from_safetensors
=
args
.
from_safetensors
,
from_safetensors
=
args
.
from_safetensors
,
device
=
args
.
device
,
device
=
args
.
device
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment