Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
VibeVoice_pytorch
Commits
b4af4e0c
Commit
b4af4e0c
authored
Sep 01, 2025
by
luopl
Browse files
"Initial commit"
parents
Changes
48
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
2415 additions
and
0 deletions
+2415
-0
vibevoice/processor/__init__.py
vibevoice/processor/__init__.py
+0
-0
vibevoice/processor/vibevoice_processor.py
vibevoice/processor/vibevoice_processor.py
+678
-0
vibevoice/processor/vibevoice_tokenizer_processor.py
vibevoice/processor/vibevoice_tokenizer_processor.py
+484
-0
vibevoice/schedule/__init__.py
vibevoice/schedule/__init__.py
+0
-0
vibevoice/schedule/dpm_solver.py
vibevoice/schedule/dpm_solver.py
+1066
-0
vibevoice/schedule/timestep_sampler.py
vibevoice/schedule/timestep_sampler.py
+20
-0
vibevoice/scripts/__init__.py
vibevoice/scripts/__init__.py
+0
-0
vibevoice/scripts/convert_nnscaler_checkpoint_to_transformers.py
...ce/scripts/convert_nnscaler_checkpoint_to_transformers.py
+167
-0
No files found.
vibevoice/processor/__init__.py
0 → 100644
View file @
b4af4e0c
vibevoice/processor/vibevoice_processor.py
0 → 100644
View file @
b4af4e0c
import
math
import
warnings
from
typing
import
List
,
Optional
,
Union
,
Dict
,
Any
,
Tuple
import
os
import
re
import
numpy
as
np
import
torch
from
transformers.tokenization_utils_base
import
BatchEncoding
,
PaddingStrategy
,
PreTokenizedInput
,
TextInput
,
TruncationStrategy
from
transformers.utils
import
TensorType
,
logging
from
.vibevoice_tokenizer_processor
import
AudioNormalizer
logger
=
logging
.
get_logger
(
__name__
)
class
VibeVoiceProcessor
:
r
"""
Constructs a VibeVoice processor which wraps a VibeVoice tokenizer and audio processor into a single processor.
[`VibeVoiceProcessor`] offers all the functionalities of [`VibeVoiceTokenizer`] and [`VibeVoiceTokenizerProcessor`].
See the [`~VibeVoiceProcessor.__call__`] and [`~VibeVoiceProcessor.decode`] for more information.
Args:
tokenizer (`VibeVoiceTextTokenizer` or `VibeVoiceTextTokenizerFast`):
The tokenizer for text processing.
audio_processor (`VibeVoiceTokenizerProcessor`):
The audio processor for speech processing.
speech_tok_compress_ratio (`int`, *optional*, defaults to 3200):
The compression ratio for speech tokenization.
db_normalize (`bool`, *optional*, defaults to True):
Whether to apply decibel normalization to audio inputs.
"""
def
__init__
(
self
,
tokenizer
=
None
,
audio_processor
=
None
,
speech_tok_compress_ratio
=
3200
,
db_normalize
=
True
,
**
kwargs
):
self
.
tokenizer
=
tokenizer
self
.
audio_processor
=
audio_processor
self
.
speech_tok_compress_ratio
=
speech_tok_compress_ratio
self
.
db_normalize
=
db_normalize
self
.
audio_normalizer
=
AudioNormalizer
()
if
db_normalize
else
None
self
.
system_prompt
=
" Transform the text provided by various speakers into speech output, utilizing the distinct voice of each respective speaker.
\n
"
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
**
kwargs
):
"""
Instantiate a VibeVoiceProcessor from a pretrained VibeVoice processor.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a pretrained model
- a path to a *directory* containing processor config
Returns:
[`VibeVoiceProcessor`]: The processor object instantiated from pretrained model.
"""
import
os
import
json
from
.vibevoice_tokenizer_processor
import
VibeVoiceTokenizerProcessor
from
vibevoice.modular.modular_vibevoice_text_tokenizer
import
(
VibeVoiceTextTokenizer
,
VibeVoiceTextTokenizerFast
)
# Load processor configuration
config_path
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
"preprocessor_config.json"
)
if
os
.
path
.
exists
(
config_path
):
with
open
(
config_path
,
'r'
)
as
f
:
config
=
json
.
load
(
f
)
else
:
logger
.
warning
(
f
"No preprocessor_config.json found at
{
pretrained_model_name_or_path
}
, using defaults"
)
config
=
{
"speech_tok_compress_ratio"
:
3200
,
"db_normalize"
:
True
,
}
# Extract main processor parameters
speech_tok_compress_ratio
=
config
.
get
(
"speech_tok_compress_ratio"
,
3200
)
db_normalize
=
config
.
get
(
"db_normalize"
,
True
)
# Load tokenizer - try from model path first, then fallback to Qwen
language_model_pretrained_name
=
config
.
get
(
"language_model_pretrained_name"
,
None
)
or
kwargs
.
pop
(
"language_model_pretrained_name"
,
"Qwen/Qwen2.5-1.5B"
)
logger
.
info
(
f
"Loading tokenizer from
{
language_model_pretrained_name
}
"
)
if
'qwen'
in
language_model_pretrained_name
.
lower
():
tokenizer
=
VibeVoiceTextTokenizerFast
.
from_pretrained
(
language_model_pretrained_name
,
**
kwargs
)
else
:
raise
ValueError
(
f
"Unsupported tokenizer type for
{
language_model_pretrained_name
}
. Supported types: Qwen, Llama, Gemma."
)
# Load audio processor
if
"audio_processor"
in
config
:
# Create audio processor from config
audio_config
=
config
[
"audio_processor"
]
audio_processor
=
VibeVoiceTokenizerProcessor
(
sampling_rate
=
audio_config
.
get
(
"sampling_rate"
,
24000
),
normalize_audio
=
audio_config
.
get
(
"normalize_audio"
,
True
),
target_dB_FS
=
audio_config
.
get
(
"target_dB_FS"
,
-
25
),
eps
=
audio_config
.
get
(
"eps"
,
1e-6
),
)
else
:
# Create default audio processor
audio_processor
=
VibeVoiceTokenizerProcessor
()
# Create and return the processor
return
cls
(
tokenizer
=
tokenizer
,
audio_processor
=
audio_processor
,
speech_tok_compress_ratio
=
speech_tok_compress_ratio
,
db_normalize
=
db_normalize
,
)
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
],
**
kwargs
):
"""
Save a processor to a directory, so that it can be re-loaded using the
[`~VibeVoiceProcessor.from_pretrained`] class method.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the processor will be saved.
"""
import
os
import
json
os
.
makedirs
(
save_directory
,
exist_ok
=
True
)
# Save processor configuration
processor_config
=
{
"processor_class"
:
"VibeVoiceProcessor"
,
"speech_tok_compress_ratio"
:
self
.
speech_tok_compress_ratio
,
"db_normalize"
:
self
.
db_normalize
,
"audio_processor"
:
{
"feature_extractor_type"
:
"VibeVoiceTokenizerProcessor"
,
"sampling_rate"
:
getattr
(
self
.
audio_processor
,
'sampling_rate'
,
24000
),
"normalize_audio"
:
getattr
(
self
.
audio_processor
,
'normalize_audio'
,
True
),
"target_dB_FS"
:
getattr
(
self
.
audio_processor
,
'target_dB_FS'
,
-
25
),
"eps"
:
getattr
(
self
.
audio_processor
,
'eps'
,
1e-6
),
}
}
config_path
=
os
.
path
.
join
(
save_directory
,
"preprocessor_config.json"
)
with
open
(
config_path
,
'w'
)
as
f
:
json
.
dump
(
processor_config
,
f
,
indent
=
2
)
logger
.
info
(
f
"Processor configuration saved in
{
config_path
}
"
)
def
__call__
(
self
,
text
:
Optional
[
Union
[
str
,
List
[
str
],
TextInput
,
PreTokenizedInput
,
List
[
TextInput
],
List
[
PreTokenizedInput
]]]
=
None
,
voice_samples
:
Optional
[
Union
[
List
[
Union
[
str
,
np
.
ndarray
]],
List
[
List
[
Union
[
str
,
np
.
ndarray
]]]]]
=
None
,
padding
:
Union
[
bool
,
str
,
PaddingStrategy
]
=
True
,
truncation
:
Union
[
bool
,
str
,
TruncationStrategy
]
=
False
,
max_length
:
Optional
[
int
]
=
None
,
return_tensors
:
Optional
[
Union
[
str
,
TensorType
]]
=
None
,
return_attention_mask
:
bool
=
True
,
**
kwargs
,
)
->
BatchEncoding
:
"""
Main method to process one or more podcast scripts with optional voice samples.
Args:
text (`str`, `List[str]`):
The input text(s) to process. Can be:
- A single script string
- A list of script strings for batch processing
- A path to a .json or .txt file
- A list of paths
voice_samples (`List[Union[str, np.ndarray]]`, `List[List[Union[str, np.ndarray]]]`, *optional*):
Voice samples for each script. Can be:
- A list of samples for a single script
- A list of lists for batch processing
padding (`bool`, `str` or `PaddingStrategy`, defaults to `True`):
Whether to pad sequences to the same length
truncation (`bool`, `str` or `TruncationStrategy`, defaults to `False`):
Whether to truncate sequences
max_length (`int`, *optional*):
Maximum length of the returned sequences
return_tensors (`str` or `TensorType`, *optional*):
If set, will return tensors of a particular framework
return_attention_mask (`bool`, defaults to `True`):
Whether to return the attention mask
Returns:
`BatchEncoding`: A BatchEncoding with the following fields:
- **input_ids** -- List of token id sequences or tensor
- **attention_mask** -- List of attention masks or tensor
- **speech_tensors** -- Padded speech inputs (if voice_samples provided)
- **speech_masks** -- Speech masks (if voice_samples provided)
- **speech_input_mask** -- Boolean masks indicating speech token positions
"""
# Handle single vs batch input
if
isinstance
(
text
,
str
)
or
(
isinstance
(
text
,
list
)
and
len
(
text
)
>
0
and
not
isinstance
(
text
[
0
],
str
)):
# Single input
texts
=
[
text
]
is_batched
=
False
else
:
# Batch input
texts
=
text
is_batched
=
True
# Handle voice samples
if
voice_samples
is
not
None
:
if
not
is_batched
or
(
isinstance
(
voice_samples
[
0
],
(
str
,
np
.
ndarray
))):
# Single set of voice samples
voice_samples_list
=
[
voice_samples
]
else
:
# Batch of voice samples
voice_samples_list
=
voice_samples
else
:
voice_samples_list
=
[
None
]
*
len
(
texts
)
# Process each input
all_encodings
=
[]
for
text_input
,
voice_input
in
zip
(
texts
,
voice_samples_list
):
encoding
=
self
.
_process_single
(
text_input
,
voice_input
)
all_encodings
.
append
(
encoding
)
# Combine batch
batch_encoding
=
self
.
_batch_encode
(
all_encodings
,
padding
=
padding
,
truncation
=
truncation
,
max_length
=
max_length
,
return_tensors
=
return_tensors
,
return_attention_mask
=
return_attention_mask
,
)
return
batch_encoding
def
_process_single
(
self
,
text
:
Union
[
str
,
TextInput
],
voice_samples
:
Optional
[
List
[
Union
[
str
,
np
.
ndarray
]]]
=
None
,
)
->
Dict
[
str
,
Any
]:
"""Process a single podcast script."""
# Determine if text is a file path or direct script
script
=
None
if
isinstance
(
text
,
str
):
# Check if it's a file path
if
text
.
endswith
(
'.json'
)
and
os
.
path
.
exists
(
text
):
script
=
self
.
_convert_json_to_script
(
text
)
elif
text
.
endswith
(
'.txt'
)
and
os
.
path
.
exists
(
text
):
script
=
self
.
_convert_text_to_script
(
text
)
else
:
# Assume it's the script content directly
script
=
text
if
script
is
None
:
raise
ValueError
(
f
"Could not process input text:
{
text
}
"
)
# Parse the script
parsed_lines
=
self
.
_parse_script
(
script
)
all_speakers
=
list
(
set
(
speaker_id
for
speaker_id
,
_
in
parsed_lines
))
# Create system prompt
# system_tokens = self.tokenizer.encode(self.system_prompt, add_special_tokens=False)
system_tokens
=
self
.
tokenizer
.
encode
(
self
.
system_prompt
)
# Process voice samples if provided
if
voice_samples
:
voice_tokens
,
voice_speech_inputs
,
voice_speech_masks
=
self
.
_create_voice_prompt
(
voice_samples
[:
len
(
all_speakers
)])
else
:
voice_tokens
,
voice_speech_inputs
,
voice_speech_masks
=
[],
[],
[]
# Build full token sequence
full_tokens
=
system_tokens
+
voice_tokens
speech_input_mask
=
[
False
]
*
len
(
system_tokens
)
+
voice_speech_masks
# Add text input section
full_tokens
+=
self
.
tokenizer
.
encode
(
' Text input:
\n
'
,
add_special_tokens
=
False
)
speech_input_mask
+=
[
False
]
*
len
(
self
.
tokenizer
.
encode
(
' Text input:
\n
'
,
add_special_tokens
=
False
))
for
speaker_id
,
speaker_text
in
parsed_lines
:
speaker_text_tokens
=
self
.
tokenizer
.
encode
(
f
" Speaker
{
speaker_id
}
:
{
speaker_text
}
\n
"
,
add_special_tokens
=
False
)
full_tokens
+=
speaker_text_tokens
speech_input_mask
+=
[
False
]
*
len
(
speaker_text_tokens
)
# Add speech output section
full_tokens
+=
self
.
tokenizer
.
encode
(
' Speech output:
\n
'
,
add_special_tokens
=
False
)
+
[
self
.
tokenizer
.
speech_start_id
]
speech_input_mask
+=
[
False
]
*
(
len
(
self
.
tokenizer
.
encode
(
' Speech output:
\n
'
,
add_special_tokens
=
False
))
+
1
)
return
{
"input_ids"
:
full_tokens
,
"speech_inputs"
:
voice_speech_inputs
if
voice_speech_inputs
else
None
,
"speech_input_mask"
:
speech_input_mask
,
"parsed_script"
:
parsed_lines
,
"all_speakers"
:
all_speakers
,
}
def
_batch_encode
(
self
,
encodings
:
List
[
Dict
[
str
,
Any
]],
padding
:
Union
[
bool
,
str
,
PaddingStrategy
]
=
True
,
truncation
:
Union
[
bool
,
str
,
TruncationStrategy
]
=
False
,
max_length
:
Optional
[
int
]
=
None
,
return_tensors
:
Optional
[
Union
[
str
,
TensorType
]]
=
None
,
return_attention_mask
:
bool
=
True
,
)
->
BatchEncoding
:
"""Combine multiple encodings into a batch with padding."""
# Extract input_ids and create attention_mask
input_ids_list
=
[
enc
[
"input_ids"
]
for
enc
in
encodings
]
speech_input_masks_list
=
[
enc
[
"speech_input_mask"
]
for
enc
in
encodings
]
# Determine padding strategy
if
isinstance
(
padding
,
bool
):
padding_strategy
=
PaddingStrategy
.
LONGEST
if
padding
else
PaddingStrategy
.
DO_NOT_PAD
elif
isinstance
(
padding
,
str
):
padding_strategy
=
PaddingStrategy
(
padding
)
else
:
padding_strategy
=
padding
# Apply padding to input_ids
if
padding_strategy
!=
PaddingStrategy
.
DO_NOT_PAD
:
if
padding_strategy
==
PaddingStrategy
.
LONGEST
:
max_len
=
max
(
len
(
ids
)
for
ids
in
input_ids_list
)
elif
padding_strategy
==
PaddingStrategy
.
MAX_LENGTH
and
max_length
is
not
None
:
max_len
=
max_length
else
:
max_len
=
max
(
len
(
ids
)
for
ids
in
input_ids_list
)
# Pad sequences
padded_input_ids
=
[]
attention_masks
=
[]
padded_speech_input_masks
=
[]
for
input_ids
,
speech_mask
in
zip
(
input_ids_list
,
speech_input_masks_list
):
# Truncate if needed
if
truncation
and
len
(
input_ids
)
>
max_len
:
input_ids
=
input_ids
[:
max_len
]
speech_mask
=
speech_mask
[:
max_len
]
# Pad
padding_length
=
max_len
-
len
(
input_ids
)
# padded_ids = [self.tokenizer.pad_token_id] * padding_length + input_ids
padded_ids
=
[
self
.
tokenizer
.
pad_id
]
*
padding_length
+
input_ids
attention_mask
=
[
0
]
*
padding_length
+
[
1
]
*
len
(
input_ids
)
padded_speech_mask
=
[
False
]
*
padding_length
+
speech_mask
padded_input_ids
.
append
(
padded_ids
)
attention_masks
.
append
(
attention_mask
)
padded_speech_input_masks
.
append
(
padded_speech_mask
)
input_ids_list
=
padded_input_ids
speech_input_masks_list
=
padded_speech_input_masks
else
:
# No padding, just create attention masks
attention_masks
=
[[
1
]
*
len
(
ids
)
for
ids
in
input_ids_list
]
if
return_attention_mask
else
None
# Process speech inputs
all_speech_inputs
=
[]
has_speech
=
False
for
enc
in
encodings
:
if
enc
[
"speech_inputs"
]
is
not
None
:
all_speech_inputs
.
extend
(
enc
[
"speech_inputs"
])
has_speech
=
True
# Prepare batch encoding
batch_encoding
=
BatchEncoding
()
# Handle tensor conversion
if
return_tensors
is
not
None
:
batch_encoding
[
"input_ids"
]
=
torch
.
tensor
(
input_ids_list
,
dtype
=
torch
.
long
)
if
return_attention_mask
and
attention_masks
is
not
None
:
batch_encoding
[
"attention_mask"
]
=
torch
.
tensor
(
attention_masks
,
dtype
=
torch
.
long
)
batch_encoding
[
"speech_input_mask"
]
=
torch
.
tensor
(
speech_input_masks_list
,
dtype
=
torch
.
bool
)
else
:
batch_encoding
[
"input_ids"
]
=
input_ids_list
if
return_attention_mask
and
attention_masks
is
not
None
:
batch_encoding
[
"attention_mask"
]
=
attention_masks
batch_encoding
[
"speech_input_mask"
]
=
speech_input_masks_list
# Process speech tensors if present
if
has_speech
:
speech_dict
=
self
.
prepare_speech_inputs
(
all_speech_inputs
,
return_tensors
=
return_tensors
,
)
batch_encoding
[
"speech_tensors"
]
=
speech_dict
[
"padded_speeches"
]
batch_encoding
[
"speech_masks"
]
=
speech_dict
[
"speech_masks"
]
else
:
batch_encoding
[
"speech_tensors"
]
=
None
batch_encoding
[
"speech_masks"
]
=
None
# Add metadata
batch_encoding
[
"parsed_scripts"
]
=
[
enc
[
"parsed_script"
]
for
enc
in
encodings
]
batch_encoding
[
"all_speakers_list"
]
=
[
enc
[
"all_speakers"
]
for
enc
in
encodings
]
return
batch_encoding
def
_create_voice_prompt
(
self
,
speaker_samples
:
List
[
Union
[
str
,
np
.
ndarray
]]
)
->
Tuple
[
List
[
int
],
List
[
np
.
ndarray
],
List
[
bool
]]:
"""
Create voice prompt tokens and process audio samples.
Returns:
tuple: (voice_tokens, voice_speech_inputs, voice_speech_masks)
"""
vae_token_id
=
self
.
tokenizer
.
speech_diffusion_id
voice_full_tokens
=
self
.
tokenizer
.
encode
(
' Voice input:
\n
'
,
add_special_tokens
=
False
)
voice_speech_inputs
=
[]
voice_speech_masks
=
[
False
]
*
len
(
voice_full_tokens
)
for
speaker_id
,
speaker_audio
in
enumerate
(
speaker_samples
):
prefix_tokens
=
self
.
tokenizer
.
encode
(
f
" Speaker
{
speaker_id
}
:"
,
add_special_tokens
=
False
)
# Process audio
if
isinstance
(
speaker_audio
,
str
):
# Load audio from file
wav
=
self
.
audio_processor
.
_load_audio_from_path
(
speaker_audio
)
else
:
wav
=
np
.
array
(
speaker_audio
,
dtype
=
np
.
float32
)
# Apply normalization if needed
if
self
.
db_normalize
and
self
.
audio_normalizer
:
wav
=
self
.
audio_normalizer
(
wav
)
# Calculate token length based on compression ratio
# if speaker_audio.endswith('.pt') or speaker_audio.endswith('.npy'):
# vae_tok_len = wav.shape[0]
# else:
vae_tok_len
=
math
.
ceil
(
wav
.
shape
[
0
]
/
self
.
speech_tok_compress_ratio
)
# Build tokens and masks
speaker_tokens
=
(
prefix_tokens
+
[
self
.
tokenizer
.
speech_start_id
]
+
[
vae_token_id
]
*
vae_tok_len
+
[
self
.
tokenizer
.
speech_end_id
]
+
self
.
tokenizer
.
encode
(
'
\n
'
,
add_special_tokens
=
False
))
vae_input_mask
=
([
False
]
*
len
(
prefix_tokens
)
+
[
False
]
+
[
True
]
*
vae_tok_len
+
[
False
]
+
[
False
])
voice_full_tokens
.
extend
(
speaker_tokens
)
voice_speech_masks
.
extend
(
vae_input_mask
)
voice_speech_inputs
.
append
(
wav
)
return
voice_full_tokens
,
voice_speech_inputs
,
voice_speech_masks
def
prepare_speech_inputs
(
self
,
speech_inputs
:
List
[
np
.
ndarray
],
return_tensors
:
Optional
[
Union
[
str
,
TensorType
]]
=
None
,
device
:
Optional
[
Union
[
str
,
torch
.
device
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
)
->
Dict
[
str
,
Any
]:
"""
Prepare speech inputs for model consumption.
Args:
speech_inputs: List of speech arrays
return_tensors: Output tensor type
device: Device to place tensors on
dtype: Data type for tensors
Returns:
Dictionary with padded_speeches and speech_masks
"""
if
not
speech_inputs
:
return
{
"padded_speeches"
:
None
,
"speech_masks"
:
None
}
# Calculate sequence lengths
vae_tok_seqlens
=
[
math
.
ceil
(
s
.
shape
[
0
]
/
self
.
speech_tok_compress_ratio
)
for
s
in
speech_inputs
]
# vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) if s.ndim == 1 else s.shape[0] for s in speech_inputs]
max_speech_length
=
max
(
s
.
shape
[
0
]
for
s
in
speech_inputs
)
# Pad speeches
if
speech_inputs
[
0
].
ndim
==
1
:
padded_speeches
=
np
.
full
((
len
(
speech_inputs
),
max_speech_length
),
fill_value
=
0
,
dtype
=
np
.
float32
)
else
:
padded_speeches
=
np
.
full
((
len
(
speech_inputs
),
max_speech_length
,
speech_inputs
[
0
].
shape
[
-
1
]),
fill_value
=
0
,
dtype
=
np
.
float32
)
speech_masks
=
np
.
zeros
((
len
(
speech_inputs
),
max
(
vae_tok_seqlens
)),
dtype
=
np
.
bool_
)
for
i
,
(
speech
,
vae_tok_length
)
in
enumerate
(
zip
(
speech_inputs
,
vae_tok_seqlens
)):
padded_speeches
[
i
,
:
len
(
speech
)]
=
speech
speech_masks
[
i
,
:
vae_tok_length
]
=
True
result
=
{
"padded_speeches"
:
padded_speeches
,
"speech_masks"
:
speech_masks
,
}
# Convert to tensors if requested
if
return_tensors
==
"pt"
:
result
[
"padded_speeches"
]
=
torch
.
tensor
(
padded_speeches
,
device
=
device
,
dtype
=
dtype
or
torch
.
float32
)
result
[
"speech_masks"
]
=
torch
.
tensor
(
speech_masks
,
device
=
device
,
dtype
=
torch
.
bool
)
return
result
def
_convert_json_to_script
(
self
,
json_file
:
str
)
->
str
:
"""
Convert JSON format to script format.
Expected JSON format:
[
{"speaker": "1", "text": "Hello everyone..."},
{"speaker": "2", "text": "Great to be here..."}
]
"""
import
json
with
open
(
json_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
data
=
json
.
load
(
f
)
if
not
isinstance
(
data
,
list
):
raise
ValueError
(
"JSON file must contain a list of speaker entries"
)
script_lines
=
[]
for
item
in
data
:
if
not
isinstance
(
item
,
dict
):
logger
.
warning
(
f
"Skipping non-dict entry:
{
item
}
"
)
continue
speaker
=
item
.
get
(
'speaker'
)
text
=
item
.
get
(
'text'
)
if
speaker
is
None
or
text
is
None
:
logger
.
warning
(
f
"Skipping entry missing speaker or text:
{
item
}
"
)
continue
# Ensure speaker ID is valid
try
:
speaker_id
=
int
(
speaker
)
except
(
ValueError
,
TypeError
):
logger
.
warning
(
f
"Invalid speaker ID:
{
speaker
}
, skipping entry"
)
continue
# Clean up text
text
=
text
.
strip
()
if
text
:
script_lines
.
append
(
f
"Speaker
{
speaker_id
}
:
{
text
}
"
)
if
not
script_lines
:
raise
ValueError
(
"No valid entries found in JSON file"
)
return
"
\n
"
.
join
(
script_lines
)
def
_convert_text_to_script
(
self
,
text_file
:
str
)
->
str
:
"""
Convert text file to script format.
Handles multiple formats:
1. Already formatted as "Speaker X: text"
2. Plain text (assigns to Speaker 1)
Handles edge cases like multiple colons in a line.
"""
with
open
(
text_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
lines
=
f
.
readlines
()
script_lines
=
[]
current_speaker
=
1
for
line
in
lines
:
line
=
line
.
strip
()
if
not
line
:
continue
# Try to parse as "Speaker X: text" format
# Use regex to be more robust
speaker_match
=
re
.
match
(
r
'^Speaker\s+(\d+)\s*:\s*(.*)$'
,
line
,
re
.
IGNORECASE
)
if
speaker_match
:
speaker_id
=
int
(
speaker_match
.
group
(
1
))
text
=
speaker_match
.
group
(
2
).
strip
()
if
text
:
script_lines
.
append
(
f
"Speaker
{
speaker_id
}
:
{
text
}
"
)
else
:
# Treat as plain text - assign to current speaker
script_lines
.
append
(
f
"Speaker
{
current_speaker
}
:
{
line
}
"
)
if
not
script_lines
:
raise
ValueError
(
"No valid content found in text file"
)
return
"
\n
"
.
join
(
script_lines
)
def
_parse_script
(
self
,
script
:
str
)
->
List
[
Tuple
[
int
,
str
]]:
"""Parse script into list of (speaker_id, text) tuples."""
lines
=
script
.
strip
().
split
(
"
\n
"
)
parsed_lines
=
[]
speaker_ids
=
[]
# First pass: parse all lines and collect speaker IDs
for
line
in
lines
:
if
not
line
.
strip
():
continue
# Use regex to handle edge cases like multiple colons
match
=
re
.
match
(
r
'^Speaker\s+(\d+)\s*:\s*(.*)$'
,
line
.
strip
(),
re
.
IGNORECASE
)
if
match
:
speaker_id
=
int
(
match
.
group
(
1
))
text
=
' '
+
match
.
group
(
2
).
strip
()
parsed_lines
.
append
((
speaker_id
,
text
))
speaker_ids
.
append
(
speaker_id
)
else
:
logger
.
warning
(
f
"Could not parse line: '
{
line
}
'"
)
if
not
parsed_lines
:
raise
ValueError
(
"No valid speaker lines found in script"
)
# Check if we need to normalize speaker IDs (only if all are > 0)
min_speaker_id
=
min
(
speaker_ids
)
if
min_speaker_id
>
0
:
# Normalize to start from 0
normalized_lines
=
[]
for
speaker_id
,
text
in
parsed_lines
:
normalized_lines
.
append
((
speaker_id
-
1
,
text
))
return
normalized_lines
else
:
# Keep original IDs
return
parsed_lines
def
_merge_inputs
(
self
,
text_inputs
:
BatchEncoding
,
audio_inputs
:
Dict
)
->
BatchEncoding
:
"""Merge text and audio inputs into a single BatchEncoding."""
# Start with text inputs
merged
=
BatchEncoding
(
text_inputs
)
# Add audio-specific fields
if
"audio"
in
audio_inputs
:
merged
[
"speech_inputs"
]
=
audio_inputs
[
"audio"
]
if
"streaming"
in
audio_inputs
:
merged
[
"streaming"
]
=
audio_inputs
[
"streaming"
]
return
merged
def
batch_decode
(
self
,
*
args
,
**
kwargs
):
"""
This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.batch_decode`].
Please refer to the docstring of this method for more information.
"""
return
self
.
tokenizer
.
batch_decode
(
*
args
,
**
kwargs
)
def
decode
(
self
,
*
args
,
**
kwargs
):
"""
This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.decode`].
Please refer to the docstring of this method for more information.
"""
return
self
.
tokenizer
.
decode
(
*
args
,
**
kwargs
)
@
property
def
model_input_names
(
self
):
"""
Return the list of inputs accepted by the model.
"""
tokenizer_input_names
=
self
.
tokenizer
.
model_input_names
audio_processor_input_names
=
self
.
audio_processor
.
model_input_names
return
list
(
dict
.
fromkeys
(
tokenizer_input_names
+
audio_processor_input_names
+
[
"speech_inputs"
,
"speech_input_mask"
]))
def
save_audio
(
self
,
audio
:
Union
[
torch
.
Tensor
,
np
.
ndarray
,
List
[
Union
[
torch
.
Tensor
,
np
.
ndarray
]]],
output_path
:
str
=
"output.wav"
,
sampling_rate
:
Optional
[
int
]
=
None
,
normalize
:
bool
=
False
,
batch_prefix
:
str
=
"audio_"
,
)
->
str
:
"""
Save audio data to a file.
Args:
audio (Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]):
The audio data to save. Can be a single tensor/array or a list of them.
output_path (str, optional): Path to save the audio file. Defaults to "output.wav".
sampling_rate (int, optional): Sampling rate for the audio. If None, uses the processor's default.
normalize (bool, optional): Whether to normalize the audio before saving. Defaults to False.
batch_prefix (str, optional): Prefix for batch audio files. Defaults to "audio_".
Returns:
str: The path to the saved audio file.
"""
return
self
.
audio_processor
.
save_audio
(
audio
,
output_path
=
output_path
,
sampling_rate
=
sampling_rate
,
normalize
=
normalize
,
batch_prefix
=
batch_prefix
)
__all__
=
[
"VibeVoiceProcessor"
,
]
\ No newline at end of file
vibevoice/processor/vibevoice_tokenizer_processor.py
0 → 100644
View file @
b4af4e0c
"""
Processor class for VibeVoice models.
"""
import
os
import
json
import
warnings
from
typing
import
List
,
Optional
,
Union
,
Dict
,
Any
import
numpy
as
np
import
torch
from
transformers.feature_extraction_utils
import
FeatureExtractionMixin
from
transformers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
class
AudioNormalizer
:
"""
Audio normalization class for VibeVoice tokenizer.
This class provides audio normalization to ensure consistent input levels
for the VibeVoice tokenizer while maintaining audio quality.
"""
def
__init__
(
self
,
target_dB_FS
:
float
=
-
25
,
eps
:
float
=
1e-6
):
"""
Initialize the audio normalizer.
Args:
target_dB_FS (float): Target dB FS level for the audio. Default: -25
eps (float): Small value to avoid division by zero. Default: 1e-6
"""
self
.
target_dB_FS
=
target_dB_FS
self
.
eps
=
eps
def
tailor_dB_FS
(
self
,
audio
:
np
.
ndarray
)
->
tuple
:
"""
Adjust the audio to the target dB FS level.
Args:
audio (np.ndarray): Input audio signal
Returns:
tuple: (normalized_audio, rms, scalar)
"""
rms
=
np
.
sqrt
(
np
.
mean
(
audio
**
2
))
scalar
=
10
**
(
self
.
target_dB_FS
/
20
)
/
(
rms
+
self
.
eps
)
normalized_audio
=
audio
*
scalar
return
normalized_audio
,
rms
,
scalar
def
avoid_clipping
(
self
,
audio
:
np
.
ndarray
,
scalar
:
Optional
[
float
]
=
None
)
->
tuple
:
"""
Avoid clipping by scaling down if necessary.
Args:
audio (np.ndarray): Input audio signal
scalar (float, optional): Explicit scaling factor
Returns:
tuple: (normalized_audio, scalar)
"""
if
scalar
is
None
:
max_val
=
np
.
max
(
np
.
abs
(
audio
))
if
max_val
>
1.0
:
scalar
=
max_val
+
self
.
eps
else
:
scalar
=
1.0
return
audio
/
scalar
,
scalar
def
__call__
(
self
,
audio
:
np
.
ndarray
)
->
np
.
ndarray
:
"""
Normalize the audio by adjusting to target dB FS and avoiding clipping.
Args:
audio (np.ndarray): Input audio signal
Returns:
np.ndarray: Normalized audio signal
"""
# First adjust to target dB FS
audio
,
_
,
_
=
self
.
tailor_dB_FS
(
audio
)
# Then avoid clipping
audio
,
_
=
self
.
avoid_clipping
(
audio
)
return
audio
# Change from ProcessorMixin to FeatureExtractionMixin which is designed for single components
class
VibeVoiceTokenizerProcessor
(
FeatureExtractionMixin
):
"""
Processor for VibeVoice acoustic tokenizer models.
This processor handles audio preprocessing for VibeVoice models, including:
- Audio format conversion (stereo to mono)
- Optional audio normalization
- Streaming support for infinite-length audio
Args:
sampling_rate (int, optional): Expected sampling rate. Defaults to 24000.
normalize_audio (bool, optional): Whether to normalize audio. Defaults to True.
target_dB_FS (float, optional): Target dB FS for normalization. Defaults to -25.
eps (float, optional): Small value for numerical stability. Defaults to 1e-6.
"""
model_input_names
=
[
"input_features"
]
def
__init__
(
self
,
sampling_rate
:
int
=
24000
,
normalize_audio
:
bool
=
True
,
target_dB_FS
:
float
=
-
25
,
eps
:
float
=
1e-6
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
sampling_rate
=
sampling_rate
self
.
normalize_audio
=
normalize_audio
# Initialize audio normalizer if needed
if
self
.
normalize_audio
:
self
.
normalizer
=
AudioNormalizer
(
target_dB_FS
=
target_dB_FS
,
eps
=
eps
)
else
:
self
.
normalizer
=
None
# Save config
self
.
feature_extractor_dict
=
{
"sampling_rate"
:
sampling_rate
,
"normalize_audio"
:
normalize_audio
,
"target_dB_FS"
:
target_dB_FS
,
"eps"
:
eps
,
}
def
_ensure_mono
(
self
,
audio
:
np
.
ndarray
)
->
np
.
ndarray
:
"""
Convert stereo audio to mono if needed.
Args:
audio (np.ndarray): Input audio array
Returns:
np.ndarray: Mono audio array
"""
if
len
(
audio
.
shape
)
==
1
:
return
audio
elif
len
(
audio
.
shape
)
==
2
:
if
audio
.
shape
[
0
]
==
2
:
# (2, time)
return
np
.
mean
(
audio
,
axis
=
0
)
elif
audio
.
shape
[
1
]
==
2
:
# (time, 2)
return
np
.
mean
(
audio
,
axis
=
1
)
else
:
# If one dimension is 1, squeeze it
if
audio
.
shape
[
0
]
==
1
:
return
audio
.
squeeze
(
0
)
elif
audio
.
shape
[
1
]
==
1
:
return
audio
.
squeeze
(
1
)
else
:
raise
ValueError
(
f
"Unexpected audio shape:
{
audio
.
shape
}
"
)
else
:
raise
ValueError
(
f
"Audio should be 1D or 2D, got shape:
{
audio
.
shape
}
"
)
def
_process_single_audio
(
self
,
audio
:
Union
[
np
.
ndarray
,
List
[
float
]])
->
np
.
ndarray
:
"""
Process a single audio array.
Args:
audio: Single audio input
Returns:
np.ndarray: Processed audio
"""
# Convert to numpy array
if
not
isinstance
(
audio
,
np
.
ndarray
):
audio
=
np
.
array
(
audio
,
dtype
=
np
.
float32
)
else
:
audio
=
audio
.
astype
(
np
.
float32
)
# Ensure mono
audio
=
self
.
_ensure_mono
(
audio
)
# Normalize if requested
if
self
.
normalize_audio
and
self
.
normalizer
is
not
None
:
audio
=
self
.
normalizer
(
audio
)
return
audio
def
__call__
(
self
,
audio
:
Union
[
str
,
np
.
ndarray
,
List
[
float
],
List
[
np
.
ndarray
],
List
[
List
[
float
]],
List
[
str
]]
=
None
,
sampling_rate
:
Optional
[
int
]
=
None
,
return_tensors
:
Optional
[
str
]
=
None
,
**
kwargs
,
):
"""
Process audio for VibeVoice models.
Args:
audio: Audio input(s) to process. Can be:
- str: Path to audio file
- np.ndarray: Audio array
- List[float]: Audio as list of floats
- List[np.ndarray]: Batch of audio arrays
- List[str]: Batch of audio file paths
sampling_rate (int, optional): Sampling rate of the input audio
return_tensors (str, optional): Return format ('pt' for PyTorch, 'np' for NumPy)
Returns:
dict: Processed audio inputs with keys:
- input_features: Audio tensor(s) ready for the model
"""
if
audio
is
None
:
raise
ValueError
(
"Audio input is required"
)
# Validate sampling rate
if
sampling_rate
is
not
None
and
sampling_rate
!=
self
.
sampling_rate
:
logger
.
warning
(
f
"Input sampling rate (
{
sampling_rate
}
) differs from expected "
f
"sampling rate (
{
self
.
sampling_rate
}
). Please resample your audio."
)
# Handle different input types
if
isinstance
(
audio
,
str
):
# Single audio file path
audio
=
self
.
_load_audio_from_path
(
audio
)
is_batched
=
False
elif
isinstance
(
audio
,
list
):
if
len
(
audio
)
==
0
:
raise
ValueError
(
"Empty audio list provided"
)
# Check if it's a list of file paths
if
all
(
isinstance
(
item
,
str
)
for
item
in
audio
):
# Batch of audio file paths
audio
=
[
self
.
_load_audio_from_path
(
path
)
for
path
in
audio
]
is_batched
=
True
else
:
# Check if it's batched audio arrays
is_batched
=
isinstance
(
audio
[
0
],
(
np
.
ndarray
,
list
))
else
:
# Single audio array or list
is_batched
=
False
# Process audio
if
is_batched
:
processed_audio
=
[
self
.
_process_single_audio
(
a
)
for
a
in
audio
]
else
:
processed_audio
=
[
self
.
_process_single_audio
(
audio
)]
# Convert to tensors if requested
if
return_tensors
==
"pt"
:
if
len
(
processed_audio
)
==
1
:
# Create a proper batch dimension (B, T)
input_features
=
torch
.
from_numpy
(
processed_audio
[
0
]).
unsqueeze
(
0
).
unsqueeze
(
1
)
else
:
# For batched input with different lengths, create a batch properly
input_features
=
torch
.
stack
([
torch
.
from_numpy
(
a
)
for
a
in
processed_audio
]).
unsqueeze
(
1
)
elif
return_tensors
==
"np"
:
if
len
(
processed_audio
)
==
1
:
input_features
=
processed_audio
[
0
][
np
.
newaxis
,
np
.
newaxis
,
:]
else
:
input_features
=
np
.
stack
(
processed_audio
)[:,
np
.
newaxis
,
:]
else
:
input_features
=
processed_audio
[
0
]
if
len
(
processed_audio
)
==
1
else
processed_audio
outputs
=
{
"audio"
:
input_features
,
# Use "audio" instead of "input_features"
}
return
outputs
def
_load_audio_from_path
(
self
,
audio_path
:
str
)
->
np
.
ndarray
:
"""
Load audio from file path.
Args:
audio_path (str): Path to audio file
Returns:
np.ndarray: Loaded audio array
"""
# Get file extension to determine loading method
file_ext
=
os
.
path
.
splitext
(
audio_path
)[
1
].
lower
()
if
file_ext
in
[
'.wav'
,
'.mp3'
,
'.flac'
,
'.m4a'
,
'.ogg'
]:
# Audio file - use librosa
import
librosa
audio_array
,
sr
=
librosa
.
load
(
audio_path
,
sr
=
self
.
sampling_rate
,
mono
=
True
)
return
audio_array
elif
file_ext
==
'.pt'
:
# PyTorch tensor file
audio_tensor
=
torch
.
load
(
audio_path
,
map_location
=
'cpu'
).
squeeze
()
if
isinstance
(
audio_tensor
,
torch
.
Tensor
):
audio_array
=
audio_tensor
.
numpy
()
else
:
audio_array
=
np
.
array
(
audio_tensor
)
return
audio_array
.
astype
(
np
.
float32
)
elif
file_ext
==
'.npy'
:
# NumPy file
audio_array
=
np
.
load
(
audio_path
)
return
audio_array
.
astype
(
np
.
float32
)
else
:
raise
ValueError
(
f
"Unsupported file format:
{
file_ext
}
. "
f
"Supported formats: .wav, .mp3, .flac, .m4a, .ogg, .pt, .npy, .npz"
)
def
preprocess_audio
(
self
,
audio_path_or_array
:
Union
[
str
,
np
.
ndarray
],
normalize
:
Optional
[
bool
]
=
None
,
)
->
np
.
ndarray
:
"""
Convenience method to preprocess audio from file path or array.
This method is kept for backward compatibility but __call__ is recommended.
Args:
audio_path_or_array: Path to audio file or numpy array
normalize: Whether to normalize (overrides default setting)
Returns:
np.ndarray: Preprocessed audio array
"""
if
isinstance
(
audio_path_or_array
,
str
):
audio_array
=
self
.
_load_audio_from_path
(
audio_path_or_array
)
else
:
audio_array
=
np
.
array
(
audio_path_or_array
,
dtype
=
np
.
float32
)
# Override normalization setting if specified
original_normalize
=
self
.
normalize_audio
if
normalize
is
not
None
:
self
.
normalize_audio
=
normalize
try
:
processed
=
self
.
_process_single_audio
(
audio_array
)
finally
:
# Restore original setting
self
.
normalize_audio
=
original_normalize
return
processed
# Override to_dict method for configuration saving
def
to_dict
(
self
)
->
Dict
[
str
,
Any
]:
"""
Convert the object to a dict containing all attributes needed for serialization.
"""
return
self
.
feature_extractor_dict
def
save_audio
(
self
,
audio
:
Union
[
torch
.
Tensor
,
np
.
ndarray
,
List
[
Union
[
torch
.
Tensor
,
np
.
ndarray
]]],
output_path
:
str
=
"output.wav"
,
sampling_rate
:
Optional
[
int
]
=
None
,
normalize
:
bool
=
False
,
batch_prefix
:
str
=
"audio_"
,
):
"""
Save audio data to WAV file(s).
Args:
audio: Audio data to save. Can be:
- torch.Tensor: PyTorch tensor with shape (B, C, T) or (B, T) or (T)
- np.ndarray: NumPy array with shape (B, C, T) or (B, T) or (T)
- List of tensors or arrays
output_path: Path where to save the audio. If saving multiple files,
this is treated as a directory and individual files will be saved inside.
sampling_rate: Sampling rate for the saved audio. Defaults to the processor's rate.
normalize: Whether to normalize audio before saving.
batch_prefix: Prefix for batch files when saving multiple audios.
Returns:
List[str]: Paths to the saved audio files.
"""
if
sampling_rate
is
None
:
sampling_rate
=
self
.
sampling_rate
try
:
import
soundfile
as
sf
except
ImportError
:
raise
ImportError
(
"soundfile is required to save audio files. "
"Install it with: pip install soundfile"
)
# Ensure audio is in the right format
if
isinstance
(
audio
,
torch
.
Tensor
):
# Convert PyTorch tensor to numpy
audio_np
=
audio
.
float
().
detach
().
cpu
().
numpy
()
elif
isinstance
(
audio
,
np
.
ndarray
):
audio_np
=
audio
elif
isinstance
(
audio
,
list
):
# Handle list of tensors or arrays
if
all
(
isinstance
(
a
,
torch
.
Tensor
)
for
a
in
audio
):
audio_np
=
[
a
.
float
().
detach
().
cpu
().
numpy
()
for
a
in
audio
]
else
:
audio_np
=
audio
else
:
raise
ValueError
(
f
"Unsupported audio type:
{
type
(
audio
)
}
"
)
saved_paths
=
[]
# Handle based on shape or type
if
isinstance
(
audio_np
,
list
):
# Multiple separate audios to save
output_dir
=
output_path
# Ensure output directory exists
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
# Save each audio
for
i
,
audio_item
in
enumerate
(
audio_np
):
audio_item
=
self
.
_prepare_audio_for_save
(
audio_item
,
normalize
)
file_path
=
os
.
path
.
join
(
output_dir
,
f
"
{
batch_prefix
}{
i
}
.wav"
)
sf
.
write
(
file_path
,
audio_item
,
sampling_rate
)
saved_paths
.
append
(
file_path
)
else
:
# Handle different dimensions
if
len
(
audio_np
.
shape
)
>=
3
:
# (B, C, T) or similar
# Get batch size
batch_size
=
audio_np
.
shape
[
0
]
if
batch_size
>
1
:
# Multiple audios in a batch
output_dir
=
output_path
# Ensure output directory exists
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
# Save each audio in the batch
for
i
in
range
(
batch_size
):
# Extract single audio and remove channel dim if present
single_audio
=
audio_np
[
i
]
if
len
(
single_audio
.
shape
)
>
1
:
if
single_audio
.
shape
[
0
]
==
1
:
# (1, T)
single_audio
=
single_audio
.
squeeze
(
0
)
single_audio
=
self
.
_prepare_audio_for_save
(
single_audio
,
normalize
)
file_path
=
os
.
path
.
join
(
output_dir
,
f
"
{
batch_prefix
}{
i
}
.wav"
)
sf
.
write
(
file_path
,
single_audio
,
sampling_rate
)
saved_paths
.
append
(
file_path
)
else
:
# Single audio with batch and channel dims
audio_item
=
audio_np
.
squeeze
()
# Remove batch and channel dimensions
audio_item
=
self
.
_prepare_audio_for_save
(
audio_item
,
normalize
)
sf
.
write
(
output_path
,
audio_item
,
sampling_rate
)
saved_paths
.
append
(
output_path
)
else
:
# Single audio without batch dimension
audio_item
=
self
.
_prepare_audio_for_save
(
audio_np
,
normalize
)
sf
.
write
(
output_path
,
audio_item
,
sampling_rate
)
saved_paths
.
append
(
output_path
)
return
saved_paths
def
_prepare_audio_for_save
(
self
,
audio
:
np
.
ndarray
,
normalize
:
bool
)
->
np
.
ndarray
:
"""
Prepare audio for saving by ensuring it's the right shape and optionally normalizing.
Args:
audio: Audio data as numpy array
normalize: Whether to normalize audio
Returns:
np.ndarray: Processed audio ready for saving
"""
# Ensure right dimensionality
if
len
(
audio
.
shape
)
>
1
and
audio
.
shape
[
0
]
==
1
:
# (1, T)
audio
=
audio
.
squeeze
(
0
)
# Normalize if requested
if
normalize
:
max_val
=
np
.
abs
(
audio
).
max
()
if
max_val
>
0
:
audio
=
audio
/
max_val
return
audio
__all__
=
[
"VibeVoiceTokenizerProcessor"
,
"AudioNormalizer"
]
\ No newline at end of file
vibevoice/schedule/__init__.py
0 → 100644
View file @
b4af4e0c
vibevoice/schedule/dpm_solver.py
0 → 100644
View file @
b4af4e0c
# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
import
math
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
from
diffusers.configuration_utils
import
ConfigMixin
,
register_to_config
from
diffusers.utils
import
deprecate
from
diffusers.utils.torch_utils
import
randn_tensor
from
diffusers.schedulers.scheduling_utils
import
KarrasDiffusionSchedulers
,
SchedulerMixin
,
SchedulerOutput
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
max_beta
=
0.999
,
alpha_transform_type
=
"cosine"
,
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if
alpha_transform_type
==
"cosine"
:
def
alpha_bar_fn
(
t
):
return
math
.
cos
((
t
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
# return math.cos(t * math.pi / 2 * 0.95) ** 2
elif
alpha_transform_type
==
"exp"
:
def
alpha_bar_fn
(
t
):
return
math
.
exp
(
t
*
-
12.0
)
elif
alpha_transform_type
==
"cauchy"
:
# µ + γ tan (π (0.5 - x)) γ = 1, µ = 3
# alpha^2 = 1-1/(exp(λ)+1)
def
alpha_bar_fn
(
t
,
gamma
=
1
,
mu
=
3
):
snr
=
mu
+
gamma
*
math
.
tan
(
math
.
pi
*
(
0.5
-
t
)
*
0.9
)
return
1
-
1
/
(
math
.
exp
(
snr
)
+
1.1
)
elif
alpha_transform_type
==
"laplace"
:
# µ − bsgn(0.5 − t) log(1 − 2|t − 0.5|) µ = 0, b = 1
def
alpha_bar_fn
(
t
,
mu
=
0
,
b
=
1
):
snr
=
mu
-
b
*
math
.
copysign
(
1
,
0.5
-
t
)
*
math
.
log
(
1
-
2
*
abs
(
t
-
0.5
)
*
0.98
)
return
1
-
1
/
(
math
.
exp
(
snr
)
+
1.02
)
else
:
raise
ValueError
(
f
"Unsupported alpha_transform_type:
{
alpha_transform_type
}
"
)
betas
=
[]
for
i
in
range
(
num_diffusion_timesteps
):
t1
=
i
/
num_diffusion_timesteps
t2
=
(
i
+
1
)
/
num_diffusion_timesteps
betas
.
append
(
min
(
1
-
alpha_bar_fn
(
t2
)
/
alpha_bar_fn
(
t1
),
max_beta
))
return
torch
.
tensor
(
betas
,
dtype
=
torch
.
float32
)
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def
rescale_zero_terminal_snr
(
betas
):
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
Args:
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.Tensor`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas
=
1.0
-
betas
alphas_cumprod
=
torch
.
cumprod
(
alphas
,
dim
=
0
)
alphas_bar_sqrt
=
alphas_cumprod
.
sqrt
()
# Store old values.
alphas_bar_sqrt_0
=
alphas_bar_sqrt
[
0
].
clone
()
alphas_bar_sqrt_T
=
alphas_bar_sqrt
[
-
1
].
clone
()
# Shift so the last timestep is zero.
alphas_bar_sqrt
-=
alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt
*=
alphas_bar_sqrt_0
/
(
alphas_bar_sqrt_0
-
alphas_bar_sqrt_T
)
# Convert alphas_bar_sqrt to betas
alphas_bar
=
alphas_bar_sqrt
**
2
# Revert sqrt
alphas
=
alphas_bar
[
1
:]
/
alphas_bar
[:
-
1
]
# Revert cumprod
alphas
=
torch
.
cat
([
alphas_bar
[
0
:
1
],
alphas
])
betas
=
1
-
alphas
return
betas
class
DPMSolverMultistepScheduler
(
SchedulerMixin
,
ConfigMixin
):
"""
`DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
beta_start (`float`, defaults to 0.0001):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
solver_order (`int`, defaults to 2):
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://imagen.research.google/video/paper.pdf) paper).
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
dynamic_thresholding_ratio (`float`, defaults to 0.995):
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
sample_max_value (`float`, defaults to 1.0):
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
`algorithm_type="dpmsolver++"`.
algorithm_type (`str`, defaults to `dpmsolver++`):
Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
`dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
paper, and the `dpmsolver++` type implements the algorithms in the
[DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
`sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
solver_type (`str`, defaults to `midpoint`):
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
lower_order_final (`bool`, defaults to `True`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
euler_at_final (`bool`, defaults to `False`):
Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
steps, but sometimes may result in blurring.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
use_lu_lambdas (`bool`, *optional*, defaults to `False`):
Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
`lambda(t)`.
final_sigmas_type (`str`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule.
variance_type (`str`, *optional*):
Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
contains the predicted Gaussian variance.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0):
An offset added to the inference steps, as required by some model families.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
"""
_compatibles
=
[
e
.
name
for
e
in
KarrasDiffusionSchedulers
]
order
=
1
@
register_to_config
def
__init__
(
self
,
num_train_timesteps
:
int
=
1000
,
beta_start
:
float
=
0.0001
,
beta_end
:
float
=
0.02
,
beta_schedule
:
str
=
"linear"
,
trained_betas
:
Optional
[
Union
[
np
.
ndarray
,
List
[
float
]]]
=
None
,
solver_order
:
int
=
2
,
prediction_type
:
str
=
"epsilon"
,
thresholding
:
bool
=
False
,
dynamic_thresholding_ratio
:
float
=
0.995
,
sample_max_value
:
float
=
1.0
,
algorithm_type
:
str
=
"dpmsolver++"
,
solver_type
:
str
=
"midpoint"
,
lower_order_final
:
bool
=
True
,
euler_at_final
:
bool
=
False
,
use_karras_sigmas
:
Optional
[
bool
]
=
False
,
use_lu_lambdas
:
Optional
[
bool
]
=
False
,
final_sigmas_type
:
Optional
[
str
]
=
"zero"
,
# "zero", "sigma_min"
lambda_min_clipped
:
float
=
-
float
(
"inf"
),
variance_type
:
Optional
[
str
]
=
None
,
timestep_spacing
:
str
=
"linspace"
,
steps_offset
:
int
=
0
,
rescale_betas_zero_snr
:
bool
=
False
,
):
if
algorithm_type
in
[
"dpmsolver"
,
"sde-dpmsolver"
]:
deprecation_message
=
f
"algorithm_type
{
algorithm_type
}
is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
deprecate
(
"algorithm_types dpmsolver and sde-dpmsolver"
,
"1.0.0"
,
deprecation_message
)
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"scaled_linear"
:
# this schedule is very specific to the latent diffusion model.
self
.
betas
=
torch
.
linspace
(
beta_start
**
0.5
,
beta_end
**
0.5
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
**
2
elif
beta_schedule
==
"squaredcos_cap_v2"
or
beta_schedule
==
"cosine"
:
# Glide cosine schedule
self
.
betas
=
betas_for_alpha_bar
(
num_train_timesteps
,
alpha_transform_type
=
"cosine"
)
elif
beta_schedule
==
"cauchy"
:
self
.
betas
=
betas_for_alpha_bar
(
num_train_timesteps
,
alpha_transform_type
=
"cauchy"
)
elif
beta_schedule
==
"laplace"
:
self
.
betas
=
betas_for_alpha_bar
(
num_train_timesteps
,
alpha_transform_type
=
"laplace"
)
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
is not implemented for
{
self
.
__class__
}
"
)
if
rescale_betas_zero_snr
:
self
.
betas
=
rescale_zero_terminal_snr
(
self
.
betas
)
self
.
alphas
=
1.0
-
self
.
betas
self
.
alphas_cumprod
=
torch
.
cumprod
(
self
.
alphas
,
dim
=
0
)
if
rescale_betas_zero_snr
:
# Close to 0 without being 0 so first sigma is not inf
# FP16 smallest positive subnormal works well here
self
.
alphas_cumprod
[
-
1
]
=
2
**-
24
# Currently we only support VP-type noise schedule
self
.
alpha_t
=
torch
.
sqrt
(
self
.
alphas_cumprod
)
self
.
sigma_t
=
torch
.
sqrt
(
1
-
self
.
alphas_cumprod
)
self
.
lambda_t
=
torch
.
log
(
self
.
alpha_t
)
-
torch
.
log
(
self
.
sigma_t
)
self
.
sigmas
=
((
1
-
self
.
alphas_cumprod
)
/
self
.
alphas_cumprod
)
**
0.5
# standard deviation of the initial noise distribution
self
.
init_noise_sigma
=
1.0
# settings for DPM-Solver
if
algorithm_type
not
in
[
"dpmsolver"
,
"dpmsolver++"
,
"sde-dpmsolver"
,
"sde-dpmsolver++"
]:
if
algorithm_type
==
"deis"
:
self
.
register_to_config
(
algorithm_type
=
"dpmsolver++"
)
else
:
raise
NotImplementedError
(
f
"
{
algorithm_type
}
is not implemented for
{
self
.
__class__
}
"
)
if
solver_type
not
in
[
"midpoint"
,
"heun"
]:
if
solver_type
in
[
"logrho"
,
"bh1"
,
"bh2"
]:
self
.
register_to_config
(
solver_type
=
"midpoint"
)
else
:
raise
NotImplementedError
(
f
"
{
solver_type
}
is not implemented for
{
self
.
__class__
}
"
)
if
algorithm_type
not
in
[
"dpmsolver++"
,
"sde-dpmsolver++"
]
and
final_sigmas_type
==
"zero"
:
raise
ValueError
(
f
"`final_sigmas_type`
{
final_sigmas_type
}
is not supported for `algorithm_type`
{
algorithm_type
}
. Please choose `sigma_min` instead."
)
# setable values
self
.
num_inference_steps
=
None
timesteps
=
np
.
linspace
(
0
,
num_train_timesteps
-
1
,
num_train_timesteps
,
dtype
=
np
.
float32
)[::
-
1
].
copy
()
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
)
self
.
model_outputs
=
[
None
]
*
solver_order
self
.
lower_order_nums
=
0
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
@
property
def
step_index
(
self
):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
set_timesteps
(
self
,
num_inference_steps
:
int
=
None
,
device
:
Union
[
str
,
torch
.
device
]
=
None
,
timesteps
:
Optional
[
List
[
int
]]
=
None
,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
must be `None`, and `timestep_spacing` attribute will be ignored.
"""
if
num_inference_steps
is
None
and
timesteps
is
None
:
raise
ValueError
(
"Must pass exactly one of `num_inference_steps` or `timesteps`."
)
if
num_inference_steps
is
not
None
and
timesteps
is
not
None
:
raise
ValueError
(
"Can only pass one of `num_inference_steps` or `custom_timesteps`."
)
if
timesteps
is
not
None
and
self
.
config
.
use_karras_sigmas
:
raise
ValueError
(
"Cannot use `timesteps` with `config.use_karras_sigmas = True`"
)
if
timesteps
is
not
None
and
self
.
config
.
use_lu_lambdas
:
raise
ValueError
(
"Cannot use `timesteps` with `config.use_lu_lambdas = True`"
)
if
timesteps
is
not
None
:
timesteps
=
np
.
array
(
timesteps
).
astype
(
np
.
int64
)
else
:
# Clipping the minimum of all lambda(t) for numerical stability.
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
clipped_idx
=
torch
.
searchsorted
(
torch
.
flip
(
self
.
lambda_t
,
[
0
]),
self
.
config
.
lambda_min_clipped
)
last_timestep
=
((
self
.
config
.
num_train_timesteps
-
clipped_idx
).
numpy
()).
item
()
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if
self
.
config
.
timestep_spacing
==
"linspace"
:
timesteps
=
(
np
.
linspace
(
0
,
last_timestep
-
1
,
num_inference_steps
+
1
)
.
round
()[::
-
1
][:
-
1
]
.
copy
()
.
astype
(
np
.
int64
)
)
elif
self
.
config
.
timestep_spacing
==
"leading"
:
step_ratio
=
last_timestep
//
(
num_inference_steps
+
1
)
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps
=
(
(
np
.
arange
(
0
,
num_inference_steps
+
1
)
*
step_ratio
).
round
()[::
-
1
][:
-
1
].
copy
().
astype
(
np
.
int64
)
)
timesteps
+=
self
.
config
.
steps_offset
elif
self
.
config
.
timestep_spacing
==
"trailing"
:
step_ratio
=
self
.
config
.
num_train_timesteps
/
num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps
=
np
.
arange
(
last_timestep
,
0
,
-
step_ratio
).
round
().
copy
().
astype
(
np
.
int64
)
timesteps
-=
1
else
:
raise
ValueError
(
f
"
{
self
.
config
.
timestep_spacing
}
is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
sigmas
=
np
.
array
(((
1
-
self
.
alphas_cumprod
)
/
self
.
alphas_cumprod
)
**
0.5
)
log_sigmas
=
np
.
log
(
sigmas
)
if
self
.
config
.
use_karras_sigmas
:
sigmas
=
np
.
flip
(
sigmas
).
copy
()
sigmas
=
self
.
_convert_to_karras
(
in_sigmas
=
sigmas
,
num_inference_steps
=
num_inference_steps
)
timesteps
=
np
.
array
([
self
.
_sigma_to_t
(
sigma
,
log_sigmas
)
for
sigma
in
sigmas
]).
round
()
elif
self
.
config
.
use_lu_lambdas
:
lambdas
=
np
.
flip
(
log_sigmas
.
copy
())
lambdas
=
self
.
_convert_to_lu
(
in_lambdas
=
lambdas
,
num_inference_steps
=
num_inference_steps
)
sigmas
=
np
.
exp
(
lambdas
)
timesteps
=
np
.
array
([
self
.
_sigma_to_t
(
sigma
,
log_sigmas
)
for
sigma
in
sigmas
]).
round
()
else
:
sigmas
=
np
.
interp
(
timesteps
,
np
.
arange
(
0
,
len
(
sigmas
)),
sigmas
)
if
self
.
config
.
final_sigmas_type
==
"sigma_min"
:
sigma_last
=
((
1
-
self
.
alphas_cumprod
[
0
])
/
self
.
alphas_cumprod
[
0
])
**
0.5
elif
self
.
config
.
final_sigmas_type
==
"zero"
:
sigma_last
=
0
else
:
raise
ValueError
(
f
"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got
{
self
.
config
.
final_sigmas_type
}
"
)
sigmas
=
np
.
concatenate
([
sigmas
,
[
sigma_last
]]).
astype
(
np
.
float32
)
self
.
sigmas
=
torch
.
from_numpy
(
sigmas
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
,
dtype
=
torch
.
int64
)
self
.
num_inference_steps
=
len
(
timesteps
)
self
.
model_outputs
=
[
None
,
]
*
self
.
config
.
solver_order
self
.
lower_order_nums
=
0
# add an index counter for schedulers that allow duplicated timesteps
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def
_threshold_sample
(
self
,
sample
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487
"""
dtype
=
sample
.
dtype
batch_size
,
channels
,
*
remaining_dims
=
sample
.
shape
if
dtype
not
in
(
torch
.
float32
,
torch
.
float64
):
sample
=
sample
.
float
()
# upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample
=
sample
.
reshape
(
batch_size
,
channels
*
np
.
prod
(
remaining_dims
))
abs_sample
=
sample
.
abs
()
# "a certain percentile absolute pixel value"
s
=
torch
.
quantile
(
abs_sample
,
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
s
=
torch
.
clamp
(
s
,
min
=
1
,
max
=
self
.
config
.
sample_max_value
)
# When clamped to min=1, equivalent to standard clipping to [-1, 1]
s
=
s
.
unsqueeze
(
1
)
# (batch_size, 1) because clamp will broadcast along dim=0
sample
=
torch
.
clamp
(
sample
,
-
s
,
s
)
/
s
# "we threshold xt0 to the range [-s, s] and then divide by s"
sample
=
sample
.
reshape
(
batch_size
,
channels
,
*
remaining_dims
)
sample
=
sample
.
to
(
dtype
)
return
sample
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def
_sigma_to_t
(
self
,
sigma
,
log_sigmas
):
# get log sigma
log_sigma
=
np
.
log
(
np
.
maximum
(
sigma
,
1e-10
))
# get distribution
dists
=
log_sigma
-
log_sigmas
[:,
np
.
newaxis
]
# get sigmas range
low_idx
=
np
.
cumsum
((
dists
>=
0
),
axis
=
0
).
argmax
(
axis
=
0
).
clip
(
max
=
log_sigmas
.
shape
[
0
]
-
2
)
high_idx
=
low_idx
+
1
low
=
log_sigmas
[
low_idx
]
high
=
log_sigmas
[
high_idx
]
# interpolate sigmas
w
=
(
low
-
log_sigma
)
/
(
low
-
high
)
w
=
np
.
clip
(
w
,
0
,
1
)
# transform interpolation to time range
t
=
(
1
-
w
)
*
low_idx
+
w
*
high_idx
t
=
t
.
reshape
(
sigma
.
shape
)
return
t
def
_sigma_to_alpha_sigma_t
(
self
,
sigma
):
alpha_t
=
1
/
((
sigma
**
2
+
1
)
**
0.5
)
sigma_t
=
sigma
*
alpha_t
return
alpha_t
,
sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def
_convert_to_karras
(
self
,
in_sigmas
:
torch
.
Tensor
,
num_inference_steps
)
->
torch
.
Tensor
:
"""Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if
hasattr
(
self
.
config
,
"sigma_min"
):
sigma_min
=
self
.
config
.
sigma_min
else
:
sigma_min
=
None
if
hasattr
(
self
.
config
,
"sigma_max"
):
sigma_max
=
self
.
config
.
sigma_max
else
:
sigma_max
=
None
sigma_min
=
sigma_min
if
sigma_min
is
not
None
else
in_sigmas
[
-
1
].
item
()
sigma_max
=
sigma_max
if
sigma_max
is
not
None
else
in_sigmas
[
0
].
item
()
rho
=
7.0
# 7.0 is the value used in the paper
ramp
=
np
.
linspace
(
0
,
1
,
num_inference_steps
)
min_inv_rho
=
sigma_min
**
(
1
/
rho
)
max_inv_rho
=
sigma_max
**
(
1
/
rho
)
sigmas
=
(
max_inv_rho
+
ramp
*
(
min_inv_rho
-
max_inv_rho
))
**
rho
return
sigmas
def
_convert_to_lu
(
self
,
in_lambdas
:
torch
.
Tensor
,
num_inference_steps
)
->
torch
.
Tensor
:
"""Constructs the noise schedule of Lu et al. (2022)."""
lambda_min
:
float
=
in_lambdas
[
-
1
].
item
()
lambda_max
:
float
=
in_lambdas
[
0
].
item
()
rho
=
1.0
# 1.0 is the value used in the paper
ramp
=
np
.
linspace
(
0
,
1
,
num_inference_steps
)
min_inv_rho
=
lambda_min
**
(
1
/
rho
)
max_inv_rho
=
lambda_max
**
(
1
/
rho
)
lambdas
=
(
max_inv_rho
+
ramp
*
(
min_inv_rho
-
max_inv_rho
))
**
rho
return
lambdas
def
convert_model_output
(
self
,
model_output
:
torch
.
Tensor
,
*
args
,
sample
:
torch
.
Tensor
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
"""
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
integral of the data prediction model.
<Tip>
The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
prediction and data prediction models.
</Tip>
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.Tensor`:
The converted model output.
"""
timestep
=
args
[
0
]
if
len
(
args
)
>
0
else
kwargs
.
pop
(
"timestep"
,
None
)
if
sample
is
None
:
if
len
(
args
)
>
1
:
sample
=
args
[
1
]
else
:
raise
ValueError
(
"missing `sample` as a required keyward argument"
)
if
timestep
is
not
None
:
deprecate
(
"timesteps"
,
"1.0.0"
,
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`"
,
)
# DPM-Solver++ needs to solve an integral of the data prediction model.
if
self
.
config
.
algorithm_type
in
[
"dpmsolver++"
,
"sde-dpmsolver++"
]:
if
self
.
config
.
prediction_type
==
"epsilon"
:
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if
self
.
config
.
variance_type
in
[
"learned"
,
"learned_range"
]:
model_output
=
model_output
[:,
:
3
]
sigma
=
self
.
sigmas
[
self
.
step_index
]
alpha_t
,
sigma_t
=
self
.
_sigma_to_alpha_sigma_t
(
sigma
)
x0_pred
=
(
sample
-
sigma_t
*
model_output
)
/
alpha_t
elif
self
.
config
.
prediction_type
==
"sample"
:
x0_pred
=
model_output
elif
self
.
config
.
prediction_type
==
"v_prediction"
:
sigma
=
self
.
sigmas
[
self
.
step_index
]
alpha_t
,
sigma_t
=
self
.
_sigma_to_alpha_sigma_t
(
sigma
)
x0_pred
=
alpha_t
*
sample
-
sigma_t
*
model_output
else
:
raise
ValueError
(
f
"prediction_type given as
{
self
.
config
.
prediction_type
}
must be one of `epsilon`, `sample`, or"
" `v_prediction` for the DPMSolverMultistepScheduler."
)
if
self
.
config
.
thresholding
:
x0_pred
=
self
.
_threshold_sample
(
x0_pred
)
return
x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model.
elif
self
.
config
.
algorithm_type
in
[
"dpmsolver"
,
"sde-dpmsolver"
]:
if
self
.
config
.
prediction_type
==
"epsilon"
:
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if
self
.
config
.
variance_type
in
[
"learned"
,
"learned_range"
]:
epsilon
=
model_output
[:,
:
3
]
else
:
epsilon
=
model_output
elif
self
.
config
.
prediction_type
==
"sample"
:
sigma
=
self
.
sigmas
[
self
.
step_index
]
alpha_t
,
sigma_t
=
self
.
_sigma_to_alpha_sigma_t
(
sigma
)
epsilon
=
(
sample
-
alpha_t
*
model_output
)
/
sigma_t
elif
self
.
config
.
prediction_type
==
"v_prediction"
:
sigma
=
self
.
sigmas
[
self
.
step_index
]
alpha_t
,
sigma_t
=
self
.
_sigma_to_alpha_sigma_t
(
sigma
)
epsilon
=
alpha_t
*
model_output
+
sigma_t
*
sample
else
:
raise
ValueError
(
f
"prediction_type given as
{
self
.
config
.
prediction_type
}
must be one of `epsilon`, `sample`, or"
" `v_prediction` for the DPMSolverMultistepScheduler."
)
if
self
.
config
.
thresholding
:
sigma
=
self
.
sigmas
[
self
.
step_index
]
alpha_t
,
sigma_t
=
self
.
_sigma_to_alpha_sigma_t
(
sigma
)
x0_pred
=
(
sample
-
sigma_t
*
epsilon
)
/
alpha_t
x0_pred
=
self
.
_threshold_sample
(
x0_pred
)
epsilon
=
(
sample
-
alpha_t
*
x0_pred
)
/
sigma_t
return
epsilon
def
dpm_solver_first_order_update
(
self
,
model_output
:
torch
.
Tensor
,
*
args
,
sample
:
torch
.
Tensor
=
None
,
noise
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
"""
One step for the first-order DPMSolver (equivalent to DDIM).
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
timestep
=
args
[
0
]
if
len
(
args
)
>
0
else
kwargs
.
pop
(
"timestep"
,
None
)
prev_timestep
=
args
[
1
]
if
len
(
args
)
>
1
else
kwargs
.
pop
(
"prev_timestep"
,
None
)
if
sample
is
None
:
if
len
(
args
)
>
2
:
sample
=
args
[
2
]
else
:
raise
ValueError
(
" missing `sample` as a required keyward argument"
)
if
timestep
is
not
None
:
deprecate
(
"timesteps"
,
"1.0.0"
,
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`"
,
)
if
prev_timestep
is
not
None
:
deprecate
(
"prev_timestep"
,
"1.0.0"
,
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`"
,
)
sigma_t
,
sigma_s
=
self
.
sigmas
[
self
.
step_index
+
1
],
self
.
sigmas
[
self
.
step_index
]
alpha_t
,
sigma_t
=
self
.
_sigma_to_alpha_sigma_t
(
sigma_t
)
alpha_s
,
sigma_s
=
self
.
_sigma_to_alpha_sigma_t
(
sigma_s
)
lambda_t
=
torch
.
log
(
alpha_t
)
-
torch
.
log
(
sigma_t
)
lambda_s
=
torch
.
log
(
alpha_s
)
-
torch
.
log
(
sigma_s
)
h
=
lambda_t
-
lambda_s
if
self
.
config
.
algorithm_type
==
"dpmsolver++"
:
x_t
=
(
sigma_t
/
sigma_s
)
*
sample
-
(
alpha_t
*
(
torch
.
exp
(
-
h
)
-
1.0
))
*
model_output
elif
self
.
config
.
algorithm_type
==
"dpmsolver"
:
x_t
=
(
alpha_t
/
alpha_s
)
*
sample
-
(
sigma_t
*
(
torch
.
exp
(
h
)
-
1.0
))
*
model_output
elif
self
.
config
.
algorithm_type
==
"sde-dpmsolver++"
:
assert
noise
is
not
None
x_t
=
(
(
sigma_t
/
sigma_s
*
torch
.
exp
(
-
h
))
*
sample
+
(
alpha_t
*
(
1
-
torch
.
exp
(
-
2.0
*
h
)))
*
model_output
+
sigma_t
*
torch
.
sqrt
(
1.0
-
torch
.
exp
(
-
2
*
h
))
*
noise
)
elif
self
.
config
.
algorithm_type
==
"sde-dpmsolver"
:
assert
noise
is
not
None
x_t
=
(
(
alpha_t
/
alpha_s
)
*
sample
-
2.0
*
(
sigma_t
*
(
torch
.
exp
(
h
)
-
1.0
))
*
model_output
+
sigma_t
*
torch
.
sqrt
(
torch
.
exp
(
2
*
h
)
-
1.0
)
*
noise
)
return
x_t
def
multistep_dpm_solver_second_order_update
(
self
,
model_output_list
:
List
[
torch
.
Tensor
],
*
args
,
sample
:
torch
.
Tensor
=
None
,
noise
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
"""
One step for the second-order multistep DPMSolver.
Args:
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
timestep_list
=
args
[
0
]
if
len
(
args
)
>
0
else
kwargs
.
pop
(
"timestep_list"
,
None
)
prev_timestep
=
args
[
1
]
if
len
(
args
)
>
1
else
kwargs
.
pop
(
"prev_timestep"
,
None
)
if
sample
is
None
:
if
len
(
args
)
>
2
:
sample
=
args
[
2
]
else
:
raise
ValueError
(
" missing `sample` as a required keyward argument"
)
if
timestep_list
is
not
None
:
deprecate
(
"timestep_list"
,
"1.0.0"
,
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`"
,
)
if
prev_timestep
is
not
None
:
deprecate
(
"prev_timestep"
,
"1.0.0"
,
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`"
,
)
sigma_t
,
sigma_s0
,
sigma_s1
=
(
self
.
sigmas
[
self
.
step_index
+
1
],
self
.
sigmas
[
self
.
step_index
],
self
.
sigmas
[
self
.
step_index
-
1
],
)
alpha_t
,
sigma_t
=
self
.
_sigma_to_alpha_sigma_t
(
sigma_t
)
alpha_s0
,
sigma_s0
=
self
.
_sigma_to_alpha_sigma_t
(
sigma_s0
)
alpha_s1
,
sigma_s1
=
self
.
_sigma_to_alpha_sigma_t
(
sigma_s1
)
lambda_t
=
torch
.
log
(
alpha_t
)
-
torch
.
log
(
sigma_t
)
lambda_s0
=
torch
.
log
(
alpha_s0
)
-
torch
.
log
(
sigma_s0
)
lambda_s1
=
torch
.
log
(
alpha_s1
)
-
torch
.
log
(
sigma_s1
)
m0
,
m1
=
model_output_list
[
-
1
],
model_output_list
[
-
2
]
h
,
h_0
=
lambda_t
-
lambda_s0
,
lambda_s0
-
lambda_s1
r0
=
h_0
/
h
D0
,
D1
=
m0
,
(
1.0
/
r0
)
*
(
m0
-
m1
)
if
self
.
config
.
algorithm_type
==
"dpmsolver++"
:
# See https://arxiv.org/abs/2211.01095 for detailed derivations
if
self
.
config
.
solver_type
==
"midpoint"
:
x_t
=
(
(
sigma_t
/
sigma_s0
)
*
sample
-
(
alpha_t
*
(
torch
.
exp
(
-
h
)
-
1.0
))
*
D0
-
0.5
*
(
alpha_t
*
(
torch
.
exp
(
-
h
)
-
1.0
))
*
D1
)
elif
self
.
config
.
solver_type
==
"heun"
:
x_t
=
(
(
sigma_t
/
sigma_s0
)
*
sample
-
(
alpha_t
*
(
torch
.
exp
(
-
h
)
-
1.0
))
*
D0
+
(
alpha_t
*
((
torch
.
exp
(
-
h
)
-
1.0
)
/
h
+
1.0
))
*
D1
)
elif
self
.
config
.
algorithm_type
==
"dpmsolver"
:
# See https://arxiv.org/abs/2206.00927 for detailed derivations
if
self
.
config
.
solver_type
==
"midpoint"
:
x_t
=
(
(
alpha_t
/
alpha_s0
)
*
sample
-
(
sigma_t
*
(
torch
.
exp
(
h
)
-
1.0
))
*
D0
-
0.5
*
(
sigma_t
*
(
torch
.
exp
(
h
)
-
1.0
))
*
D1
)
elif
self
.
config
.
solver_type
==
"heun"
:
x_t
=
(
(
alpha_t
/
alpha_s0
)
*
sample
-
(
sigma_t
*
(
torch
.
exp
(
h
)
-
1.0
))
*
D0
-
(
sigma_t
*
((
torch
.
exp
(
h
)
-
1.0
)
/
h
-
1.0
))
*
D1
)
elif
self
.
config
.
algorithm_type
==
"sde-dpmsolver++"
:
assert
noise
is
not
None
if
self
.
config
.
solver_type
==
"midpoint"
:
x_t
=
(
(
sigma_t
/
sigma_s0
*
torch
.
exp
(
-
h
))
*
sample
+
(
alpha_t
*
(
1
-
torch
.
exp
(
-
2.0
*
h
)))
*
D0
+
0.5
*
(
alpha_t
*
(
1
-
torch
.
exp
(
-
2.0
*
h
)))
*
D1
+
sigma_t
*
torch
.
sqrt
(
1.0
-
torch
.
exp
(
-
2
*
h
))
*
noise
)
elif
self
.
config
.
solver_type
==
"heun"
:
x_t
=
(
(
sigma_t
/
sigma_s0
*
torch
.
exp
(
-
h
))
*
sample
+
(
alpha_t
*
(
1
-
torch
.
exp
(
-
2.0
*
h
)))
*
D0
+
(
alpha_t
*
((
1.0
-
torch
.
exp
(
-
2.0
*
h
))
/
(
-
2.0
*
h
)
+
1.0
))
*
D1
+
sigma_t
*
torch
.
sqrt
(
1.0
-
torch
.
exp
(
-
2
*
h
))
*
noise
)
elif
self
.
config
.
algorithm_type
==
"sde-dpmsolver"
:
assert
noise
is
not
None
if
self
.
config
.
solver_type
==
"midpoint"
:
x_t
=
(
(
alpha_t
/
alpha_s0
)
*
sample
-
2.0
*
(
sigma_t
*
(
torch
.
exp
(
h
)
-
1.0
))
*
D0
-
(
sigma_t
*
(
torch
.
exp
(
h
)
-
1.0
))
*
D1
+
sigma_t
*
torch
.
sqrt
(
torch
.
exp
(
2
*
h
)
-
1.0
)
*
noise
)
elif
self
.
config
.
solver_type
==
"heun"
:
x_t
=
(
(
alpha_t
/
alpha_s0
)
*
sample
-
2.0
*
(
sigma_t
*
(
torch
.
exp
(
h
)
-
1.0
))
*
D0
-
2.0
*
(
sigma_t
*
((
torch
.
exp
(
h
)
-
1.0
)
/
h
-
1.0
))
*
D1
+
sigma_t
*
torch
.
sqrt
(
torch
.
exp
(
2
*
h
)
-
1.0
)
*
noise
)
return
x_t
def
multistep_dpm_solver_third_order_update
(
self
,
model_output_list
:
List
[
torch
.
Tensor
],
*
args
,
sample
:
torch
.
Tensor
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
"""
One step for the third-order multistep DPMSolver.
Args:
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.Tensor`):
A current instance of a sample created by diffusion process.
Returns:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
timestep_list
=
args
[
0
]
if
len
(
args
)
>
0
else
kwargs
.
pop
(
"timestep_list"
,
None
)
prev_timestep
=
args
[
1
]
if
len
(
args
)
>
1
else
kwargs
.
pop
(
"prev_timestep"
,
None
)
if
sample
is
None
:
if
len
(
args
)
>
2
:
sample
=
args
[
2
]
else
:
raise
ValueError
(
" missing`sample` as a required keyward argument"
)
if
timestep_list
is
not
None
:
deprecate
(
"timestep_list"
,
"1.0.0"
,
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`"
,
)
if
prev_timestep
is
not
None
:
deprecate
(
"prev_timestep"
,
"1.0.0"
,
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`"
,
)
sigma_t
,
sigma_s0
,
sigma_s1
,
sigma_s2
=
(
self
.
sigmas
[
self
.
step_index
+
1
],
self
.
sigmas
[
self
.
step_index
],
self
.
sigmas
[
self
.
step_index
-
1
],
self
.
sigmas
[
self
.
step_index
-
2
],
)
alpha_t
,
sigma_t
=
self
.
_sigma_to_alpha_sigma_t
(
sigma_t
)
alpha_s0
,
sigma_s0
=
self
.
_sigma_to_alpha_sigma_t
(
sigma_s0
)
alpha_s1
,
sigma_s1
=
self
.
_sigma_to_alpha_sigma_t
(
sigma_s1
)
alpha_s2
,
sigma_s2
=
self
.
_sigma_to_alpha_sigma_t
(
sigma_s2
)
lambda_t
=
torch
.
log
(
alpha_t
)
-
torch
.
log
(
sigma_t
)
lambda_s0
=
torch
.
log
(
alpha_s0
)
-
torch
.
log
(
sigma_s0
)
lambda_s1
=
torch
.
log
(
alpha_s1
)
-
torch
.
log
(
sigma_s1
)
lambda_s2
=
torch
.
log
(
alpha_s2
)
-
torch
.
log
(
sigma_s2
)
m0
,
m1
,
m2
=
model_output_list
[
-
1
],
model_output_list
[
-
2
],
model_output_list
[
-
3
]
h
,
h_0
,
h_1
=
lambda_t
-
lambda_s0
,
lambda_s0
-
lambda_s1
,
lambda_s1
-
lambda_s2
r0
,
r1
=
h_0
/
h
,
h_1
/
h
D0
=
m0
D1_0
,
D1_1
=
(
1.0
/
r0
)
*
(
m0
-
m1
),
(
1.0
/
r1
)
*
(
m1
-
m2
)
D1
=
D1_0
+
(
r0
/
(
r0
+
r1
))
*
(
D1_0
-
D1_1
)
D2
=
(
1.0
/
(
r0
+
r1
))
*
(
D1_0
-
D1_1
)
if
self
.
config
.
algorithm_type
==
"dpmsolver++"
:
# See https://arxiv.org/abs/2206.00927 for detailed derivations
x_t
=
(
(
sigma_t
/
sigma_s0
)
*
sample
-
(
alpha_t
*
(
torch
.
exp
(
-
h
)
-
1.0
))
*
D0
+
(
alpha_t
*
((
torch
.
exp
(
-
h
)
-
1.0
)
/
h
+
1.0
))
*
D1
-
(
alpha_t
*
((
torch
.
exp
(
-
h
)
-
1.0
+
h
)
/
h
**
2
-
0.5
))
*
D2
)
elif
self
.
config
.
algorithm_type
==
"dpmsolver"
:
# See https://arxiv.org/abs/2206.00927 for detailed derivations
x_t
=
(
(
alpha_t
/
alpha_s0
)
*
sample
-
(
sigma_t
*
(
torch
.
exp
(
h
)
-
1.0
))
*
D0
-
(
sigma_t
*
((
torch
.
exp
(
h
)
-
1.0
)
/
h
-
1.0
))
*
D1
-
(
sigma_t
*
((
torch
.
exp
(
h
)
-
1.0
-
h
)
/
h
**
2
-
0.5
))
*
D2
)
return
x_t
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
index_candidates
=
(
schedule_timesteps
==
timestep
).
nonzero
()
if
len
(
index_candidates
)
==
0
:
step_index
=
len
(
self
.
timesteps
)
-
1
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
elif
len
(
index_candidates
)
>
1
:
step_index
=
index_candidates
[
1
].
item
()
else
:
step_index
=
index_candidates
[
0
].
item
()
return
step_index
def
_init_step_index
(
self
,
timestep
):
"""
Initialize the step_index counter for the scheduler.
"""
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
def
step
(
self
,
model_output
:
torch
.
Tensor
,
timestep
:
int
,
sample
:
torch
.
Tensor
,
generator
=
None
,
variance_noise
:
Optional
[
torch
.
Tensor
]
=
None
,
return_dict
:
bool
=
True
,
)
->
Union
[
SchedulerOutput
,
Tuple
]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
the multistep DPMSolver.
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
variance_noise (`torch.Tensor`):
Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`LEdits++`].
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
if
self
.
num_inference_steps
is
None
:
raise
ValueError
(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if
self
.
step_index
is
None
:
self
.
_init_step_index
(
timestep
)
# Improve numerical stability for small number of steps
lower_order_final
=
(
self
.
step_index
==
len
(
self
.
timesteps
)
-
1
)
and
(
self
.
config
.
euler_at_final
or
(
self
.
config
.
lower_order_final
and
len
(
self
.
timesteps
)
<
15
)
or
self
.
config
.
final_sigmas_type
==
"zero"
)
lower_order_second
=
(
(
self
.
step_index
==
len
(
self
.
timesteps
)
-
2
)
and
self
.
config
.
lower_order_final
and
len
(
self
.
timesteps
)
<
15
)
model_output
=
self
.
convert_model_output
(
model_output
,
sample
=
sample
)
for
i
in
range
(
self
.
config
.
solver_order
-
1
):
self
.
model_outputs
[
i
]
=
self
.
model_outputs
[
i
+
1
]
self
.
model_outputs
[
-
1
]
=
model_output
# Upcast to avoid precision issues when computing prev_sample
sample
=
sample
.
to
(
torch
.
float32
)
if
self
.
config
.
algorithm_type
in
[
"sde-dpmsolver"
,
"sde-dpmsolver++"
]
and
variance_noise
is
None
:
noise
=
randn_tensor
(
model_output
.
shape
,
generator
=
generator
,
device
=
model_output
.
device
,
dtype
=
torch
.
float32
)
elif
self
.
config
.
algorithm_type
in
[
"sde-dpmsolver"
,
"sde-dpmsolver++"
]:
noise
=
variance_noise
.
to
(
device
=
model_output
.
device
,
dtype
=
torch
.
float32
)
else
:
noise
=
None
if
self
.
config
.
solver_order
==
1
or
self
.
lower_order_nums
<
1
or
lower_order_final
:
prev_sample
=
self
.
dpm_solver_first_order_update
(
model_output
,
sample
=
sample
,
noise
=
noise
)
elif
self
.
config
.
solver_order
==
2
or
self
.
lower_order_nums
<
2
or
lower_order_second
:
prev_sample
=
self
.
multistep_dpm_solver_second_order_update
(
self
.
model_outputs
,
sample
=
sample
,
noise
=
noise
)
else
:
prev_sample
=
self
.
multistep_dpm_solver_third_order_update
(
self
.
model_outputs
,
sample
=
sample
)
if
self
.
lower_order_nums
<
self
.
config
.
solver_order
:
self
.
lower_order_nums
+=
1
# Cast sample back to expected dtype
prev_sample
=
prev_sample
.
to
(
model_output
.
dtype
)
# upon completion increase step index by one
self
.
_step_index
+=
1
if
not
return_dict
:
return
(
prev_sample
,)
return
SchedulerOutput
(
prev_sample
=
prev_sample
)
def
add_noise
(
self
,
original_samples
:
torch
.
Tensor
,
noise
:
torch
.
Tensor
,
timesteps
:
torch
.
IntTensor
,
)
->
torch
.
Tensor
:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
# alpha_t = self.alpha_t.to(device=original_samples.device, dtype=original_samples.dtype)
# sigma_t = self.sigma_t.to(device=original_samples.device, dtype=original_samples.dtype)
alpha_t
=
self
.
alpha_t
.
to
(
original_samples
.
device
).
to
(
original_samples
.
dtype
)
sigma_t
=
self
.
sigma_t
.
to
(
original_samples
.
device
).
to
(
original_samples
.
dtype
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
alpha_t
=
alpha_t
[
timesteps
].
flatten
()
while
len
(
alpha_t
.
shape
)
<
len
(
original_samples
.
shape
):
alpha_t
=
alpha_t
.
unsqueeze
(
-
1
)
sigma_t
=
sigma_t
[
timesteps
].
flatten
()
while
len
(
sigma_t
.
shape
)
<
len
(
original_samples
.
shape
):
sigma_t
=
sigma_t
.
unsqueeze
(
-
1
)
noisy_samples
=
alpha_t
*
original_samples
+
sigma_t
*
noise
return
noisy_samples
def
get_velocity
(
self
,
original_samples
:
torch
.
Tensor
,
noise
:
torch
.
Tensor
,
timesteps
:
torch
.
IntTensor
)
->
torch
.
Tensor
:
# alpha_t = self.alpha_t.to(device=original_samples.device, dtype=original_samples.dtype)
# sigma_t = self.sigma_t.to(device=original_samples.device, dtype=original_samples.dtype)
alpha_t
=
self
.
alpha_t
.
to
(
original_samples
.
device
).
to
(
original_samples
.
dtype
)
sigma_t
=
self
.
sigma_t
.
to
(
original_samples
.
device
).
to
(
original_samples
.
dtype
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
alpha_t
=
alpha_t
[
timesteps
].
flatten
()
while
len
(
alpha_t
.
shape
)
<
len
(
original_samples
.
shape
):
alpha_t
=
alpha_t
.
unsqueeze
(
-
1
)
sigma_t
=
sigma_t
[
timesteps
].
flatten
()
while
len
(
sigma_t
.
shape
)
<
len
(
original_samples
.
shape
):
sigma_t
=
sigma_t
.
unsqueeze
(
-
1
)
velocity
=
alpha_t
*
noise
-
sigma_t
*
original_samples
return
velocity
def
__len__
(
self
):
return
self
.
config
.
num_train_timesteps
\ No newline at end of file
vibevoice/schedule/timestep_sampler.py
0 → 100644
View file @
b4af4e0c
import
math
import
torch
class
UniformSampler
:
def
__init__
(
self
,
timesteps
=
1000
):
self
.
timesteps
=
timesteps
def
sample
(
self
,
batch_size
,
device
):
return
torch
.
randint
(
0
,
self
.
timesteps
,
(
batch_size
,),
device
=
device
)
class
LogitNormalSampler
:
def
__init__
(
self
,
timesteps
=
1000
,
m
=
0
,
s
=
1
):
self
.
timesteps
=
timesteps
timesteps
=
torch
.
linspace
(
0
,
1
,
timesteps
)
logit
=
torch
.
log
(
timesteps
/
(
1
-
timesteps
))
self
.
prob
=
torch
.
exp
(
-
0.5
*
(
logit
-
m
)
**
2
/
s
**
2
)
/
(
s
*
math
.
sqrt
(
2
*
math
.
pi
))
def
sample
(
self
,
batch_size
,
device
):
return
torch
.
multinomial
(
self
.
prob
,
batch_size
,
replacement
=
True
).
to
(
device
)
\ No newline at end of file
vibevoice/scripts/__init__.py
0 → 100644
View file @
b4af4e0c
vibevoice/scripts/convert_nnscaler_checkpoint_to_transformers.py
0 → 100644
View file @
b4af4e0c
#!/usr/bin/env python
# coding=utf-8
import
argparse
import
json
import
os
from
pathlib
import
Path
import
re
import
torch
from
typing
import
Dict
,
List
,
Tuple
from
vibevoice.modular.configuration_vibevoice
import
(
VibeVoiceConfig
)
from
vibevoice.modular.modeling_vibevoice
import
VibeVoiceForConditionalGeneration
from
transformers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
def
convert_vibevoice_nnscaler_checkpoint_to_hf
(
checkpoint_path
:
str
,
pytorch_dump_folder_path
:
str
,
config_path
:
str
=
None
,
):
"""
Convert a nnscaler VibeVoice checkpoint to HuggingFace format.
Supports both regular checkpoints and tensor parallel checkpoints.
"""
# Load regular checkpoint
logger
.
info
(
f
"Loading regular checkpoint from
{
checkpoint_path
}
"
)
checkpoint
=
torch
.
load
(
checkpoint_path
,
map_location
=
"cpu"
)
# ['model', 'optimizer', 'lr_scheduler', 'train_status', 'train_args', 'rng_states', 'nnscaler', 'dataloader']
# config = checkpoint['train_args']
init_config_name
=
checkpoint
[
'train_args'
][
'vars'
][
'model_args'
][
'config_path'
][
'relative_path'
]
pretrained_name
=
checkpoint
[
'train_args'
][
'vars'
][
'data_args'
][
'tokenizer_path'
]
init_config_path
=
Path
(
__file__
).
parent
.
parent
/
'configs'
/
init_config_name
.
split
(
'/'
)[
-
1
]
if
init_config_path
.
exists
():
logger
.
info
(
f
"Loading initial config from
{
init_config_path
}
"
)
with
open
(
init_config_path
,
'r'
)
as
f
:
init_config
=
json
.
load
(
f
)
else
:
raise
FileNotFoundError
(
f
"Initial config file
{
init_config_path
}
not found. Please provide a valid path."
)
tie_word_embeddings
=
init_config
[
'decoder_config'
].
get
(
'tie_word_embeddings'
,
True
)
logger
.
info
(
f
"Tie word embeddings:
{
tie_word_embeddings
}
"
)
init_config
[
'decoder_config'
][
'use_cache'
]
=
True
config
=
VibeVoiceConfig
(
**
init_config
,
tie_word_embeddings
=
tie_word_embeddings
)
# # Extract the model state dict
model_state_dict
=
{
k
.
replace
(
'model.model.'
,
'model.'
):
v
for
k
,
v
in
checkpoint
[
"model"
].
items
()
if
k
.
startswith
(
'model.model.'
)}
if
not
tie_word_embeddings
and
'model.lm_head.weight'
in
checkpoint
[
"model"
].
keys
():
# If not tying weights, we need to add the lm_head weight separately
model_state_dict
[
'lm_head.weight'
]
=
checkpoint
[
"model"
][
'model.lm_head.weight'
]
# Override with provided config if available
if
config_path
:
logger
.
info
(
f
"Loading config from
{
config_path
}
"
)
with
open
(
config_path
,
'r'
)
as
f
:
config_dict
=
json
.
load
(
f
)
config
=
VibeVoiceConfig
.
from_dict
(
config_dict
)
# Set the default dtype to bfloat16 before creating the model
original_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
torch
.
bfloat16
)
# Create the HuggingFace model
logger
.
info
(
"Creating HuggingFace VibeVoiceForConditionalGeneration model"
)
model
=
VibeVoiceForConditionalGeneration
(
config
)
# Restore original dtype
torch
.
set_default_dtype
(
original_dtype
)
# Load the state dict
logger
.
info
(
"Loading weights into model"
)
missing_keys
,
unexpected_keys
=
model
.
load_state_dict
(
model_state_dict
,
strict
=
False
)
if
missing_keys
:
logger
.
warning
(
f
"Missing keys:
{
missing_keys
}
"
)
if
unexpected_keys
:
logger
.
warning
(
f
"Unexpected keys:
{
unexpected_keys
}
"
)
# Create output directory
os
.
makedirs
(
pytorch_dump_folder_path
,
exist_ok
=
True
)
# Save the model and config
logger
.
info
(
f
"Saving model to
{
pytorch_dump_folder_path
}
"
)
# Save config
config
.
save_pretrained
(
pytorch_dump_folder_path
)
# Save VibeVoiceProcessor configuration
logger
.
info
(
"Saving VibeVoiceProcessor configuration"
)
processor_config
=
{
"processor_class"
:
"VibeVoiceProcessor"
,
"speech_tok_compress_ratio"
:
3200
,
"db_normalize"
:
True
,
# Audio processor configuration
"audio_processor"
:
{
"feature_extractor_type"
:
"VibeVoiceTokenizerProcessor"
,
"sampling_rate"
:
24000
,
"normalize_audio"
:
True
,
"target_dB_FS"
:
-
25
,
"eps"
:
1e-6
,
},
"language_model_pretrained_name"
:
pretrained_name
,
}
processor_config_path
=
os
.
path
.
join
(
pytorch_dump_folder_path
,
"preprocessor_config.json"
)
with
open
(
processor_config_path
,
'w'
)
as
f
:
json
.
dump
(
processor_config
,
f
,
indent
=
2
)
logger
.
info
(
f
"Saved processor config to
{
processor_config_path
}
"
)
# Save model with sharding
# save_pretrained handles tied weights automatically
logger
.
info
(
"Saving model weights with sharding..."
)
model
.
save_pretrained
(
pytorch_dump_folder_path
,
max_shard_size
=
"2GB"
,
# Set maximum size for each shard
safe_serialization
=
True
# Ensure saving in .safetensors format
)
logger
.
info
(
f
"Model weights saved to
{
pytorch_dump_folder_path
}
"
)
logger
.
info
(
"Conversion complete!"
)
# Verify the saved model can be loaded
logger
.
info
(
"Verifying saved model..."
)
loaded_model
=
VibeVoiceForConditionalGeneration
.
from_pretrained
(
pytorch_dump_folder_path
)
logger
.
info
(
"Model successfully loaded from saved checkpoint!"
)
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--nnscaler_checkpoint_path"
,
type
=
str
,
required
=
True
,
help
=
"Path to the fairseq checkpoint (.pt file). For tensor parallel checkpoints, "
"provide any one of the part files (e.g., checkpoint_1_5000-model_part-0.pt), "
"and the script will automatically detect and merge all parts."
,
)
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model directory"
,
)
parser
.
add_argument
(
"--config_path"
,
type
=
str
,
default
=
None
,
help
=
"Optional path to a config JSON file to override extracted config"
,
)
args
=
parser
.
parse_args
()
convert_vibevoice_nnscaler_checkpoint_to_hf
(
args
.
nnscaler_checkpoint_path
,
args
.
pytorch_dump_folder_path
,
args
.
config_path
,
)
if
__name__
==
"__main__"
:
main
()
\ No newline at end of file
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment