Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
InspireMusic_pytorch
Commits
0112b0f0
Commit
0112b0f0
authored
Feb 14, 2025
by
chenzk
Browse files
v1.0
parents
Pipeline
#2394
canceled with stages
Changes
474
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3004 additions
and
0 deletions
+3004
-0
examples/music_generation/inspiremusic/utils/__pycache__/tokenizer_utils.cpython-310.pyc
...remusic/utils/__pycache__/tokenizer_utils.cpython-310.pyc
+0
-0
examples/music_generation/inspiremusic/utils/audio_utils.py
examples/music_generation/inspiremusic/utils/audio_utils.py
+624
-0
examples/music_generation/inspiremusic/utils/binary.py
examples/music_generation/inspiremusic/utils/binary.py
+155
-0
examples/music_generation/inspiremusic/utils/class_utils.py
examples/music_generation/inspiremusic/utils/class_utils.py
+71
-0
examples/music_generation/inspiremusic/utils/common.py
examples/music_generation/inspiremusic/utils/common.py
+179
-0
examples/music_generation/inspiremusic/utils/data_utils.py
examples/music_generation/inspiremusic/utils/data_utils.py
+105
-0
examples/music_generation/inspiremusic/utils/executor.py
examples/music_generation/inspiremusic/utils/executor.py
+121
-0
examples/music_generation/inspiremusic/utils/file_utils.py
examples/music_generation/inspiremusic/utils/file_utils.py
+79
-0
examples/music_generation/inspiremusic/utils/frontend_utils.py
...les/music_generation/inspiremusic/utils/frontend_utils.py
+126
-0
examples/music_generation/inspiremusic/utils/hinter.py
examples/music_generation/inspiremusic/utils/hinter.py
+13
-0
examples/music_generation/inspiremusic/utils/losses.py
examples/music_generation/inspiremusic/utils/losses.py
+20
-0
examples/music_generation/inspiremusic/utils/mask.py
examples/music_generation/inspiremusic/utils/mask.py
+227
-0
examples/music_generation/inspiremusic/utils/scheduler.py
examples/music_generation/inspiremusic/utils/scheduler.py
+738
-0
examples/music_generation/inspiremusic/utils/tokenizer_utils.py
...es/music_generation/inspiremusic/utils/tokenizer_utils.py
+221
-0
examples/music_generation/inspiremusic/utils/train_utils.py
examples/music_generation/inspiremusic/utils/train_utils.py
+300
-0
examples/music_generation/inspiremusic/utils/utils.py
examples/music_generation/inspiremusic/utils/utils.py
+23
-0
examples/music_generation/inspiremusic/version.txt
examples/music_generation/inspiremusic/version.txt
+2
-0
examples/music_generation/inspiremusic/wavtokenizer/__init__.py
...es/music_generation/inspiremusic/wavtokenizer/__init__.py
+0
-0
examples/music_generation/inspiremusic/wavtokenizer/__pycache__/__init__.cpython-310.pyc
...remusic/wavtokenizer/__pycache__/__init__.cpython-310.pyc
+0
-0
examples/music_generation/inspiremusic/wavtokenizer/decoder/__init__.py
..._generation/inspiremusic/wavtokenizer/decoder/__init__.py
+0
-0
No files found.
examples/music_generation/inspiremusic/utils/__pycache__/tokenizer_utils.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
examples/music_generation/inspiremusic/utils/audio_utils.py
0 → 100644
View file @
0112b0f0
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import
io
import
logging
import
re
import
sys
import
inspect
import
random
import
typing
as
tp
from
functools
import
partial
import
omegaconf
import
torch
import
torchaudio
import
numpy
as
np
from
typing_extensions
import
Literal
from
typing
import
(
Any
,
Union
,
Iterable
,
List
,
Dict
,
Optional
,
Tuple
,
)
from
librosa.filters
import
mel
as
librosa_mel_fn
from
scipy.io.wavfile
import
read
_BoolLike_co
=
Union
[
bool
,
np
.
bool_
]
_IntLike_co
=
Union
[
_BoolLike_co
,
int
,
"np.integer[Any]"
]
_FloatLike_co
=
Union
[
_IntLike_co
,
float
,
"np.floating[Any]"
]
def
process_audio
(
file_path
,
target_sample_rate
=
24000
):
audio
,
sample_rate
=
torchaudio
.
load
(
file_path
)
# Check if the audio needs to be resampled
if
sample_rate
!=
target_sample_rate
:
audio
=
torchaudio
.
transforms
.
Resample
(
orig_freq
=
sample_rate
,
new_freq
=
target_sample_rate
)(
audio
)
# Convert stereo to mono (if necessary)
audio
=
audio
.
mean
(
dim
=
0
,
keepdim
=
True
)
if
audio
.
size
(
0
)
==
2
else
audio
return
audio
,
target_sample_rate
def
load_wav
(
full_path
):
sampling_rate
,
data
=
read
(
full_path
)
return
data
,
sampling_rate
def
dynamic_range_compression
(
x
,
C
=
1
,
clip_val
=
1e-5
):
return
np
.
log
(
np
.
clip
(
x
,
a_min
=
clip_val
,
a_max
=
None
)
*
C
)
def
dynamic_range_decompression
(
x
,
C
=
1
):
return
np
.
exp
(
x
)
/
C
def
dynamic_range_compression_torch
(
x
,
C
=
1
,
clip_val
=
1e-5
):
return
torch
.
log
(
torch
.
clamp
(
x
,
min
=
clip_val
)
*
C
)
def
dynamic_range_decompression_torch
(
x
,
C
=
1
):
return
torch
.
exp
(
x
)
/
C
def
spectral_normalize_torch
(
magnitudes
):
output
=
dynamic_range_compression_torch
(
magnitudes
)
return
output
def
spectral_de_normalize_torch
(
magnitudes
):
output
=
dynamic_range_decompression_torch
(
magnitudes
)
return
output
def
mel_spectrogram
(
y
,
n_fft
,
num_mels
,
sampling_rate
,
hop_size
,
win_size
,
fmin
,
fmax
,
center
=
False
):
if
torch
.
min
(
y
)
<
-
1.0
:
print
(
"min value is "
,
torch
.
min
(
y
))
if
torch
.
max
(
y
)
>
1.0
:
print
(
"max value is "
,
torch
.
max
(
y
))
# global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
mel_basis
=
{}
hann_window
=
{}
if
f
"
{
str
(
fmax
)
}
_
{
str
(
y
.
device
)
}
"
not
in
mel_basis
:
mel
=
librosa_mel_fn
(
sr
=
sampling_rate
,
n_fft
=
n_fft
,
n_mels
=
num_mels
,
fmin
=
fmin
,
fmax
=
fmax
)
mel_basis
[
str
(
fmax
)
+
"_"
+
str
(
y
.
device
)]
=
torch
.
from_numpy
(
mel
).
float
().
to
(
y
.
device
)
hann_window
[
str
(
y
.
device
)]
=
torch
.
hann_window
(
win_size
).
to
(
y
.
device
)
y
=
torch
.
nn
.
functional
.
pad
(
y
.
unsqueeze
(
1
),
(
int
((
n_fft
-
hop_size
)
/
2
),
int
((
n_fft
-
hop_size
)
/
2
)),
mode
=
"reflect"
)
y
=
y
.
squeeze
(
1
)
spec
=
torch
.
view_as_real
(
torch
.
stft
(
y
,
n_fft
,
hop_length
=
hop_size
,
win_length
=
win_size
,
window
=
hann_window
[
str
(
y
.
device
)],
center
=
center
,
pad_mode
=
"reflect"
,
normalized
=
False
,
onesided
=
True
,
return_complex
=
True
,
)
)
spec
=
torch
.
sqrt
(
spec
.
pow
(
2
).
sum
(
-
1
)
+
(
1e-9
))
spec
=
torch
.
matmul
(
mel_basis
[
str
(
fmax
)
+
"_"
+
str
(
y
.
device
)],
spec
)
spec
=
spectral_normalize_torch
(
spec
)
return
spec
def
fade_out
(
audio
:
torch
.
Tensor
,
sample_rate
:
int
,
fade_duration
:
float
)
->
torch
.
Tensor
:
"""
Apply a linear fade-out effect to the given audio waveform.
Parameters:
audio (torch.Tensor): The audio waveform tensor.
sample_rate (int): Sample rate of the audio.
fade_duration (float): Duration of the fade-out effect in seconds.
Returns:
torch.Tensor: The audio with the fade-out effect applied.
"""
fade_samples
=
int
(
fade_duration
*
sample_rate
)
if
fade_samples
>
audio
.
shape
[
1
]:
fade_samples
=
audio
.
shape
[
1
]
# use the whole length of audio if necessary
fade_out_envelope
=
torch
.
linspace
(
1.0
,
0.0
,
fade_samples
,
dtype
=
audio
.
dtype
,
device
=
audio
.
device
)
fade_section
=
audio
[:,
-
fade_samples
:].
clone
()
fade_section
*=
fade_out_envelope
faded_audio
=
audio
.
clone
()
faded_audio
[:,
-
fade_samples
:]
=
fade_section
return
faded_audio
def
split_wav_into_chunks
(
num_samples
,
wav
,
max_chunk_size
,
minimum_chunk_size
=
720
):
num_chunks
=
(
num_samples
+
max_chunk_size
-
1
)
//
max_chunk_size
# Ceiling division
wav_chunks
=
[]
for
i
in
range
(
num_chunks
):
start_idx
=
i
*
max_chunk_size
end_idx
=
min
(
start_idx
+
max_chunk_size
,
num_samples
)
if
(
end_idx
-
start_idx
)
>=
minimum_chunk_size
:
if
len
(
wav
.
shape
)
==
2
:
chunk
=
wav
[:,
start_idx
:
end_idx
]
else
:
chunk
=
wav
[
start_idx
:
end_idx
]
wav_chunks
.
append
(
chunk
)
else
:
print
(
f
"
{
num_samples
}
:
{
num_chunks
}
, chunk size=
{
(
end_idx
-
start_idx
)
}
is lower then minimum_chunk_size!"
)
return
wav_chunks
def
tiny
(
x
:
Union
[
float
,
np
.
ndarray
])
->
_FloatLike_co
:
"""Compute the tiny-value corresponding to an input's data type.
"""
# Make sure we have an array view
x
=
np
.
asarray
(
x
)
# Only floating types generate a tiny
if
np
.
issubdtype
(
x
.
dtype
,
np
.
floating
)
or
np
.
issubdtype
(
x
.
dtype
,
np
.
complexfloating
):
dtype
=
x
.
dtype
else
:
dtype
=
np
.
dtype
(
np
.
float32
)
return
np
.
finfo
(
dtype
).
tiny
def
detect_silence
(
audio
,
sample_rate
,
threshold
=
0.05
,
min_silence_duration
=
1
):
"""
Detects the first occurrence of silence in the audio.
Parameters:
audio (Tensor): The audio waveform.
sample_rate (int): The sample rate of the audio.
threshold (float): The threshold below which the signal is considered silent.
min_silence_duration (float): The minimum duration of silence in seconds.
Returns:
int: The timestamp (in samples) where the silence starts.
"""
# Convert the audio to a numpy array for easier manipulation
audio_np
=
audio
.
numpy
().
flatten
()
# Calculate the energy of the signal
energy
=
np
.
abs
(
audio_np
)
# Find the indices where the energy is below the threshold
silent_indices
=
np
.
where
(
energy
<
threshold
)[
0
]
# Find the start and end of contiguous silent regions
silent_regions
=
np
.
split
(
silent_indices
,
np
.
where
(
np
.
diff
(
silent_indices
)
!=
1
)[
0
]
+
1
)
# Filter out regions that are too short
min_silence_samples
=
int
(
min_silence_duration
*
sample_rate
)
for
region
in
silent_regions
:
if
len
(
region
)
>=
min_silence_samples
:
return
region
[
0
]
# If no silence is found, return the length of the audio
return
len
(
audio_np
)
def
trim_audio
(
waveform
,
sample_rate
=
24000
,
threshold
=
0.05
,
min_silence_duration
=
1
,
minimum_silence_start_sample
=
24000
):
"""
Trims the audio from the beginning to the first occurrence of silence.
Parameters:
waveform (Tensor): The waveform data to the input audio file.
sample_rate (int): Sample rate of the input audio file.
threshold (float): The threshold below which the signal is considered silent.
min_silence_duration (float): The minimum duration of silence in seconds.
"""
# Detect the first occurrence of silence
silence_start_sample
=
detect_silence
(
waveform
,
sample_rate
,
threshold
,
min_silence_duration
)
if
silence_start_sample
>
minimum_silence_start_sample
:
trimmed_waveform
=
waveform
[:
silence_start_sample
]
else
:
trimmed_waveform
=
waveform
[:
minimum_silence_start_sample
]
if
isinstance
(
trimmed_waveform
,
torch
.
Tensor
):
return
trimmed_waveform
else
:
return
trimmed_waveform
.
unsqueeze
()
def
normalize_loudness
(
wav
:
torch
.
Tensor
,
sample_rate
:
int
,
loudness_headroom_db
:
float
=
14
,
loudness_compressor
:
bool
=
False
,
energy_floor
:
float
=
2e-3
):
"""Normalize an input signal to a user loudness in dB LKFS.
Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
Args:
wav (torch.Tensor): Input multichannel audio data.
sample_rate (int): Sample rate.
loudness_headroom_db (float): Target loudness of the output in dB LUFS.
loudness_compressor (bool): Uses tanh for soft clipping.
energy_floor (float): anything below that RMS level will not be rescaled.
Returns:
torch.Tensor: Loudness normalized output data.
"""
energy
=
wav
.
pow
(
2
).
mean
().
sqrt
().
item
()
if
energy
<
energy_floor
:
return
wav
transform
=
torchaudio
.
transforms
.
Loudness
(
sample_rate
)
input_loudness_db
=
transform
(
wav
).
item
()
# calculate the gain needed to scale to the desired loudness level
delta_loudness
=
-
loudness_headroom_db
-
input_loudness_db
gain
=
10.0
**
(
delta_loudness
/
20.0
)
output
=
gain
*
wav
if
loudness_compressor
:
output
=
torch
.
tanh
(
output
)
assert
output
.
isfinite
().
all
(),
(
input_loudness_db
,
wav
.
pow
(
2
).
mean
().
sqrt
())
return
output
def
normalize
(
S
:
np
.
ndarray
,
*
,
norm
:
Optional
[
float
]
=
np
.
inf
,
axis
:
Optional
[
int
]
=
0
,
threshold
:
Optional
[
_FloatLike_co
]
=
None
,
fill
:
Optional
[
bool
]
=
None
,
)
->
np
.
ndarray
:
"""Normalize an array along a chosen axis.
"""
# Avoid div-by-zero
if
threshold
is
None
:
threshold
=
tiny
(
S
)
elif
threshold
<=
0
:
raise
ParameterError
(
f
"threshold=
{
threshold
}
must be strictly positive"
)
if
fill
not
in
[
None
,
False
,
True
]:
raise
ParameterError
(
f
"fill=
{
fill
}
must be None or boolean"
)
if
not
np
.
isfinite
(
S
).
all
():
raise
ParameterError
(
"Input must be finite"
)
# All norms only depend on magnitude, let's do that first
S
=
S
.
numpy
()
mag
=
np
.
abs
(
S
).
astype
(
float
)
# For max/min norms, filling with 1 works
fill_norm
=
1
if
norm
is
None
:
return
S
elif
norm
==
np
.
inf
:
length
=
np
.
max
(
mag
,
axis
=
axis
,
keepdims
=
True
)
elif
norm
==
-
np
.
inf
:
length
=
np
.
min
(
mag
,
axis
=
axis
,
keepdims
=
True
)
elif
norm
==
0
:
if
fill
is
True
:
raise
ParameterError
(
"Cannot normalize with norm=0 and fill=True"
)
length
=
np
.
sum
(
mag
>
0
,
axis
=
axis
,
keepdims
=
True
,
dtype
=
mag
.
dtype
)
elif
np
.
issubdtype
(
type
(
norm
),
np
.
number
)
and
norm
>
0
:
length
=
np
.
sum
(
mag
**
norm
,
axis
=
axis
,
keepdims
=
True
)
**
(
1.0
/
norm
)
if
axis
is
None
:
fill_norm
=
mag
.
size
**
(
-
1.0
/
norm
)
else
:
fill_norm
=
mag
.
shape
[
axis
]
**
(
-
1.0
/
norm
)
else
:
raise
ParameterError
(
f
"Unsupported norm:
{
repr
(
norm
)
}
"
)
# indices where norm is below the threshold
small_idx
=
length
<
threshold
Snorm
=
np
.
empty_like
(
S
)
if
fill
is
None
:
# Leave small indices un-normalized
length
[
small_idx
]
=
1.0
Snorm
[:]
=
S
/
length
elif
fill
:
# If we have a non-zero fill value, we locate those entries by
# doing a nan-divide.
# If S was finite, then length is finite (except for small positions)
length
[
small_idx
]
=
np
.
nan
Snorm
[:]
=
S
/
length
Snorm
[
np
.
isnan
(
Snorm
)]
=
fill_norm
else
:
# Set small values to zero by doing an inf-divide.
# This is safe (by IEEE-754) as long as S is finite.
length
[
small_idx
]
=
np
.
inf
Snorm
[:]
=
S
/
length
return
Snorm
def
normalize_audio
(
wav
:
torch
.
Tensor
,
normalize
:
bool
=
True
,
strategy
:
str
=
'peak'
,
peak_clip_headroom_db
:
float
=
1
,
rms_headroom_db
:
float
=
18
,
loudness_headroom_db
:
float
=
14
,
loudness_compressor
:
bool
=
False
,
log_clipping
:
bool
=
False
,
sample_rate
:
tp
.
Optional
[
int
]
=
None
,
stem_name
:
tp
.
Optional
[
str
]
=
None
)
->
torch
.
Tensor
:
"""Normalize the audio according to the prescribed strategy (see after).
Args:
wav (torch.Tensor): Audio data.
normalize (bool): if `True` (default), normalizes according to the prescribed
strategy (see after). If `False`, the strategy is only used in case clipping
would happen.
strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
with extra headroom to avoid clipping. 'clip' just clips.
peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
than the `peak_clip` one to avoid further clipping.
loudness_headroom_db (float): Target loudness for loudness normalization.
loudness_compressor (bool): If True, uses tanh based soft clipping.
log_clipping (bool): If True, basic logging on stderr when clipping still
occurs despite strategy (only for 'rms').
sample_rate (int): Sample rate for the audio data (required for loudness).
stem_name (str, optional): Stem name for clipping logging.
Returns:
torch.Tensor: Normalized audio.
"""
scale_peak
=
10
**
(
-
peak_clip_headroom_db
/
20
)
scale_rms
=
10
**
(
-
rms_headroom_db
/
20
)
if
strategy
==
'peak'
:
rescaling
=
(
scale_peak
/
wav
.
abs
().
max
())
if
normalize
or
rescaling
<
1
:
wav
=
wav
*
rescaling
elif
strategy
==
'clip'
:
wav
=
wav
.
clamp
(
-
scale_peak
,
scale_peak
)
elif
strategy
==
'rms'
:
mono
=
wav
.
mean
(
dim
=
0
)
rescaling
=
scale_rms
/
mono
.
pow
(
2
).
mean
().
sqrt
()
if
normalize
or
rescaling
<
1
:
wav
=
wav
*
rescaling
_clip_wav
(
wav
,
log_clipping
=
log_clipping
,
stem_name
=
stem_name
)
elif
strategy
==
'loudness'
:
assert
sample_rate
is
not
None
,
"Loudness normalization requires sample rate."
wav
=
normalize_loudness
(
wav
,
sample_rate
,
loudness_headroom_db
,
loudness_compressor
)
_clip_wav
(
wav
,
log_clipping
=
log_clipping
,
stem_name
=
stem_name
)
else
:
assert
wav
.
abs
().
max
()
<
1
assert
strategy
==
''
or
strategy
==
'none'
,
f
"Unexpected strategy: '
{
strategy
}
'"
return
wav
def
f32_pcm
(
wav
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Convert audio to float 32 bits PCM format.
Args:
wav (torch.tensor): Input wav tensor
Returns:
same wav in float32 PCM format
"""
if
wav
.
dtype
.
is_floating_point
:
return
wav
elif
wav
.
dtype
==
torch
.
int16
:
return
wav
.
float
()
/
2
**
15
elif
wav
.
dtype
==
torch
.
int32
:
return
wav
.
float
()
/
2
**
31
raise
ValueError
(
f
"Unsupported wav dtype:
{
wav
.
dtype
}
"
)
def
i16_pcm
(
wav
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Convert audio to int 16 bits PCM format.
..Warning:: There exist many formula for doing this conversion. None are perfect
due to the asymmetry of the int16 range. One either have possible clipping, DC offset,
or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom,
it is possible that `i16_pcm(f32_pcm)) != Identity`.
Args:
wav (torch.tensor): Input wav tensor
Returns:
same wav in float16 PCM format
"""
if
wav
.
dtype
.
is_floating_point
:
assert
wav
.
abs
().
max
()
<=
1
candidate
=
(
wav
*
2
**
15
).
round
()
if
candidate
.
max
()
>=
2
**
15
:
# clipping would occur
candidate
=
(
wav
*
(
2
**
15
-
1
)).
round
()
return
candidate
.
short
()
else
:
assert
wav
.
dtype
==
torch
.
int16
return
wav
def
compress
(
wav
:
torch
.
Tensor
,
sr
:
int
,
target_format
:
tp
.
Literal
[
"mp3"
,
"ogg"
,
"flac"
]
=
"mp3"
,
bitrate
:
str
=
"128k"
)
->
tp
.
Tuple
[
torch
.
Tensor
,
int
]:
"""Convert audio wave form to a specified lossy format: mp3, ogg, flac
Args:
wav (torch.Tensor): Input wav tensor.
sr (int): Sampling rate.
target_format (str): Compression format (e.g., 'mp3').
bitrate (str): Bitrate for compression.
Returns:
Tuple of compressed WAV tensor and sampling rate.
"""
# Extract the bit rate from string (e.g., '128k')
match
=
re
.
search
(
r
"\d+(\.\d+)?"
,
str
(
bitrate
))
parsed_bitrate
=
float
(
match
.
group
())
if
match
else
None
assert
parsed_bitrate
,
f
"Invalid bitrate specified (got
{
parsed_bitrate
}
)"
try
:
# Create a virtual file instead of saving to disk
buffer
=
io
.
BytesIO
()
torchaudio
.
save
(
buffer
,
wav
,
sr
,
format
=
target_format
,
bits_per_sample
=
parsed_bitrate
,
)
# Move to the beginning of the file
buffer
.
seek
(
0
)
compressed_wav
,
sr
=
torchaudio
.
load
(
buffer
)
return
compressed_wav
,
sr
except
RuntimeError
:
logger
.
warning
(
f
"compression failed skipping compression:
{
format
}
{
parsed_bitrate
}
"
)
return
wav
,
sr
def
get_mp3
(
wav_tensor
:
torch
.
Tensor
,
sr
:
int
,
bitrate
:
str
=
"128k"
)
->
torch
.
Tensor
:
"""Convert a batch of audio files to MP3 format, maintaining the original shape.
This function takes a batch of audio files represented as a PyTorch tensor, converts
them to MP3 format using the specified bitrate, and returns the batch in the same
shape as the input.
Args:
wav_tensor (torch.Tensor): Batch of audio files represented as a tensor.
Shape should be (batch_size, channels, length).
sr (int): Sampling rate of the audio.
bitrate (str): Bitrate for MP3 conversion, default is '128k'.
Returns:
torch.Tensor: Batch of audio files converted to MP3 format, with the same
shape as the input tensor.
"""
device
=
wav_tensor
.
device
batch_size
,
channels
,
original_length
=
wav_tensor
.
shape
# Flatten tensor for conversion and move to CPU
wav_tensor_flat
=
wav_tensor
.
view
(
1
,
-
1
).
cpu
()
# Convert to MP3 format with specified bitrate
wav_tensor_flat
,
_
=
compress
(
wav_tensor_flat
,
sr
,
bitrate
=
bitrate
)
# Reshape back to original batch format and trim or pad if necessary
wav_tensor
=
wav_tensor_flat
.
view
(
batch_size
,
channels
,
-
1
)
compressed_length
=
wav_tensor
.
shape
[
-
1
]
if
compressed_length
>
original_length
:
wav_tensor
=
wav_tensor
[:,
:,
:
original_length
]
# Trim excess frames
elif
compressed_length
<
original_length
:
padding
=
torch
.
zeros
(
batch_size
,
channels
,
original_length
-
compressed_length
,
device
=
device
)
wav_tensor
=
torch
.
cat
((
wav_tensor
,
padding
),
dim
=-
1
)
# Pad with zeros
# Move tensor back to the original device
return
wav_tensor
.
to
(
device
)
def
get_aac
(
wav_tensor
:
torch
.
Tensor
,
sr
:
int
,
bitrate
:
str
=
"128k"
,
lowpass_freq
:
tp
.
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
"""Converts a batch of audio tensors to AAC format and then back to tensors.
This function first saves the input tensor batch as WAV files, then uses FFmpeg to convert
these WAV files to AAC format. Finally, it loads the AAC files back into tensors.
Args:
wav_tensor (torch.Tensor): A batch of audio files represented as a tensor.
Shape should be (batch_size, channels, length).
sr (int): Sampling rate of the audio.
bitrate (str): Bitrate for AAC conversion, default is '128k'.
lowpass_freq (Optional[int]): Frequency for a low-pass filter. If None, no filter is applied.
Returns:
torch.Tensor: Batch of audio files converted to AAC and back, with the same
shape as the input tensor.
"""
import
tempfile
import
subprocess
device
=
wav_tensor
.
device
batch_size
,
channels
,
original_length
=
wav_tensor
.
shape
# Parse the bitrate value from the string
match
=
re
.
search
(
r
"\d+(\.\d+)?"
,
bitrate
)
parsed_bitrate
=
(
match
.
group
()
if
match
else
"128"
)
# Default to 128 if parsing fails
# Flatten tensor for conversion and move to CPU
wav_tensor_flat
=
wav_tensor
.
view
(
1
,
-
1
).
cpu
()
with
tempfile
.
NamedTemporaryFile
(
suffix
=
".wav"
)
as
f_in
,
tempfile
.
NamedTemporaryFile
(
suffix
=
".aac"
)
as
f_out
:
input_path
,
output_path
=
f_in
.
name
,
f_out
.
name
# Save the tensor as a WAV file
torchaudio
.
save
(
input_path
,
wav_tensor_flat
,
sr
,
backend
=
"ffmpeg"
)
# Prepare FFmpeg command for AAC conversion
command
=
[
"ffmpeg"
,
"-y"
,
"-i"
,
input_path
,
"-ar"
,
str
(
sr
),
"-b:a"
,
f
"
{
parsed_bitrate
}
k"
,
"-c:a"
,
"aac"
,
]
if
lowpass_freq
is
not
None
:
command
+=
[
"-cutoff"
,
str
(
lowpass_freq
)]
command
.
append
(
output_path
)
try
:
# Run FFmpeg and suppress output
subprocess
.
run
(
command
,
stdout
=
subprocess
.
DEVNULL
,
stderr
=
subprocess
.
DEVNULL
)
# Load the AAC audio back into a tensor
aac_tensor
,
_
=
torchaudio
.
load
(
output_path
,
backend
=
"ffmpeg"
)
except
Exception
as
exc
:
raise
RuntimeError
(
"Failed to run command "
".join(command)} "
"(Often this means ffmpeg is not installed or the encoder is not supported, "
"make sure you installed an older version ffmpeg<5)"
)
from
exc
original_length_flat
=
batch_size
*
channels
*
original_length
compressed_length_flat
=
aac_tensor
.
shape
[
-
1
]
# Trim excess frames
if
compressed_length_flat
>
original_length_flat
:
aac_tensor
=
aac_tensor
[:,
:
original_length_flat
]
# Pad the shortedn frames
elif
compressed_length_flat
<
original_length_flat
:
padding
=
torch
.
zeros
(
1
,
original_length_flat
-
compressed_length_flat
,
device
=
device
)
aac_tensor
=
torch
.
cat
((
aac_tensor
,
padding
),
dim
=-
1
)
# Reshape and adjust length to match original tensor
wav_tensor
=
aac_tensor
.
view
(
batch_size
,
channels
,
-
1
)
compressed_length
=
wav_tensor
.
shape
[
-
1
]
assert
compressed_length
==
original_length
,
(
"AAC-compressed audio does not have the same frames as original one. "
"One reason can be ffmpeg is not installed and used as proper backed "
"for torchaudio, or the AAC encoder is not correct. Run "
"`torchaudio.utils.ffmpeg_utils.get_audio_encoders()` and make sure we see entry for"
"AAC in the output."
)
return
wav_tensor
.
to
(
device
)
\ No newline at end of file
examples/music_generation/inspiremusic/utils/binary.py
0 → 100644
View file @
0112b0f0
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`."""
import
io
import
json
import
struct
import
typing
as
tp
# format is `ECDC` magic code, followed by the header size as uint32.
# Then an uint8 indicates the protocol version (0.)
# The header is then provided as json and should contain all required
# informations for decoding. A raw stream of bytes is then provided
# and should be interpretable using the json header.
_encodec_header_struct
=
struct
.
Struct
(
'!4sBI'
)
_ENCODEC_MAGIC
=
b
'ECDC'
def
write_ecdc_header
(
fo
:
tp
.
IO
[
bytes
],
metadata
:
tp
.
Any
):
meta_dumped
=
json
.
dumps
(
metadata
).
encode
(
'utf-8'
)
version
=
0
header
=
_encodec_header_struct
.
pack
(
_ENCODEC_MAGIC
,
version
,
len
(
meta_dumped
))
fo
.
write
(
header
)
fo
.
write
(
meta_dumped
)
fo
.
flush
()
def
_read_exactly
(
fo
:
tp
.
IO
[
bytes
],
size
:
int
)
->
bytes
:
buf
=
b
""
while
len
(
buf
)
<
size
:
new_buf
=
fo
.
read
(
size
)
if
not
new_buf
:
raise
EOFError
(
"Impossible to read enough data from the stream, "
f
"
{
size
}
bytes remaining."
)
buf
+=
new_buf
size
-=
len
(
new_buf
)
return
buf
def
read_ecdc_header
(
fo
:
tp
.
IO
[
bytes
]):
header_bytes
=
_read_exactly
(
fo
,
_encodec_header_struct
.
size
)
magic
,
version
,
meta_size
=
_encodec_header_struct
.
unpack
(
header_bytes
)
if
magic
!=
_ENCODEC_MAGIC
:
raise
ValueError
(
"File is not in ECDC format."
)
if
version
!=
0
:
raise
ValueError
(
"Version not supported."
)
meta_bytes
=
_read_exactly
(
fo
,
meta_size
)
return
json
.
loads
(
meta_bytes
.
decode
(
'utf-8'
))
class
BitPacker
:
"""Simple bit packer to handle ints with a non standard width, e.g. 10 bits.
Note that for some bandwidth (1.5, 3), the codebook representation
will not cover an integer number of bytes.
Args:
bits (int): number of bits per value that will be pushed.
fo (IO[bytes]): file-object to push the bytes to.
"""
def
__init__
(
self
,
bits
:
int
,
fo
:
tp
.
IO
[
bytes
]):
self
.
_current_value
=
0
self
.
_current_bits
=
0
self
.
bits
=
bits
self
.
fo
=
fo
def
push
(
self
,
value
:
int
):
"""Push a new value to the stream. This will immediately
write as many uint8 as possible to the underlying file-object."""
self
.
_current_value
+=
(
value
<<
self
.
_current_bits
)
self
.
_current_bits
+=
self
.
bits
while
self
.
_current_bits
>=
8
:
lower_8bits
=
self
.
_current_value
&
0xff
self
.
_current_bits
-=
8
self
.
_current_value
>>=
8
self
.
fo
.
write
(
bytes
([
lower_8bits
]))
def
flush
(
self
):
"""Flushes the remaining partial uint8, call this at the end
of the stream to encode."""
if
self
.
_current_bits
:
self
.
fo
.
write
(
bytes
([
self
.
_current_value
]))
self
.
_current_value
=
0
self
.
_current_bits
=
0
self
.
fo
.
flush
()
class
BitUnpacker
:
"""BitUnpacker does the opposite of `BitPacker`.
Args:
bits (int): number of bits of the values to decode.
fo (IO[bytes]): file-object to push the bytes to.
"""
def
__init__
(
self
,
bits
:
int
,
fo
:
tp
.
IO
[
bytes
]):
self
.
bits
=
bits
self
.
fo
=
fo
self
.
_mask
=
(
1
<<
bits
)
-
1
self
.
_current_value
=
0
self
.
_current_bits
=
0
def
pull
(
self
)
->
tp
.
Optional
[
int
]:
"""
Pull a single value from the stream, potentially reading some
extra bytes from the underlying file-object.
Returns `None` when reaching the end of the stream.
"""
while
self
.
_current_bits
<
self
.
bits
:
buf
=
self
.
fo
.
read
(
1
)
if
not
buf
:
return
None
character
=
buf
[
0
]
self
.
_current_value
+=
character
<<
self
.
_current_bits
self
.
_current_bits
+=
8
out
=
self
.
_current_value
&
self
.
_mask
self
.
_current_value
>>=
self
.
bits
self
.
_current_bits
-=
self
.
bits
return
out
def
test
():
import
torch
torch
.
manual_seed
(
1234
)
for
rep
in
range
(
4
):
length
:
int
=
torch
.
randint
(
10
,
2_000
,
(
1
,
)).
item
()
bits
:
int
=
torch
.
randint
(
1
,
16
,
(
1
,
)).
item
()
tokens
:
tp
.
List
[
int
]
=
torch
.
randint
(
2
**
bits
,
(
length
,
)).
tolist
()
rebuilt
:
tp
.
List
[
int
]
=
[]
buf
=
io
.
BytesIO
()
packer
=
BitPacker
(
bits
,
buf
)
for
token
in
tokens
:
packer
.
push
(
token
)
packer
.
flush
()
buf
.
seek
(
0
)
unpacker
=
BitUnpacker
(
bits
,
buf
)
while
True
:
value
=
unpacker
.
pull
()
if
value
is
None
:
break
rebuilt
.
append
(
value
)
assert
len
(
rebuilt
)
>=
len
(
tokens
),
(
len
(
rebuilt
),
len
(
tokens
))
# The flushing mechanism might lead to "ghost" values at the end of the stream.
assert
len
(
rebuilt
)
<=
len
(
tokens
)
+
8
//
bits
,
(
len
(
rebuilt
),
len
(
tokens
),
bits
)
for
idx
,
(
a
,
b
)
in
enumerate
(
zip
(
tokens
,
rebuilt
)):
assert
a
==
b
,
(
idx
,
a
,
b
)
if
__name__
==
'__main__'
:
test
()
examples/music_generation/inspiremusic/utils/class_utils.py
0 → 100644
View file @
0112b0f0
# Copyright [2023-11-28] <sxc19@mails.tsinghua.edu.cn, Xingchen Song>
# 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
inspiremusic.transformer.activation
import
Swish
from
inspiremusic.transformer.subsampling
import
(
LinearNoSubsampling
,
EmbedinigNoSubsampling
,
Conv1dSubsampling2
,
Conv2dSubsampling4
,
Conv2dSubsampling6
,
Conv2dSubsampling8
,
)
from
inspiremusic.transformer.embedding
import
(
PositionalEncoding
,
RelPositionalEncoding
,
WhisperPositionalEncoding
,
LearnablePositionalEncoding
,
NoPositionalEncoding
)
from
inspiremusic.transformer.attention
import
(
MultiHeadedAttention
,
RelPositionMultiHeadedAttention
)
from
inspiremusic.transformer.embedding
import
EspnetRelPositionalEncoding
from
inspiremusic.transformer.subsampling
import
LegacyLinearNoSubsampling
INSPIREMUSIC_ACTIVATION_CLASSES
=
{
"hardtanh"
:
torch
.
nn
.
Hardtanh
,
"tanh"
:
torch
.
nn
.
Tanh
,
"relu"
:
torch
.
nn
.
ReLU
,
"selu"
:
torch
.
nn
.
SELU
,
"swish"
:
getattr
(
torch
.
nn
,
"SiLU"
,
Swish
),
"gelu"
:
torch
.
nn
.
GELU
,
}
INSPIREMUSIC_SUBSAMPLE_CLASSES
=
{
"linear"
:
LinearNoSubsampling
,
"linear_legacy"
:
LegacyLinearNoSubsampling
,
"embed"
:
EmbedinigNoSubsampling
,
"conv1d2"
:
Conv1dSubsampling2
,
"conv2d"
:
Conv2dSubsampling4
,
"conv2d6"
:
Conv2dSubsampling6
,
"conv2d8"
:
Conv2dSubsampling8
,
'paraformer_dummy'
:
torch
.
nn
.
Identity
}
INSPIREMUSIC_EMB_CLASSES
=
{
"embed"
:
PositionalEncoding
,
"abs_pos"
:
PositionalEncoding
,
"rel_pos"
:
RelPositionalEncoding
,
"rel_pos_espnet"
:
EspnetRelPositionalEncoding
,
"no_pos"
:
NoPositionalEncoding
,
"abs_pos_whisper"
:
WhisperPositionalEncoding
,
"embed_learnable_pe"
:
LearnablePositionalEncoding
,
}
INSPIREMUSIC_ATTENTION_CLASSES
=
{
"selfattn"
:
MultiHeadedAttention
,
"rel_selfattn"
:
RelPositionMultiHeadedAttention
,
}
examples/music_generation/inspiremusic/utils/common.py
0 → 100644
View file @
0112b0f0
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
# 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Unility functions for Transformer."""
from
typing
import
List
import
torch
IGNORE_ID
=
-
1
MUSIC_STRUCTURE_LABELS
=
[
"intro"
,
"verse1"
,
"chorus"
,
"verse2"
,
"outro"
]
DTYPES
=
{
"bf16"
:
torch
.
bfloat16
,
"fp16"
:
torch
.
float16
,
}
def
pad_list
(
xs
:
List
[
torch
.
Tensor
],
pad_value
:
int
):
"""Perform padding for the list of tensors.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
"""
max_len
=
max
([
len
(
item
)
for
item
in
xs
])
batchs
=
len
(
xs
)
ndim
=
xs
[
0
].
ndim
if
ndim
==
1
:
pad_res
=
torch
.
zeros
(
batchs
,
max_len
,
dtype
=
xs
[
0
].
dtype
,
device
=
xs
[
0
].
device
)
elif
ndim
==
2
:
pad_res
=
torch
.
zeros
(
batchs
,
max_len
,
xs
[
0
].
shape
[
1
],
dtype
=
xs
[
0
].
dtype
,
device
=
xs
[
0
].
device
)
elif
ndim
==
3
:
pad_res
=
torch
.
zeros
(
batchs
,
max_len
,
xs
[
0
].
shape
[
1
],
xs
[
0
].
shape
[
2
],
dtype
=
xs
[
0
].
dtype
,
device
=
xs
[
0
].
device
)
else
:
raise
ValueError
(
f
"Unsupported ndim:
{
ndim
}
"
)
pad_res
.
fill_
(
pad_value
)
for
i
in
range
(
batchs
):
pad_res
[
i
,
:
len
(
xs
[
i
])]
=
xs
[
i
]
return
pad_res
def
th_accuracy
(
pad_outputs
:
torch
.
Tensor
,
pad_targets
:
torch
.
Tensor
,
ignore_label
:
int
)
->
torch
.
Tensor
:
"""Calculate accuracy.
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
pad_targets (LongTensor): Target label tensors (B, Lmax).
ignore_label (int): Ignore label id.
Returns:
torch.Tensor: Accuracy value (0.0 - 1.0).
"""
pad_pred
=
pad_outputs
.
view
(
pad_targets
.
size
(
0
),
pad_targets
.
size
(
1
),
pad_outputs
.
size
(
1
)).
argmax
(
2
)
mask
=
pad_targets
!=
ignore_label
numerator
=
torch
.
sum
(
pad_pred
.
masked_select
(
mask
)
==
pad_targets
.
masked_select
(
mask
))
denominator
=
torch
.
sum
(
mask
)
return
(
numerator
/
denominator
).
detach
()
def
get_padding
(
kernel_size
,
dilation
=
1
):
return
int
((
kernel_size
*
dilation
-
dilation
)
/
2
)
def
init_weights
(
m
,
mean
=
0.0
,
std
=
0.01
):
classname
=
m
.
__class__
.
__name__
if
classname
.
find
(
"Conv"
)
!=
-
1
:
m
.
weight
.
data
.
normal_
(
mean
,
std
)
def
topk_sampling
(
weighted_scores
,
decoded_tokens
,
top_k
=
25
):
zeros
=
weighted_scores
.
new_ones
(
weighted_scores
.
shape
)
*
float
(
'-inf'
)
values
,
indices
=
torch
.
topk
(
weighted_scores
,
top_k
)
zeros
.
scatter_
(
-
1
,
indices
,
values
)
return
random_sampling
(
zeros
,
decoded_tokens
)
# Repetition Aware Sampling in VALL-E 2
def
ras_sampling
(
weighted_scores
,
decoded_tokens
,
top_p
=
0.8
,
top_k
=
25
,
win_size
=
10
,
tau_r
=
0.1
):
top_ids
=
nucleus_sampling
(
weighted_scores
,
top_p
=
top_p
,
top_k
=
top_k
)
rep_num
=
(
torch
.
tensor
(
decoded_tokens
[
-
win_size
:]).
to
(
weighted_scores
.
device
)
==
top_ids
).
sum
().
item
()
if
rep_num
>=
win_size
*
tau_r
:
top_ids
=
random_sampling
(
weighted_scores
,
decoded_tokens
)
return
top_ids
def
caras_sampling
(
weighted_scores
,
decoded_tokens
,
top_p
=
0.8
,
top_k
=
25
,
win_size
=
10
,
tau_r
=
0.1
):
weighted_scores
,
cfg_weighted_scores
=
weighted_scores
top_ids
=
nucleus_sampling
(
weighted_scores
,
top_p
=
top_p
,
top_k
=
top_k
)
rep_num
=
(
torch
.
tensor
(
decoded_tokens
[
-
win_size
:]).
to
(
weighted_scores
.
device
)
==
top_ids
).
sum
().
item
()
if
rep_num
>=
win_size
*
tau_r
:
top_ids
=
random_sampling
(
cfg_weighted_scores
,
decoded_tokens
)
return
top_ids
def
nucleus_sampling
(
weighted_scores
,
top_p
=
0.8
,
top_k
=
25
):
prob
,
indices
=
[],
[]
cum_prob
=
0.0
sorted_value
,
sorted_idx
=
weighted_scores
.
softmax
(
dim
=
0
).
sort
(
descending
=
True
,
stable
=
True
)
for
i
in
range
(
len
(
sorted_idx
)):
# sampling both top-p and numbers.
if
cum_prob
<
top_p
and
len
(
prob
)
<
top_k
:
cum_prob
+=
sorted_value
[
i
]
prob
.
append
(
sorted_value
[
i
])
indices
.
append
(
sorted_idx
[
i
])
else
:
break
prob
=
torch
.
tensor
(
prob
).
to
(
weighted_scores
)
indices
=
torch
.
tensor
(
indices
,
dtype
=
torch
.
long
).
to
(
weighted_scores
.
device
)
top_ids
=
indices
[
prob
.
multinomial
(
1
,
replacement
=
True
)]
return
top_ids
def
random_sampling
(
weighted_scores
,
decoded_tokens
):
top_ids
=
weighted_scores
.
softmax
(
dim
=
0
).
multinomial
(
1
,
replacement
=
True
)
return
top_ids
def
fade_in_out
(
fade_in_mel
,
fade_out_mel
,
window
):
device
=
fade_in_mel
.
device
fade_in_mel
,
fade_out_mel
=
fade_in_mel
.
cpu
(),
fade_out_mel
.
cpu
()
mel_overlap_len
=
int
(
window
.
shape
[
0
]
/
2
)
fade_in_mel
[:,
:,
:
mel_overlap_len
]
=
fade_in_mel
[:,
:,
:
mel_overlap_len
]
*
window
[:
mel_overlap_len
]
+
\
fade_out_mel
[:,
:,
-
mel_overlap_len
:]
*
window
[
mel_overlap_len
:]
return
fade_in_mel
.
to
(
device
)
def
set_all_random_seed
(
seed
):
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
def
mask_to_bias
(
mask
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
)
->
torch
.
Tensor
:
assert
mask
.
dtype
==
torch
.
bool
assert
dtype
in
[
torch
.
float32
,
torch
.
bfloat16
,
torch
.
float16
]
mask
=
mask
.
to
(
dtype
)
# attention mask bias
# NOTE(Mddct): torch.finfo jit issues
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
mask
=
(
1.0
-
mask
)
*
torch
.
finfo
(
dtype
).
min
return
mask
\ No newline at end of file
examples/music_generation/inspiremusic/utils/data_utils.py
0 → 100644
View file @
0112b0f0
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
torch.utils.data
import
DataLoader
from
inspiremusic.dataset.dataset
import
Dataset
import
numpy
as
np
import
librosa
def
audio_process_dataset_and_dataloader
(
args
,
configs
):
input_dataset
=
Dataset
(
args
.
input_data
,
data_pipeline
=
configs
[
'data_pipeline'
],
mode
=
'processing'
,
shuffle
=
True
,
partition
=
True
)
# do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
input_data_loader
=
DataLoader
(
input_dataset
,
batch_size
=
None
,
pin_memory
=
args
.
pin_memory
,
num_workers
=
args
.
num_workers
,
prefetch_factor
=
args
.
prefetch
)
return
input_dataset
,
input_data_loader
def
is_silent
(
wav_path
,
threshold
=
0.01
,
frame_length
=
2048
,
hop_length
=
512
):
y
,
sr
=
librosa
.
load
(
wav_path
,
sr
=
None
)
rms
=
librosa
.
feature
.
rms
(
y
=
y
,
frame_length
=
frame_length
,
hop_length
=
hop_length
)[
0
]
silent_frames
=
np
.
sum
(
rms
<
threshold
)
/
len
(
rms
)
silence_fraction_threshold
=
0.95
return
silent_frames
>=
silence_fraction_threshold
def
rich_captions
(
text
=
None
,
tags
=
None
,
lyrics
=
None
,
chorus
=
"verse"
,
start_time
=
0.0
,
end_time
=
30.0
):
if
text
is
None
and
tags
is
None
and
lyrics
is
None
:
return
None
else
:
if
start_time
is
None
:
start_time
=
0.0
if
end_time
is
None
:
end_time
=
30.0
if
chorus
is
None
:
chorus
=
"verse"
captions
=
f
"<|
{
start_time
:.
1
f
}
|><|
{
chorus
}
|>"
if
tags
is
not
None
:
captions
+=
f
"<|
{
tags
}
|>"
if
text
is
not
None
:
captions
+=
f
"<|
{
text
}
|>"
if
lyrics
is
not
None
:
captions
+=
f
"<|lyrics|><|
{
lyrics
}
|>"
captions
+=
f
"<|
{
end_time
:.
1
f
}
|>"
return
captions
def
process_tags
(
infile
,
outfile
,
timefile
=
None
):
key_list
=
[]
with
open
(
infile
,
"r"
)
as
f
:
for
line
in
f
:
sec
=
line
.
strip
()
key_list
.
append
(
sec
)
f
.
close
()
if
timefile
is
None
:
with
open
(
outfile
,
'w'
)
as
f
:
for
k
in
key_list
:
parts
=
k
.
rsplit
(
'_'
,
1
)
text
=
parts
[
0
].
replace
(
'_'
,
' '
)
+
', '
+
parts
[
1
]
caption
=
rich_captions
(
text
,
None
,
None
)
if
caption
is
not
None
:
f
.
write
(
"%s
\t
%s
\n
"
%
(
k
,
caption
))
f
.
close
()
else
:
times
=
{}
with
open
(
timefile
,
"r"
)
as
f
:
for
line
in
f
:
sec
=
line
.
strip
().
split
(
"
\t
"
)
if
len
(
sec
)
==
2
:
times
[
sec
[
0
]]
=
sec
[
1
]
f
.
close
()
with
open
(
outfile
,
'w'
)
as
f
:
for
k
in
key_list
:
parts
=
k
.
rsplit
(
'_'
,
1
)
text
=
parts
[
0
].
replace
(
'_'
,
' '
)
+
', '
+
parts
[
1
]
if
k
in
times
.
keys
():
caption
=
rich_captions
(
text
,
None
,
None
,
"verse"
,
0.0
,
float
(
times
[
k
]))
if
caption
is
not
None
:
f
.
write
(
"%s
\t
%s
\n
"
%
(
k
,
caption
))
f
.
close
()
def
process_trans
(
infile
,
outfile
):
trans
=
{}
with
open
(
infile
,
"r"
)
as
f
:
for
line
in
f
:
sec
=
line
.
strip
().
split
(
"
\t
"
)
if
len
(
sec
)
==
2
:
trans
[
sec
[
0
]]
=
sec
[
1
]
else
:
print
(
line
)
f
.
close
()
with
open
(
outfile
,
'w'
)
as
f
:
for
k
,
v
in
trans
.
items
():
f
.
write
(
"%s
\t
%s
\n
"
%
(
k
,
rich_captions
(
v
)))
f
.
close
()
\ No newline at end of file
examples/music_generation/inspiremusic/utils/executor.py
0 → 100644
View file @
0112b0f0
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
# 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
logging
from
contextlib
import
nullcontext
import
os
import
torch
import
torch.distributed
as
dist
from
inspiremusic.utils.train_utils
import
update_parameter_and_lr
,
log_per_step
,
log_per_save
,
batch_forward
,
batch_backward
,
save_model
,
inspiremusic_join
from
torch.cuda.amp
import
GradScaler
,
autocast
class
Executor
:
def
__init__
(
self
):
self
.
step
=
0
self
.
epoch
=
0
self
.
rank
=
int
(
os
.
environ
.
get
(
'RANK'
,
0
))
self
.
device
=
torch
.
device
(
'cuda:{}'
.
format
(
self
.
rank
))
def
train_one_epoch
(
self
,
model
,
optimizer
,
scheduler
,
train_data_loader
,
cv_data_loader
,
writer
,
info_dict
,
group_join
,
scaler
=
None
):
''' Train one epoch
'''
lr
=
optimizer
.
param_groups
[
0
][
'lr'
]
logging
.
info
(
'Epoch {} TRAIN info lr {} rank {}'
.
format
(
self
.
epoch
,
lr
,
self
.
rank
))
logging
.
info
(
'using accumulate grad, new batch size is {} times'
' larger than before'
.
format
(
info_dict
[
'accum_grad'
]))
# A context manager to be used in conjunction with an instance of
# torch.nn.parallel.DistributedDataParallel to be able to train
# with uneven inputs across participating processes.
model
.
train
()
model_context
=
model
.
join
if
info_dict
[
'train_engine'
]
==
'torch_ddp'
else
nullcontext
with
model_context
():
for
batch_idx
,
batch_dict
in
enumerate
(
train_data_loader
):
info_dict
[
"tag"
]
=
"TRAIN"
info_dict
[
"step"
]
=
self
.
step
info_dict
[
"epoch"
]
=
self
.
epoch
info_dict
[
"batch_idx"
]
=
batch_idx
if
inspiremusic_join
(
group_join
,
info_dict
):
break
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
if
info_dict
[
'train_engine'
]
==
'torch_ddp'
and
(
batch_idx
+
1
)
%
info_dict
[
"accum_grad"
]
!=
0
:
context
=
model
.
no_sync
# Used for single gpu training and DDP gradient synchronization
# processes.
else
:
context
=
nullcontext
with
context
():
with
autocast
(
enabled
=
scaler
is
not
None
):
info_dict
=
batch_forward
(
model
,
batch_dict
,
info_dict
,
scaler
)
info_dict
=
batch_backward
(
model
,
info_dict
,
scaler
)
info_dict
=
update_parameter_and_lr
(
model
,
optimizer
,
scheduler
,
info_dict
,
scaler
)
log_per_step
(
writer
,
info_dict
)
# NOTE specify save_per_step in inspiremusic.yaml if you want to enable step save
if
info_dict
[
'save_per_step'
]
>
0
and
(
self
.
step
+
1
)
%
info_dict
[
'save_per_step'
]
==
0
and
\
(
batch_idx
+
1
)
%
info_dict
[
"accum_grad"
]
==
0
:
dist
.
barrier
()
self
.
cv
(
model
,
cv_data_loader
,
writer
,
info_dict
,
on_batch_end
=
False
,
scaler
=
scaler
)
model
.
train
()
if
(
batch_idx
+
1
)
%
info_dict
[
"accum_grad"
]
==
0
:
self
.
step
+=
1
dist
.
barrier
()
self
.
cv
(
model
,
cv_data_loader
,
writer
,
info_dict
,
on_batch_end
=
True
,
scaler
=
scaler
)
@
torch
.
inference_mode
()
def
cv
(
self
,
model
,
cv_data_loader
,
writer
,
info_dict
,
on_batch_end
=
True
,
capped_at
=
5
,
scaler
=
None
):
''' Cross validation on
'''
logging
.
info
(
'Epoch {} Step {} on_batch_end {} CV rank {}'
.
format
(
self
.
epoch
,
self
.
step
+
1
,
on_batch_end
,
self
.
rank
))
model
.
eval
()
total_num_utts
,
total_loss_dict
=
0
,
{}
# avoid division by 0
stop
=
capped_at
for
batch_idx
,
batch_dict
in
enumerate
(
cv_data_loader
):
info_dict
[
"tag"
]
=
"CV"
info_dict
[
"step"
]
=
self
.
step
info_dict
[
"epoch"
]
=
self
.
epoch
info_dict
[
"batch_idx"
]
=
batch_idx
num_utts
=
len
(
batch_dict
[
"utts"
])
total_num_utts
+=
num_utts
if
capped_at
>
0
:
if
stop
<=
0
:
continue
else
:
stop
-=
1
with
autocast
(
enabled
=
scaler
is
not
None
):
info_dict
=
batch_forward
(
model
,
batch_dict
,
info_dict
,
scaler
)
for
k
,
v
in
info_dict
[
'loss_dict'
].
items
():
if
k
not
in
total_loss_dict
:
total_loss_dict
[
k
]
=
[]
total_loss_dict
[
k
].
append
(
v
.
item
()
*
num_utts
)
log_per_step
(
None
,
info_dict
)
for
k
,
v
in
total_loss_dict
.
items
():
total_loss_dict
[
k
]
=
sum
(
v
)
/
total_num_utts
info_dict
[
'loss_dict'
]
=
total_loss_dict
log_per_save
(
writer
,
info_dict
)
model_name
=
'epoch_{}_whole'
.
format
(
self
.
epoch
)
if
on_batch_end
else
'epoch_{}_step_{}'
.
format
(
self
.
epoch
,
self
.
step
+
1
)
save_model
(
model
,
model_name
,
info_dict
)
examples/music_generation/inspiremusic/utils/file_utils.py
0 → 100644
View file @
0112b0f0
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
# 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
torchaudio
import
logging
logging
.
getLogger
(
'matplotlib'
).
setLevel
(
logging
.
WARNING
)
logging
.
basicConfig
(
level
=
logging
.
DEBUG
,
format
=
'%(asctime)s %(levelname)s %(message)s'
)
def
read_trans
(
list_file
):
trans
=
{}
with
open
(
list_file
,
'r'
,
encoding
=
'utf8'
)
as
fin
:
for
line
in
fin
:
sec
=
line
.
strip
().
split
(
"
\t
"
)
if
len
(
sec
)
>
1
:
if
sec
[
0
]
not
in
trans
.
keys
():
trans
[
sec
[
0
]]
=
sec
[
1
]
return
trans
def
read_scp
(
list_file
):
scp
=
{}
with
open
(
list_file
,
'r'
,
encoding
=
'utf8'
)
as
fin
:
for
line
in
fin
:
sec
=
line
.
strip
().
split
(
" "
)
if
len
(
sec
)
>
1
:
if
sec
[
0
]
not
in
scp
.
keys
():
scp
[
sec
[
0
]]
=
sec
[
1
]
return
scp
def
read_lists
(
list_file
):
lists
=
[]
with
open
(
list_file
,
'r'
,
encoding
=
'utf8'
)
as
fin
:
for
line
in
fin
:
lists
.
append
(
line
.
strip
())
return
lists
def
read_json_lists
(
list_file
):
lists
=
read_lists
(
list_file
)
results
=
{}
for
fn
in
lists
:
with
open
(
fn
,
'r'
,
encoding
=
'utf8'
)
as
fin
:
results
.
update
(
json
.
load
(
fin
))
return
results
def
load_wav
(
wav
,
target_sr
):
audio
,
sample_rate
=
torchaudio
.
load
(
wav
)
audio
=
audio
.
mean
(
dim
=
0
,
keepdim
=
True
)
if
sample_rate
!=
target_sr
:
assert
sample_rate
>
target_sr
,
'wav sample rate {} must be greater than {}'
.
format
(
sample_rate
,
target_sr
)
audio
=
torchaudio
.
transforms
.
Resample
(
orig_freq
=
sample_rate
,
new_freq
=
target_sr
)(
speech
)
return
audio
def
speed_change
(
waveform
,
sample_rate
,
speed_factor
:
str
):
effects
=
[
[
"tempo"
,
speed_factor
],
# speed_factor
[
"rate"
,
f
"
{
sample_rate
}
"
]
]
augmented_waveform
,
new_sample_rate
=
torchaudio
.
sox_effects
.
apply_effects_tensor
(
waveform
,
sample_rate
,
effects
)
return
augmented_waveform
,
new_sample_rate
examples/music_generation/inspiremusic/utils/frontend_utils.py
0 → 100644
View file @
0112b0f0
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
re
chinese_char_pattern
=
re
.
compile
(
r
'[\u4e00-\u9fff]+'
)
# whether contain chinese character
def
contains_chinese
(
text
):
return
bool
(
chinese_char_pattern
.
search
(
text
))
# replace special symbol
def
replace_corner_mark
(
text
):
text
=
text
.
replace
(
'²'
,
'平方'
)
text
=
text
.
replace
(
'³'
,
'立方'
)
return
text
# remove meaningless symbol
def
remove_bracket
(
text
):
text
=
text
.
replace
(
'('
,
''
).
replace
(
')'
,
''
)
text
=
text
.
replace
(
'【'
,
''
).
replace
(
'】'
,
''
)
text
=
text
.
replace
(
'`'
,
''
).
replace
(
'`'
,
''
)
text
=
text
.
replace
(
"——"
,
" "
)
return
text
# spell Arabic numerals
def
spell_out_number
(
text
:
str
,
inflect_parser
):
new_text
=
[]
st
=
None
for
i
,
c
in
enumerate
(
text
):
if
not
c
.
isdigit
():
if
st
is
not
None
:
num_str
=
inflect_parser
.
number_to_words
(
text
[
st
:
i
])
new_text
.
append
(
num_str
)
st
=
None
new_text
.
append
(
c
)
else
:
if
st
is
None
:
st
=
i
if
st
is
not
None
and
st
<
len
(
text
):
num_str
=
inflect_parser
.
number_to_words
(
text
[
st
:])
new_text
.
append
(
num_str
)
return
''
.
join
(
new_text
)
# split paragrah logic:
# 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
# 2. cal sentence len according to lang
# 3. split sentence according to puncatation
def
split_paragraph
(
text
:
str
,
tokenize
,
lang
=
"zh"
,
token_max_n
=
80
,
token_min_n
=
60
,
merge_len
=
20
,
comma_split
=
False
):
def
calc_utt_length
(
_text
:
str
):
if
lang
==
"zh"
:
return
len
(
_text
)
else
:
return
len
(
tokenize
(
_text
))
def
should_merge
(
_text
:
str
):
if
lang
==
"zh"
:
return
len
(
_text
)
<
merge_len
else
:
return
len
(
tokenize
(
_text
))
<
merge_len
if
lang
==
"zh"
:
pounc
=
[
'。'
,
'?'
,
'!'
,
';'
,
':'
,
'、'
,
'.'
,
'?'
,
'!'
,
';'
]
else
:
pounc
=
[
'.'
,
'?'
,
'!'
,
';'
,
':'
]
if
comma_split
:
pounc
.
extend
([
','
,
','
])
st
=
0
utts
=
[]
for
i
,
c
in
enumerate
(
text
):
if
c
in
pounc
:
if
len
(
text
[
st
:
i
])
>
0
:
utts
.
append
(
text
[
st
:
i
]
+
c
)
if
i
+
1
<
len
(
text
)
and
text
[
i
+
1
]
in
[
'"'
,
'”'
]:
tmp
=
utts
.
pop
(
-
1
)
utts
.
append
(
tmp
+
text
[
i
+
1
])
st
=
i
+
2
else
:
st
=
i
+
1
if
len
(
utts
)
==
0
:
if
lang
==
"zh"
:
utts
.
append
(
text
+
'。'
)
else
:
utts
.
append
(
text
+
'.'
)
final_utts
=
[]
cur_utt
=
""
for
utt
in
utts
:
if
calc_utt_length
(
cur_utt
+
utt
)
>
token_max_n
and
calc_utt_length
(
cur_utt
)
>
token_min_n
:
final_utts
.
append
(
cur_utt
)
cur_utt
=
""
cur_utt
=
cur_utt
+
utt
if
len
(
cur_utt
)
>
0
:
if
should_merge
(
cur_utt
)
and
len
(
final_utts
)
!=
0
:
final_utts
[
-
1
]
=
final_utts
[
-
1
]
+
cur_utt
else
:
final_utts
.
append
(
cur_utt
)
return
final_utts
# remove blank between chinese character
def
replace_blank
(
text
:
str
):
out_str
=
[]
for
i
,
c
in
enumerate
(
text
):
if
c
==
" "
:
if
((
text
[
i
+
1
].
isascii
()
and
text
[
i
+
1
]
!=
" "
)
and
(
text
[
i
-
1
].
isascii
()
and
text
[
i
-
1
]
!=
" "
)):
out_str
.
append
(
c
)
else
:
out_str
.
append
(
c
)
return
""
.
join
(
out_str
)
examples/music_generation/inspiremusic/utils/hinter.py
0 → 100644
View file @
0112b0f0
import
sys
import
torch.distributed
import
logging
HINTED
=
set
()
def
hint_once
(
content
,
uid
,
rank
=
None
):
if
(
rank
is
None
)
or
(
not
torch
.
distributed
.
is_initialized
())
or
torch
.
distributed
.
get_rank
()
==
rank
:
if
uid
not
in
HINTED
:
logging
.
info
(
content
,
stacklevel
=
3
)
HINTED
.
add
(
uid
)
\ No newline at end of file
examples/music_generation/inspiremusic/utils/losses.py
0 → 100644
View file @
0112b0f0
import
torch
import
torch.nn.functional
as
F
def
tpr_loss
(
disc_real_outputs
,
disc_generated_outputs
,
tau
):
loss
=
0
for
dr
,
dg
in
zip
(
disc_real_outputs
,
disc_generated_outputs
):
m_DG
=
torch
.
median
((
dr
-
dg
))
L_rel
=
torch
.
mean
((((
dr
-
dg
)
-
m_DG
)
**
2
)[
dr
<
dg
+
m_DG
])
loss
+=
tau
-
F
.
relu
(
tau
-
L_rel
)
return
loss
def
mel_loss
(
real_speech
,
generated_speech
,
mel_transforms
):
loss
=
0
for
transform
in
mel_transforms
:
mel_r
=
transform
(
real_speech
)
mel_g
=
transform
(
generated_speech
)
loss
+=
F
.
l1_loss
(
mel_g
,
mel_r
)
return
loss
examples/music_generation/inspiremusic/utils/mask.py
0 → 100644
View file @
0112b0f0
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
# 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
'''
def subsequent_mask(
size: int,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size).
This mask is used only in decoder which works in an auto-regressive mode.
This means the current step could only do attention with its left steps.
In encoder, fully attention is used when streaming is not necessary and
the sequence is not long. In this case, no attention mask is needed.
When streaming is need, chunk-based attention is used in encoder. See
subsequent_chunk_mask for the chunk-based attention mask.
Args:
size (int): size of mask
str device (str): "cpu" or "cuda" or torch.Tensor.device
dtype (torch.device): result dtype
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_mask(3)
[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]
"""
ret = torch.ones(size, size, device=device, dtype=torch.bool)
return torch.tril(ret)
'''
def
subsequent_mask
(
size
:
int
,
device
:
torch
.
device
=
torch
.
device
(
"cpu"
),
)
->
torch
.
Tensor
:
"""Create mask for subsequent steps (size, size).
This mask is used only in decoder which works in an auto-regressive mode.
This means the current step could only do attention with its left steps.
In encoder, fully attention is used when streaming is not necessary and
the sequence is not long. In this case, no attention mask is needed.
When streaming is need, chunk-based attention is used in encoder. See
subsequent_chunk_mask for the chunk-based attention mask.
Args:
size (int): size of mask
str device (str): "cpu" or "cuda" or torch.Tensor.device
dtype (torch.device): result dtype
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_mask(3)
[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]
"""
arange
=
torch
.
arange
(
size
,
device
=
device
)
mask
=
arange
.
expand
(
size
,
size
)
arange
=
arange
.
unsqueeze
(
-
1
)
mask
=
mask
<=
arange
return
mask
def
subsequent_chunk_mask
(
size
:
int
,
chunk_size
:
int
,
num_left_chunks
:
int
=
-
1
,
device
:
torch
.
device
=
torch
.
device
(
"cpu"
),
)
->
torch
.
Tensor
:
"""Create mask for subsequent steps (size, size) with chunk size,
this is for streaming encoder
Args:
size (int): size of mask
chunk_size (int): size of chunk
num_left_chunks (int): number of left chunks
<0: use full chunk
>=0: use num_left_chunks
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_chunk_mask(4, 2)
[[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 1],
[1, 1, 1, 1]]
"""
ret
=
torch
.
zeros
(
size
,
size
,
device
=
device
,
dtype
=
torch
.
bool
)
for
i
in
range
(
size
):
if
num_left_chunks
<
0
:
start
=
0
else
:
start
=
max
((
i
//
chunk_size
-
num_left_chunks
)
*
chunk_size
,
0
)
ending
=
min
((
i
//
chunk_size
+
1
)
*
chunk_size
,
size
)
ret
[
i
,
start
:
ending
]
=
True
return
ret
def
add_optional_chunk_mask
(
xs
:
torch
.
Tensor
,
masks
:
torch
.
Tensor
,
use_dynamic_chunk
:
bool
,
use_dynamic_left_chunk
:
bool
,
decoding_chunk_size
:
int
,
static_chunk_size
:
int
,
num_decoding_left_chunks
:
int
,
enable_full_context
:
bool
=
True
):
""" Apply optional mask for encoder.
Args:
xs (torch.Tensor): padded input, (B, L, D), L for max length
mask (torch.Tensor): mask for xs, (B, 1, L)
use_dynamic_chunk (bool): whether to use dynamic chunk or not
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
training.
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
static_chunk_size (int): chunk size for static chunk training/decoding
if it's greater than 0, if use_dynamic_chunk is true,
this parameter will be ignored
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
enable_full_context (bool):
True: chunk size is either [1, 25] or full context(max_len)
False: chunk size ~ U[1, 25]
Returns:
torch.Tensor: chunk mask of the input xs.
"""
# Whether to use chunk mask or not
if
use_dynamic_chunk
:
max_len
=
xs
.
size
(
1
)
if
decoding_chunk_size
<
0
:
chunk_size
=
max_len
num_left_chunks
=
-
1
elif
decoding_chunk_size
>
0
:
chunk_size
=
decoding_chunk_size
num_left_chunks
=
num_decoding_left_chunks
else
:
# chunk size is either [1, 25] or full context(max_len).
# Since we use 4 times subsampling and allow up to 1s(100 frames)
# delay, the maximum frame is 100 / 4 = 25.
chunk_size
=
torch
.
randint
(
1
,
max_len
,
(
1
,
)).
item
()
num_left_chunks
=
-
1
if
chunk_size
>
max_len
//
2
and
enable_full_context
:
chunk_size
=
max_len
else
:
chunk_size
=
chunk_size
%
25
+
1
if
use_dynamic_left_chunk
:
max_left_chunks
=
(
max_len
-
1
)
//
chunk_size
num_left_chunks
=
torch
.
randint
(
0
,
max_left_chunks
,
(
1
,
)).
item
()
chunk_masks
=
subsequent_chunk_mask
(
xs
.
size
(
1
),
chunk_size
,
num_left_chunks
,
xs
.
device
)
# (L, L)
chunk_masks
=
chunk_masks
.
unsqueeze
(
0
)
# (1, L, L)
chunk_masks
=
masks
&
chunk_masks
# (B, L, L)
elif
static_chunk_size
>
0
:
num_left_chunks
=
num_decoding_left_chunks
chunk_masks
=
subsequent_chunk_mask
(
xs
.
size
(
1
),
static_chunk_size
,
num_left_chunks
,
xs
.
device
)
# (L, L)
chunk_masks
=
chunk_masks
.
unsqueeze
(
0
)
# (1, L, L)
chunk_masks
=
masks
&
chunk_masks
# (B, L, L)
else
:
chunk_masks
=
masks
return
chunk_masks
def
make_pad_mask
(
lengths
:
torch
.
Tensor
,
max_len
:
int
=
0
)
->
torch
.
Tensor
:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
Args:
lengths (torch.Tensor): Batch of lengths (B,).
Returns:
torch.Tensor: Mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
batch_size
=
lengths
.
size
(
0
)
max_len
=
max_len
if
max_len
>
0
else
lengths
.
max
().
item
()
seq_range
=
torch
.
arange
(
0
,
max_len
,
dtype
=
torch
.
int64
,
device
=
lengths
.
device
)
seq_range_expand
=
seq_range
.
unsqueeze
(
0
).
expand
(
batch_size
,
max_len
)
seq_length_expand
=
lengths
.
unsqueeze
(
-
1
)
mask
=
seq_range_expand
>=
seq_length_expand
return
mask
examples/music_generation/inspiremusic/utils/scheduler.py
0 → 100644
View file @
0112b0f0
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
# 2022 Ximalaya Inc (Yuguang Yang)
# 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
# NeMo(https://github.com/NVIDIA/NeMo)
from
typing
import
Union
import
math
import
warnings
import
torch
from
torch.optim.lr_scheduler
import
_LRScheduler
class
WarmupLR
(
_LRScheduler
):
"""The WarmupLR scheduler
This scheduler is almost same as NoamLR Scheduler except for following
difference:
NoamLR:
lr = optimizer.lr * model_size ** -0.5
* min(step ** -0.5, step * warmup_step ** -1.5)
WarmupLR:
lr = optimizer.lr * warmup_step ** 0.5
* min(step ** -0.5, step * warmup_step ** -1.5)
Note that the maximum lr equals to optimizer.lr in this scheduler.
"""
def
__init__
(
self
,
optimizer
:
torch
.
optim
.
Optimizer
,
warmup_steps
:
Union
[
int
,
float
]
=
25000
,
last_epoch
:
int
=
-
1
,
):
self
.
warmup_steps
=
warmup_steps
# __init__() must be invoked before setting field
# because step() is also invoked in __init__()
super
().
__init__
(
optimizer
,
last_epoch
)
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
(warmup_steps=
{
self
.
warmup_steps
}
)"
def
get_lr
(
self
):
step_num
=
self
.
last_epoch
+
1
if
self
.
warmup_steps
==
0
:
return
[
lr
*
step_num
**-
0.5
for
lr
in
self
.
base_lrs
]
else
:
return
[
lr
*
self
.
warmup_steps
**
0.5
*
min
(
step_num
**-
0.5
,
step_num
*
self
.
warmup_steps
**-
1.5
)
for
lr
in
self
.
base_lrs
]
def
set_step
(
self
,
step
:
int
):
self
.
last_epoch
=
step
class
WarmupPolicy
(
_LRScheduler
):
"""Adds warmup kwargs and warmup logic to lr policy.
All arguments should be passed as kwargs for clarity,
Args:
warmup_steps: Number of training steps in warmup stage
warmup_ratio: Ratio of warmup steps to total steps
max_steps: Total number of steps while training or `None` for
infinite training
"""
def
__init__
(
self
,
optimizer
,
*
,
warmup_steps
=
None
,
warmup_ratio
=
None
,
max_steps
=
None
,
min_lr
=
0.0
,
last_epoch
=-
1
):
assert
not
(
warmup_steps
is
not
None
and
warmup_ratio
is
not
None
),
\
"Either use particular number of step or ratio"
assert
warmup_ratio
is
None
or
max_steps
is
not
None
,
\
"If there is a ratio, there should be a total steps"
# It is necessary to assign all attributes *before* __init__,
# as class is wrapped by an inner class.
self
.
max_steps
=
max_steps
if
warmup_steps
is
not
None
:
self
.
warmup_steps
=
warmup_steps
elif
warmup_ratio
is
not
None
:
self
.
warmup_steps
=
int
(
warmup_ratio
*
max_steps
)
else
:
self
.
warmup_steps
=
0
self
.
min_lr
=
min_lr
super
().
__init__
(
optimizer
,
last_epoch
)
def
get_lr
(
self
):
if
not
self
.
_get_lr_called_within_step
:
warnings
.
warn
(
"To get the last learning rate computed "
"by the scheduler, please use `get_last_lr()`."
,
UserWarning
,
stacklevel
=
2
)
step
=
self
.
last_epoch
if
step
<=
self
.
warmup_steps
and
self
.
warmup_steps
>
0
:
return
self
.
_get_warmup_lr
(
step
)
if
step
>
self
.
max_steps
:
return
[
self
.
min_lr
for
_
in
self
.
base_lrs
]
return
self
.
_get_lr
(
step
)
def
_get_warmup_lr
(
self
,
step
):
lr_val
=
(
step
+
1
)
/
(
self
.
warmup_steps
+
1
)
return
[
initial_lr
*
lr_val
for
initial_lr
in
self
.
base_lrs
]
def
_get_lr
(
self
,
step
):
"""Simple const lr policy"""
return
self
.
base_lrs
class
SquareRootConstantPolicy
(
_LRScheduler
):
"""Adds warmup kwargs and warmup logic to lr policy.
All arguments should be passed as kwargs for clarity,
Args:
warmup_steps: Number of training steps in warmup stage
warmup_ratio: Ratio of warmup steps to total steps
max_steps: Total number of steps while training or `None` for
infinite training
"""
def
__init__
(
self
,
optimizer
,
*
,
constant_steps
=
None
,
constant_ratio
=
None
,
max_steps
=
None
,
min_lr
=
0.0
,
last_epoch
=-
1
):
assert
not
(
constant_steps
is
not
None
and
constant_ratio
is
not
None
),
\
"Either use particular number of step or ratio"
assert
constant_ratio
is
None
or
max_steps
is
not
None
,
\
"If there is a ratio, there should be a total steps"
# It is necessary to assign all attributes *before* __init__,
# as class is wrapped by an inner class.
self
.
max_steps
=
max_steps
if
constant_steps
is
not
None
:
self
.
constant_steps
=
constant_steps
elif
constant_ratio
is
not
None
:
self
.
constant_steps
=
int
(
constant_ratio
*
max_steps
)
else
:
self
.
constant_steps
=
0
self
.
constant_lr
=
1
/
(
constant_steps
**
0.5
)
self
.
min_lr
=
min_lr
super
().
__init__
(
optimizer
,
last_epoch
)
def
get_lr
(
self
):
if
not
self
.
_get_lr_called_within_step
:
warnings
.
warn
(
"To get the last learning rate computed "
"by the scheduler, please use `get_last_lr()`."
,
UserWarning
,
stacklevel
=
2
)
step
=
self
.
last_epoch
if
step
<=
self
.
constant_steps
:
return
[
self
.
constant_lr
for
_
in
self
.
base_lrs
]
if
step
>
self
.
max_steps
:
return
[
self
.
min_lr
for
_
in
self
.
base_lrs
]
return
self
.
_get_lr
(
step
)
def
_get_lr
(
self
,
step
):
"""Simple const lr policy"""
return
self
.
base_lrs
class
WarmupHoldPolicy
(
WarmupPolicy
):
"""Variant of WarmupPolicy which maintains high
learning rate for a defined number of steps.
All arguments should be passed as kwargs for clarity,
Args:
warmup_steps: Number of training steps in warmup stage
warmup_ratio: Ratio of warmup steps to total steps
hold_steps: Number of training steps to
hold the learning rate after warm up
hold_ratio: Ratio of hold steps to total steps
max_steps: Total number of steps while training or `None` for
infinite training
"""
def
__init__
(
self
,
optimizer
,
*
,
warmup_steps
=
None
,
warmup_ratio
=
None
,
hold_steps
=
None
,
hold_ratio
=
None
,
max_steps
=
None
,
min_lr
=
0.0
,
last_epoch
=-
1
,
):
assert
not
(
hold_steps
is
not
None
and
hold_ratio
is
not
None
),
\
"Either use particular number of step or ratio"
assert
hold_ratio
is
None
or
max_steps
is
not
None
,
\
"If there is a ratio, there should be a total steps"
self
.
min_lr
=
min_lr
self
.
_last_warmup_lr
=
0.0
# Necessary to duplicate as class attributes are hidden in inner class
self
.
max_steps
=
max_steps
if
warmup_steps
is
not
None
:
self
.
warmup_steps
=
warmup_steps
elif
warmup_ratio
is
not
None
:
self
.
warmup_steps
=
int
(
warmup_ratio
*
max_steps
)
else
:
self
.
warmup_steps
=
0
if
hold_steps
is
not
None
:
self
.
hold_steps
=
hold_steps
+
self
.
warmup_steps
elif
hold_ratio
is
not
None
:
self
.
hold_steps
=
int
(
hold_ratio
*
max_steps
)
+
self
.
warmup_steps
else
:
self
.
hold_steps
=
0
super
().
__init__
(
optimizer
,
warmup_steps
=
warmup_steps
,
warmup_ratio
=
warmup_ratio
,
max_steps
=
max_steps
,
last_epoch
=
last_epoch
,
min_lr
=
min_lr
,
)
def
get_lr
(
self
):
if
not
self
.
_get_lr_called_within_step
:
warnings
.
warn
(
"To get the last learning rate computed by the scheduler,"
" "
"please use `get_last_lr()`."
,
UserWarning
,
stacklevel
=
2
)
step
=
self
.
last_epoch
# Warmup phase
if
step
<=
self
.
warmup_steps
and
self
.
warmup_steps
>
0
:
return
self
.
_get_warmup_lr
(
step
)
# Hold phase
if
(
step
>=
self
.
warmup_steps
)
and
(
step
<
self
.
hold_steps
):
return
self
.
base_lrs
if
step
>
self
.
max_steps
:
return
[
self
.
min_lr
for
_
in
self
.
base_lrs
]
return
self
.
_get_lr
(
step
)
class
WarmupAnnealHoldPolicy
(
_LRScheduler
):
"""Adds warmup kwargs and warmup logic to lr policy.
All arguments should be passed as kwargs for clarity,
Args:
warmup_steps: Number of training steps in warmup stage
warmup_ratio: Ratio of warmup steps to total steps
max_steps: Total number of steps while training or `None` for
infinite training
min_lr: Minimum lr to hold the learning rate after decay at.
constant_steps: Number of steps to keep lr constant at.
constant_ratio: Ratio of steps to keep lr constant.
"""
def
__init__
(
self
,
optimizer
,
*
,
warmup_steps
=
None
,
warmup_ratio
=
None
,
constant_steps
=
None
,
constant_ratio
=
None
,
max_steps
=
None
,
min_lr
=
0.0
,
last_epoch
=-
1
,
):
assert
not
(
warmup_steps
is
not
None
and
warmup_ratio
is
not
None
),
\
"Either use particular number of step or ratio"
assert
not
(
constant_steps
is
not
None
and
constant_ratio
is
not
None
),
\
"Either use constant_steps or constant_ratio"
assert
warmup_ratio
is
None
or
max_steps
is
not
None
,
\
"If there is a ratio, there should be a total steps"
# It is necessary to assign all attributes *before* __init__,
# as class is wrapped by an inner class.
self
.
max_steps
=
max_steps
if
warmup_steps
is
not
None
:
self
.
warmup_steps
=
warmup_steps
elif
warmup_ratio
is
not
None
:
self
.
warmup_steps
=
int
(
warmup_ratio
*
max_steps
)
else
:
self
.
warmup_steps
=
0
if
constant_steps
is
not
None
:
self
.
constant_steps
=
constant_steps
elif
constant_ratio
is
not
None
:
self
.
constant_steps
=
int
(
constant_ratio
*
max_steps
)
else
:
self
.
constant_steps
=
0
self
.
decay_steps
=
max_steps
-
(
self
.
constant_steps
+
self
.
warmup_steps
)
self
.
min_lr
=
min_lr
super
().
__init__
(
optimizer
,
last_epoch
)
def
get_lr
(
self
):
if
not
self
.
_get_lr_called_within_step
:
warnings
.
warn
(
"To get the last learning rate computed "
"by the scheduler, please use `get_last_lr()`."
,
UserWarning
,
stacklevel
=
2
)
step
=
self
.
last_epoch
# Warmup steps
if
self
.
warmup_steps
>
0
and
step
<=
self
.
warmup_steps
:
return
self
.
_get_warmup_lr
(
step
)
# Constant steps after warmup and decay
if
self
.
constant_steps
>
0
and
(
self
.
warmup_steps
+
self
.
decay_steps
)
<
step
<=
self
.
max_steps
:
return
self
.
_get_constant_lr
(
step
)
# Min lr after max steps of updates
if
step
>
self
.
max_steps
:
return
[
self
.
min_lr
for
_
in
self
.
base_lrs
]
return
self
.
_get_lr
(
step
)
def
_get_warmup_lr
(
self
,
step
):
lr_val
=
(
step
+
1
)
/
(
self
.
warmup_steps
+
1
)
return
[
initial_lr
*
lr_val
for
initial_lr
in
self
.
base_lrs
]
def
_get_constant_lr
(
self
,
step
):
return
[
self
.
min_lr
for
_
in
self
.
base_lrs
]
def
_get_lr
(
self
,
step
):
"""Simple const lr policy"""
return
self
.
base_lrs
def
_squareroot_annealing
(
initial_lr
,
step
,
max_steps
,
min_lr
):
mult
=
((
max_steps
-
step
)
/
max_steps
)
**
0.5
out_lr
=
initial_lr
*
mult
out_lr
=
max
(
out_lr
,
min_lr
)
return
out_lr
def
_square_annealing
(
initial_lr
,
step
,
max_steps
,
min_lr
):
mult
=
((
max_steps
-
step
)
/
max_steps
)
**
2
out_lr
=
initial_lr
*
mult
out_lr
=
max
(
out_lr
,
min_lr
)
return
out_lr
def
_cosine_annealing
(
initial_lr
,
step
,
max_steps
,
min_lr
):
mult
=
0.5
*
(
1
+
math
.
cos
(
math
.
pi
*
step
/
max_steps
))
out_lr
=
(
initial_lr
-
min_lr
)
*
mult
+
min_lr
return
out_lr
def
_linear_warmup_with_cosine_annealing
(
max_lr
,
warmup_steps
,
step
,
decay_steps
,
min_lr
):
assert
max_lr
>
min_lr
# Use linear warmup for the initial part.
if
warmup_steps
>
0
and
step
<=
warmup_steps
:
return
max_lr
*
float
(
step
)
/
float
(
warmup_steps
)
# For any steps larger than `decay_steps`, use `min_lr`.
if
step
>
warmup_steps
+
decay_steps
:
return
min_lr
# If we are done with the warmup period, use the decay style.
num_steps_
=
step
-
warmup_steps
decay_steps_
=
decay_steps
decay_ratio
=
float
(
num_steps_
)
/
float
(
decay_steps_
)
assert
decay_ratio
>=
0.0
assert
decay_ratio
<=
1.0
delta_lr
=
max_lr
-
min_lr
coeff
=
0.5
*
(
math
.
cos
(
math
.
pi
*
decay_ratio
)
+
1.0
)
return
min_lr
+
coeff
*
delta_lr
def
_poly_decay
(
initial_lr
,
step
,
decay_steps
,
power
,
min_lr
,
cycle
):
if
cycle
:
multiplier
=
1.0
if
step
==
0
else
math
.
ceil
(
step
/
decay_steps
)
decay_steps
*=
multiplier
else
:
step
=
min
(
step
,
decay_steps
)
p
=
step
/
decay_steps
lr
=
(
initial_lr
-
min_lr
)
*
math
.
pow
(
1.0
-
p
,
power
)
lr
+=
min_lr
return
lr
def
_noam_hold_annealing
(
initial_lr
,
step
,
warmup_steps
,
hold_steps
,
decay_rate
,
min_lr
):
# hold_steps = total number of steps
# to hold the LR, not the warmup + hold steps.
T_warmup_decay
=
max
(
1
,
warmup_steps
**
decay_rate
)
T_hold_decay
=
max
(
1
,
(
step
-
hold_steps
)
**
decay_rate
)
lr
=
(
initial_lr
*
T_warmup_decay
)
/
T_hold_decay
lr
=
max
(
lr
,
min_lr
)
return
lr
class
SquareAnnealing
(
WarmupPolicy
):
def
__init__
(
self
,
optimizer
,
*
,
max_steps
,
min_lr
=
1e-5
,
last_epoch
=-
1
,
**
kwargs
):
super
().
__init__
(
optimizer
=
optimizer
,
max_steps
=
max_steps
,
last_epoch
=
last_epoch
,
min_lr
=
min_lr
,
**
kwargs
)
def
_get_lr
(
self
,
step
):
new_lrs
=
[
_square_annealing
(
initial_lr
=
initial_lr
,
step
=
step
-
self
.
warmup_steps
,
max_steps
=
self
.
max_steps
-
self
.
warmup_steps
,
min_lr
=
self
.
min_lr
,
)
for
initial_lr
in
self
.
base_lrs
]
return
new_lrs
class
SquareRootAnnealing
(
WarmupPolicy
):
def
__init__
(
self
,
optimizer
,
*
,
max_steps
,
min_lr
=
0
,
last_epoch
=-
1
,
**
kwargs
):
super
().
__init__
(
optimizer
=
optimizer
,
max_steps
=
max_steps
,
last_epoch
=
last_epoch
,
min_lr
=
min_lr
,
**
kwargs
)
def
_get_lr
(
self
,
step
):
new_lrs
=
[
_squareroot_annealing
(
initial_lr
=
initial_lr
,
step
=
step
,
max_steps
=
self
.
max_steps
,
min_lr
=
self
.
min_lr
)
for
initial_lr
in
self
.
base_lrs
]
return
new_lrs
class
CosineAnnealing
(
WarmupAnnealHoldPolicy
):
def
__init__
(
self
,
optimizer
,
*
,
max_steps
,
min_lr
=
0
,
last_epoch
=-
1
,
**
kwargs
):
super
().
__init__
(
optimizer
=
optimizer
,
max_steps
=
max_steps
,
last_epoch
=
last_epoch
,
min_lr
=
min_lr
,
**
kwargs
)
def
_get_lr
(
self
,
step
):
for
initial_lr
in
self
.
base_lrs
:
if
initial_lr
<
self
.
min_lr
:
raise
ValueError
(
f
"
{
self
}
received an initial learning rate "
f
"that was lower than the minimum learning rate."
)
if
self
.
constant_steps
is
None
or
self
.
constant_steps
==
0
:
new_lrs
=
[
_cosine_annealing
(
initial_lr
=
initial_lr
,
step
=
step
-
self
.
warmup_steps
,
max_steps
=
self
.
max_steps
-
self
.
warmup_steps
,
min_lr
=
self
.
min_lr
,
)
for
initial_lr
in
self
.
base_lrs
]
else
:
new_lrs
=
self
.
_get_linear_warmup_with_cosine_annealing_lr
(
step
)
return
new_lrs
def
_get_warmup_lr
(
self
,
step
):
if
self
.
constant_steps
is
None
or
self
.
constant_steps
==
0
:
return
super
().
_get_warmup_lr
(
step
)
else
:
# Use linear warmup for the initial part.
return
self
.
_get_linear_warmup_with_cosine_annealing_lr
(
step
)
def
_get_constant_lr
(
self
,
step
):
# Only called when `constant_steps` > 0.
return
self
.
_get_linear_warmup_with_cosine_annealing_lr
(
step
)
def
_get_linear_warmup_with_cosine_annealing_lr
(
self
,
step
):
# Cosine Schedule for Megatron LM,
# slightly different warmup schedule + constant LR at the end.
new_lrs
=
[
_linear_warmup_with_cosine_annealing
(
max_lr
=
self
.
base_lrs
[
0
],
warmup_steps
=
self
.
warmup_steps
,
step
=
step
,
decay_steps
=
self
.
decay_steps
,
min_lr
=
self
.
min_lr
,
)
for
_
in
self
.
base_lrs
]
return
new_lrs
class
NoamAnnealing
(
_LRScheduler
):
def
__init__
(
self
,
optimizer
,
*
,
d_model
,
warmup_steps
=
None
,
warmup_ratio
=
None
,
max_steps
=
None
,
min_lr
=
0.0
,
last_epoch
=-
1
):
self
.
_normalize
=
d_model
**
(
-
0.5
)
assert
not
(
warmup_steps
is
not
None
and
warmup_ratio
is
not
None
),
\
"Either use particular number of step or ratio"
assert
warmup_ratio
is
None
or
max_steps
is
not
None
,
\
"If there is a ratio, there should be a total steps"
# It is necessary to assign all attributes *before* __init__,
# as class is wrapped by an inner class.
self
.
max_steps
=
max_steps
if
warmup_steps
is
not
None
:
self
.
warmup_steps
=
warmup_steps
elif
warmup_ratio
is
not
None
:
self
.
warmup_steps
=
int
(
warmup_ratio
*
max_steps
)
else
:
self
.
warmup_steps
=
0
self
.
min_lr
=
min_lr
super
().
__init__
(
optimizer
,
last_epoch
)
def
get_lr
(
self
):
if
not
self
.
_get_lr_called_within_step
:
warnings
.
warn
(
"To get the last learning rate computed "
"by the scheduler, please use `get_last_lr()`."
,
UserWarning
,
stacklevel
=
2
)
step
=
max
(
1
,
self
.
last_epoch
)
for
initial_lr
in
self
.
base_lrs
:
if
initial_lr
<
self
.
min_lr
:
raise
ValueError
(
f
"
{
self
}
received an initial learning rate "
f
"that was lower than the minimum learning rate."
)
new_lrs
=
[
self
.
_noam_annealing
(
initial_lr
=
initial_lr
,
step
=
step
)
for
initial_lr
in
self
.
base_lrs
]
return
new_lrs
def
_noam_annealing
(
self
,
initial_lr
,
step
):
if
self
.
warmup_steps
>
0
:
mult
=
self
.
_normalize
*
min
(
step
**
(
-
0.5
),
step
*
(
self
.
warmup_steps
**
(
-
1.5
)))
else
:
mult
=
self
.
_normalize
*
step
**
(
-
0.5
)
out_lr
=
initial_lr
*
mult
if
step
>
self
.
warmup_steps
:
out_lr
=
max
(
out_lr
,
self
.
min_lr
)
return
out_lr
class
NoamHoldAnnealing
(
WarmupHoldPolicy
):
def
__init__
(
self
,
optimizer
,
*
,
max_steps
,
decay_rate
=
0.5
,
min_lr
=
0.0
,
last_epoch
=-
1
,
**
kwargs
):
"""
From Nemo:
Implementation of the Noam Hold Annealing policy
from the SqueezeFormer paper.
Unlike NoamAnnealing, the peak learning rate
can be explicitly set for this scheduler.
The schedule first performs linear warmup,
then holds the peak LR, then decays with some schedule for
the remainder of the steps.
Therefore the min-lr is still dependent
on the hyper parameters selected.
It's schedule is determined by three factors-
Warmup Steps: Initial stage, where linear warmup
occurs uptil the peak LR is reached. Unlike NoamAnnealing,
the peak LR is explicitly stated here instead of a scaling factor.
Hold Steps: Intermediate stage, where the peak LR
is maintained for some number of steps. In this region,
the high peak LR allows the model to converge faster
if training is stable. However the high LR
may also cause instability during training.
Should usually be a significant fraction of training
steps (around 30-40% of the entire training steps).
Decay Steps: Final stage, where the LR rapidly decays
with some scaling rate (set by decay rate).
To attain Noam decay, use 0.5,
for Squeezeformer recommended decay, use 1.0.
The fast decay after prolonged high LR during
hold phase allows for rapid convergence.
References:
- [Squeezeformer:
An Efficient Transformer for Automatic Speech Recognition]
(https://arxiv.org/abs/2206.00888)
Args:
optimizer: Pytorch compatible Optimizer object.
warmup_steps: Number of training steps in warmup stage
warmup_ratio: Ratio of warmup steps to total steps
hold_steps: Number of training steps to
hold the learning rate after warm up
hold_ratio: Ratio of hold steps to total steps
max_steps: Total number of steps while training or `None` for
infinite training
decay_rate: Float value describing the polynomial decay
after the hold period. Default value
of 0.5 corresponds to Noam decay.
min_lr: Minimum learning rate.
"""
self
.
decay_rate
=
decay_rate
super
().
__init__
(
optimizer
=
optimizer
,
max_steps
=
max_steps
,
last_epoch
=
last_epoch
,
min_lr
=
min_lr
,
**
kwargs
)
def
_get_lr
(
self
,
step
):
if
self
.
warmup_steps
is
None
or
self
.
warmup_steps
==
0
:
raise
ValueError
(
"Noam scheduler cannot be used without warmup steps"
)
if
self
.
hold_steps
>
0
:
hold_steps
=
self
.
hold_steps
-
self
.
warmup_steps
else
:
hold_steps
=
0
new_lrs
=
[
_noam_hold_annealing
(
initial_lr
,
step
=
step
,
warmup_steps
=
self
.
warmup_steps
,
hold_steps
=
hold_steps
,
decay_rate
=
self
.
decay_rate
,
min_lr
=
self
.
min_lr
,
)
for
initial_lr
in
self
.
base_lrs
]
return
new_lrs
def
set_step
(
self
,
step
:
int
):
self
.
last_epoch
=
step
class
ConstantLR
(
_LRScheduler
):
"""The ConstantLR scheduler
This scheduler keeps a constant lr
"""
def
__init__
(
self
,
optimizer
:
torch
.
optim
.
Optimizer
,
):
# __init__() must be invoked before setting field
# because step() is also invoked in __init__()
super
().
__init__
(
optimizer
)
def
get_lr
(
self
):
return
self
.
base_lrs
def
set_step
(
self
,
step
:
int
):
self
.
last_epoch
=
step
examples/music_generation/inspiremusic/utils/tokenizer_utils.py
0 → 100644
View file @
0112b0f0
import
glob
import
json
import
os
import
random
import
sys
import
time
import
warnings
import
matplotlib
import
numpy
as
np
import
torch
import
yaml
from
torch
import
distributed
as
dist
from
torch.nn.utils
import
weight_norm
matplotlib
.
use
(
"Agg"
)
import
matplotlib.pylab
as
plt
import
re
import
pathlib
def
seed_everything
(
seed
,
cudnn_deterministic
=
False
):
"""
Function that sets seed for pseudo-random number generators in:
pytorch, numpy, python.random
Args:
seed: the integer value seed for global random state
"""
if
seed
is
not
None
:
# print(f"Global seed set to {seed}")
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
# if cudnn_deterministic:
# torch.backends.cudnn.deterministic = True
# warnings.warn('You have chosen to seed training. '
# 'This will turn on the CUDNN deterministic setting, '
# 'which can slow down your training considerably! '
# 'You may see unexpected behavior when restarting '
# 'from checkpoints.')
def
is_primary
():
return
get_rank
()
==
0
def
get_rank
():
if
not
dist
.
is_available
():
return
0
if
not
dist
.
is_initialized
():
return
0
return
dist
.
get_rank
()
def
load_yaml_config
(
path
):
with
open
(
path
)
as
f
:
config
=
yaml
.
full_load
(
f
)
return
config
def
save_config_to_yaml
(
config
,
path
):
assert
path
.
endswith
(
'.yaml'
)
with
open
(
path
,
'w'
)
as
f
:
f
.
write
(
yaml
.
dump
(
config
))
f
.
close
()
def
save_dict_to_json
(
d
,
path
,
indent
=
None
):
json
.
dump
(
d
,
open
(
path
,
'w'
),
indent
=
indent
)
def
load_dict_from_json
(
path
):
return
json
.
load
(
open
(
path
,
'r'
))
def
write_args
(
args
,
path
):
args_dict
=
dict
((
name
,
getattr
(
args
,
name
))
for
name
in
dir
(
args
)
if
not
name
.
startswith
(
'_'
))
with
open
(
path
,
'a'
)
as
args_file
:
args_file
.
write
(
'==> torch version: {}
\n
'
.
format
(
torch
.
__version__
))
args_file
.
write
(
'==> cudnn version: {}
\n
'
.
format
(
torch
.
backends
.
cudnn
.
version
()))
args_file
.
write
(
'==> Cmd:
\n
'
)
args_file
.
write
(
str
(
sys
.
argv
))
args_file
.
write
(
'
\n
==> args:
\n
'
)
for
k
,
v
in
sorted
(
args_dict
.
items
()):
args_file
.
write
(
' %s: %s
\n
'
%
(
str
(
k
),
str
(
v
)))
args_file
.
close
()
class
Logger
(
object
):
def
__init__
(
self
,
args
):
self
.
args
=
args
self
.
save_dir
=
args
.
save_dir
self
.
is_primary
=
is_primary
()
if
self
.
is_primary
:
os
.
makedirs
(
self
.
save_dir
,
exist_ok
=
True
)
# save the args and config
self
.
config_dir
=
os
.
path
.
join
(
self
.
save_dir
,
'configs'
)
os
.
makedirs
(
self
.
config_dir
,
exist_ok
=
True
)
file_name
=
os
.
path
.
join
(
self
.
config_dir
,
'args.txt'
)
write_args
(
args
,
file_name
)
log_dir
=
os
.
path
.
join
(
self
.
save_dir
,
'logs'
)
if
not
os
.
path
.
exists
(
log_dir
):
os
.
makedirs
(
log_dir
,
exist_ok
=
True
)
self
.
text_writer
=
open
(
os
.
path
.
join
(
log_dir
,
'log.txt'
),
'a'
)
# 'w')
if
args
.
tensorboard
:
self
.
log_info
(
'using tensorboard'
)
self
.
tb_writer
=
torch
.
utils
.
tensorboard
.
SummaryWriter
(
log_dir
=
log_dir
)
# tensorboard.SummaryWriter(log_dir=log_dir)
else
:
self
.
tb_writer
=
None
def
save_config
(
self
,
config
):
if
self
.
is_primary
:
save_config_to_yaml
(
config
,
os
.
path
.
join
(
self
.
config_dir
,
'config.yaml'
))
def
log_info
(
self
,
info
,
check_primary
=
True
):
if
self
.
is_primary
or
(
not
check_primary
):
print
(
info
)
if
self
.
is_primary
:
info
=
str
(
info
)
time_str
=
time
.
strftime
(
'%Y-%m-%d-%H-%M'
)
info
=
'{}: {}'
.
format
(
time_str
,
info
)
if
not
info
.
endswith
(
'
\n
'
):
info
+=
'
\n
'
self
.
text_writer
.
write
(
info
)
self
.
text_writer
.
flush
()
def
add_scalar
(
self
,
**
kargs
):
"""Log a scalar variable."""
if
self
.
is_primary
:
if
self
.
tb_writer
is
not
None
:
self
.
tb_writer
.
add_scalar
(
**
kargs
)
def
add_scalars
(
self
,
**
kargs
):
"""Log a scalar variable."""
if
self
.
is_primary
:
if
self
.
tb_writer
is
not
None
:
self
.
tb_writer
.
add_scalars
(
**
kargs
)
def
add_image
(
self
,
**
kargs
):
"""Log a scalar variable."""
if
self
.
is_primary
:
if
self
.
tb_writer
is
not
None
:
self
.
tb_writer
.
add_image
(
**
kargs
)
def
add_images
(
self
,
**
kargs
):
"""Log a scalar variable."""
if
self
.
is_primary
:
if
self
.
tb_writer
is
not
None
:
self
.
tb_writer
.
add_images
(
**
kargs
)
def
close
(
self
):
if
self
.
is_primary
:
self
.
text_writer
.
close
()
self
.
tb_writer
.
close
()
def
plot_spectrogram
(
spectrogram
):
fig
,
ax
=
plt
.
subplots
(
figsize
=
(
10
,
2
))
im
=
ax
.
imshow
(
spectrogram
,
aspect
=
"auto"
,
origin
=
"lower"
,
interpolation
=
'none'
)
plt
.
colorbar
(
im
,
ax
=
ax
)
fig
.
canvas
.
draw
()
plt
.
close
()
return
fig
def
init_weights
(
m
,
mean
=
0.0
,
std
=
0.01
):
classname
=
m
.
__class__
.
__name__
if
classname
.
find
(
"Conv"
)
!=
-
1
:
m
.
weight
.
data
.
normal_
(
mean
,
std
)
def
apply_weight_norm
(
m
):
classname
=
m
.
__class__
.
__name__
if
classname
.
find
(
"Conv"
)
!=
-
1
:
weight_norm
(
m
)
def
get_padding
(
kernel_size
,
dilation
=
1
):
return
int
((
kernel_size
*
dilation
-
dilation
)
/
2
)
def
load_checkpoint
(
filepath
,
device
):
assert
os
.
path
.
isfile
(
filepath
)
print
(
"Loading '{}'"
.
format
(
filepath
))
checkpoint_dict
=
torch
.
load
(
filepath
,
map_location
=
device
)
print
(
"Complete."
)
return
checkpoint_dict
def
save_checkpoint
(
filepath
,
obj
,
num_ckpt_keep
=
5
):
name
=
re
.
match
(
r
'(do|g)_\d+'
,
pathlib
.
Path
(
filepath
).
name
).
group
(
1
)
ckpts
=
sorted
(
pathlib
.
Path
(
filepath
).
parent
.
glob
(
f
'
{
name
}
_*'
))
if
len
(
ckpts
)
>
num_ckpt_keep
:
[
os
.
remove
(
c
)
for
c
in
ckpts
[:
-
num_ckpt_keep
]]
print
(
"Saving checkpoint to {}"
.
format
(
filepath
))
torch
.
save
(
obj
,
filepath
)
print
(
"Complete."
)
def
scan_checkpoint
(
cp_dir
,
prefix
):
pattern
=
os
.
path
.
join
(
cp_dir
,
prefix
+
'????????'
)
cp_list
=
glob
.
glob
(
pattern
)
if
len
(
cp_list
)
==
0
:
return
None
return
sorted
(
cp_list
)[
-
1
]
examples/music_generation/inspiremusic/utils/train_utils.py
0 → 100644
View file @
0112b0f0
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
# 2023 Horizon Inc. (authors: Xingchen Song)
# 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
contextlib
import
nullcontext
import
logging
import
os
import
torch
import
json
import
re
import
datetime
import
yaml
import
deepspeed
import
torch.optim
as
optim
import
torch.distributed
as
dist
from
torch.utils.tensorboard
import
SummaryWriter
from
torch.utils.data
import
DataLoader
from
torch.nn.utils
import
clip_grad_norm_
from
deepspeed.runtime.zero.stage_1_and_2
import
estimate_zero2_model_states_mem_needs_all_live
from
inspiremusic.dataset.dataset
import
Dataset
from
inspiremusic.utils.scheduler
import
WarmupLR
,
NoamHoldAnnealing
,
ConstantLR
def
init_distributed
(
args
):
world_size
=
int
(
os
.
environ
.
get
(
'WORLD_SIZE'
,
1
))
local_rank
=
int
(
os
.
environ
.
get
(
'LOCAL_RANK'
,
0
))
rank
=
int
(
os
.
environ
.
get
(
'RANK'
,
0
))
logging
.
info
(
'training on multiple gpus, this gpu {}'
.
format
(
local_rank
)
+
', rank {}, world_size {}'
.
format
(
rank
,
world_size
))
if
args
.
train_engine
==
'torch_ddp'
:
torch
.
cuda
.
set_device
(
local_rank
)
dist
.
init_process_group
(
args
.
dist_backend
)
else
:
deepspeed
.
init_distributed
(
dist_backend
=
args
.
dist_backend
)
return
world_size
,
local_rank
,
rank
def
init_dataset_and_dataloader
(
args
,
configs
):
gan
=
False
data_pipeline
=
configs
[
'data_pipeline_gan'
]
if
gan
is
True
else
configs
[
'data_pipeline'
]
train_dataset
=
Dataset
(
args
.
train_data
,
data_pipeline
=
data_pipeline
,
mode
=
'train'
,
shuffle
=
True
,
partition
=
True
)
cv_dataset
=
Dataset
(
args
.
cv_data
,
data_pipeline
=
data_pipeline
,
mode
=
'train'
,
shuffle
=
False
,
partition
=
False
)
# do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
train_data_loader
=
DataLoader
(
train_dataset
,
batch_size
=
None
,
pin_memory
=
args
.
pin_memory
,
num_workers
=
args
.
num_workers
,
prefetch_factor
=
args
.
prefetch
,
timeout
=
60
)
cv_data_loader
=
DataLoader
(
cv_dataset
,
batch_size
=
None
,
pin_memory
=
args
.
pin_memory
,
num_workers
=
args
.
num_workers
,
prefetch_factor
=
args
.
prefetch
,
timeout
=
60
)
return
train_dataset
,
cv_dataset
,
train_data_loader
,
cv_data_loader
def
check_modify_and_save_config
(
args
,
configs
):
if
args
.
train_engine
==
"torch_ddp"
:
configs
[
'train_conf'
][
"dtype"
]
=
'fp32'
else
:
with
open
(
args
.
deepspeed_config
,
'r'
)
as
fin
:
ds_configs
=
json
.
load
(
fin
)
if
"fp16"
in
ds_configs
and
ds_configs
[
"fp16"
][
"enabled"
]:
configs
[
'train_conf'
][
"dtype"
]
=
"fp16"
elif
"bf16"
in
ds_configs
and
ds_configs
[
"bf16"
][
"enabled"
]:
configs
[
'train_conf'
][
"dtype"
]
=
"bf16"
else
:
configs
[
'train_conf'
][
"dtype"
]
=
"fp32"
assert
ds_configs
[
"train_micro_batch_size_per_gpu"
]
==
1
# if use deepspeed, override ddp config
configs
[
'train_conf'
][
'save_per_step'
]
=
int
(
configs
[
'train_conf'
][
'save_per_step'
]
*
configs
[
'train_conf'
][
'accum_grad'
]
/
ds_configs
[
"gradient_accumulation_steps"
])
configs
[
'train_conf'
][
'accum_grad'
]
=
ds_configs
[
"gradient_accumulation_steps"
]
configs
[
'train_conf'
][
'grad_clip'
]
=
ds_configs
[
"gradient_clipping"
]
configs
[
'train_conf'
][
'log_interval'
]
=
ds_configs
[
"steps_per_print"
]
return
configs
def
wrap_cuda_model
(
args
,
model
):
local_world_size
=
int
(
os
.
environ
.
get
(
'LOCAL_WORLD_SIZE'
,
1
))
world_size
=
int
(
os
.
environ
.
get
(
'WORLD_SIZE'
,
1
))
if
args
.
train_engine
==
"torch_ddp"
:
# native pytorch ddp
assert
(
torch
.
cuda
.
is_available
())
model
.
cuda
()
model
=
torch
.
nn
.
parallel
.
DistributedDataParallel
(
model
,
find_unused_parameters
=
True
)
else
:
if
int
(
os
.
environ
.
get
(
'RANK'
,
0
))
==
0
:
logging
.
info
(
"Estimating model states memory needs (zero2)..."
)
estimate_zero2_model_states_mem_needs_all_live
(
model
,
num_gpus_per_node
=
local_world_size
,
num_nodes
=
world_size
//
local_world_size
)
return
model
def
init_optimizer_and_scheduler
(
args
,
configs
,
model
):
if
configs
[
'train_conf'
][
'optim'
]
==
'adam'
:
optimizer
=
optim
.
Adam
(
model
.
parameters
(),
**
configs
[
'train_conf'
][
'optim_conf'
])
elif
configs
[
'train_conf'
][
'optim'
]
==
'adamw'
:
optimizer
=
optim
.
AdamW
(
model
.
parameters
(),
**
configs
[
'train_conf'
][
'optim_conf'
])
else
:
raise
ValueError
(
"unknown optimizer: "
+
configs
[
'train_conf'
])
if
configs
[
'train_conf'
][
'scheduler'
]
==
'warmuplr'
:
scheduler_type
=
WarmupLR
scheduler
=
WarmupLR
(
optimizer
,
**
configs
[
'train_conf'
][
'scheduler_conf'
])
elif
configs
[
'train_conf'
][
'scheduler'
]
==
'NoamHoldAnnealing'
:
scheduler_type
=
NoamHoldAnnealing
scheduler
=
NoamHoldAnnealing
(
optimizer
,
**
configs
[
'train_conf'
][
'scheduler_conf'
])
elif
configs
[
'train_conf'
][
'scheduler'
]
==
'constantlr'
:
scheduler_type
=
ConstantLR
scheduler
=
ConstantLR
(
optimizer
)
else
:
raise
ValueError
(
"unknown scheduler: "
+
configs
[
'train_conf'
])
# use deepspeed optimizer for speedup
if
args
.
train_engine
==
"deepspeed"
:
def
scheduler
(
opt
):
return
scheduler_type
(
opt
,
**
configs
[
'train_conf'
][
'scheduler_conf'
])
model
,
optimizer
,
_
,
scheduler
=
deepspeed
.
initialize
(
args
=
args
,
model
=
model
,
optimizer
=
None
,
lr_scheduler
=
scheduler
,
model_parameters
=
model
.
parameters
())
return
model
,
optimizer
,
scheduler
def
init_summarywriter
(
args
):
writer
=
None
if
int
(
os
.
environ
.
get
(
'RANK'
,
0
))
==
0
:
os
.
makedirs
(
args
.
model_dir
,
exist_ok
=
True
)
writer
=
SummaryWriter
(
args
.
tensorboard_dir
)
return
writer
def
save_model
(
model
,
model_name
,
info_dict
):
rank
=
int
(
os
.
environ
.
get
(
'RANK'
,
0
))
model_dir
=
info_dict
[
"model_dir"
]
save_model_path
=
os
.
path
.
join
(
model_dir
,
'{}.pt'
.
format
(
model_name
))
if
info_dict
[
"train_engine"
]
==
"torch_ddp"
:
if
rank
==
0
:
torch
.
save
(
model
.
module
.
state_dict
(),
save_model_path
)
else
:
with
torch
.
no_grad
():
model
.
save_checkpoint
(
save_dir
=
model_dir
,
tag
=
model_name
,
client_state
=
info_dict
)
if
rank
==
0
:
info_path
=
re
.
sub
(
'.pt$'
,
'.yaml'
,
save_model_path
)
info_dict
[
'save_time'
]
=
datetime
.
datetime
.
now
().
strftime
(
'%d/%m/%Y %H:%M:%S'
)
with
open
(
info_path
,
'w'
)
as
fout
:
data
=
yaml
.
dump
(
info_dict
)
fout
.
write
(
data
)
logging
.
info
(
'[Rank {}] Checkpoint: save to checkpoint {}'
.
format
(
rank
,
save_model_path
))
def
inspiremusic_join
(
group_join
,
info_dict
):
world_size
=
int
(
os
.
environ
.
get
(
'WORLD_SIZE'
,
1
))
local_rank
=
int
(
os
.
environ
.
get
(
'LOCAL_RANK'
,
0
))
rank
=
int
(
os
.
environ
.
get
(
'RANK'
,
0
))
if
info_dict
[
"batch_idx"
]
!=
0
:
# we try to join all rank in both ddp and deepspeed mode, in case different rank has different lr
try
:
dist
.
monitored_barrier
(
group
=
group_join
,
timeout
=
group_join
.
options
.
_timeout
)
return
False
except
RuntimeError
as
e
:
logging
.
info
(
"Detected uneven workload distribution: {}
\n
"
.
format
(
e
)
+
"Break current worker to manually join all workers, "
+
"world_size {}, current rank {}, current local_rank {}
\n
"
.
format
(
world_size
,
rank
,
local_rank
))
return
True
else
:
return
False
def
batch_forward
(
model
,
batch
,
info_dict
,
scaler
):
device
=
int
(
os
.
environ
.
get
(
'LOCAL_RANK'
,
0
))
dtype
=
info_dict
[
"dtype"
]
if
dtype
==
"fp16"
:
dtype
=
torch
.
float16
elif
dtype
==
"bf16"
:
dtype
=
torch
.
bfloat16
else
:
# fp32
dtype
=
torch
.
float32
if
info_dict
[
'train_engine'
]
==
'torch_ddp'
:
autocast
=
torch
.
cuda
.
amp
.
autocast
(
enabled
=
scaler
is
not
None
)
else
:
autocast
=
torch
.
cuda
.
amp
.
autocast
(
enabled
=
True
,
dtype
=
dtype
,
cache_enabled
=
False
)
with
autocast
:
info_dict
[
'loss_dict'
]
=
model
(
batch
,
device
)
return
info_dict
def
batch_backward
(
model
,
info_dict
,
scaler
):
if
info_dict
[
"train_engine"
]
==
"deepspeed"
:
scaled_loss
=
model
.
backward
(
info_dict
[
'loss_dict'
][
'loss'
])
else
:
scaled_loss
=
info_dict
[
'loss_dict'
][
'loss'
]
/
info_dict
[
'accum_grad'
]
if
scaler
is
not
None
:
scaler
.
scale
(
scaled_loss
).
backward
()
else
:
scaled_loss
.
backward
()
info_dict
[
'loss_dict'
][
'loss'
]
=
scaled_loss
return
info_dict
def
update_parameter_and_lr
(
model
,
optimizer
,
scheduler
,
info_dict
,
scaler
=
None
):
grad_norm
=
0.0
if
info_dict
[
'train_engine'
]
==
"deepspeed"
:
info_dict
[
"is_gradient_accumulation_boundary"
]
=
model
.
is_gradient_accumulation_boundary
()
model
.
step
()
grad_norm
=
model
.
get_global_grad_norm
()
elif
(
info_dict
[
'batch_idx'
]
+
1
)
%
info_dict
[
"accum_grad"
]
==
0
:
if
scaler
is
not
None
:
scaler
.
unscale_
(
optimizer
)
# Unscale gradients before clipping
grad_norm
=
clip_grad_norm_
(
model
.
parameters
(),
info_dict
[
'grad_clip'
])
scaler
.
step
(
optimizer
)
scaler
.
update
()
else
:
grad_norm
=
clip_grad_norm_
(
model
.
parameters
(),
info_dict
[
'grad_clip'
])
if
torch
.
isfinite
(
grad_norm
):
optimizer
.
step
()
optimizer
.
zero_grad
()
scheduler
.
step
()
info_dict
[
"lr"
]
=
optimizer
.
param_groups
[
0
][
'lr'
]
info_dict
[
"grad_norm"
]
=
grad_norm
return
info_dict
def
log_per_step
(
writer
,
info_dict
):
tag
=
info_dict
[
"tag"
]
epoch
=
info_dict
.
get
(
'epoch'
,
0
)
step
=
info_dict
[
"step"
]
batch_idx
=
info_dict
[
"batch_idx"
]
loss_dict
=
info_dict
[
'loss_dict'
]
rank
=
int
(
os
.
environ
.
get
(
'RANK'
,
0
))
# only rank 0 write to tensorboard to avoid multi-process write
if
writer
is
not
None
:
if
(
info_dict
[
'train_engine'
]
==
'deepspeed'
and
info_dict
[
'is_gradient_accumulation_boundary'
]
is
True
)
or
\
(
info_dict
[
'train_engine'
]
==
'torch_ddp'
and
(
info_dict
[
'batch_idx'
]
+
1
)
%
info_dict
[
'accum_grad'
]
==
0
):
for
k
in
[
'epoch'
,
'lr'
,
'grad_norm'
]:
writer
.
add_scalar
(
'{}/{}'
.
format
(
tag
,
k
),
info_dict
[
k
],
step
+
1
)
for
k
,
v
in
loss_dict
.
items
():
writer
.
add_scalar
(
'{}/{}'
.
format
(
tag
,
k
),
v
,
step
+
1
)
# TRAIN & CV, Shell log (stdout)
if
(
info_dict
[
'batch_idx'
]
+
1
)
%
info_dict
[
'log_interval'
]
==
0
:
log_str
=
'{} Batch {}/{} '
.
format
(
tag
,
epoch
,
batch_idx
+
1
)
for
name
,
value
in
loss_dict
.
items
():
log_str
+=
'{} {:.6f} '
.
format
(
name
,
value
.
item
())
if
tag
==
"TRAIN"
:
log_str
+=
'lr {:.8f} grad_norm {:.6f}'
.
format
(
info_dict
[
"lr"
],
info_dict
[
'grad_norm'
])
log_str
+=
' rank {}'
.
format
(
rank
)
logging
.
debug
(
log_str
)
def
log_per_save
(
writer
,
info_dict
):
tag
=
info_dict
[
"tag"
]
epoch
=
info_dict
[
"epoch"
]
step
=
info_dict
[
"step"
]
loss_dict
=
info_dict
[
"loss_dict"
]
lr
=
info_dict
[
'lr'
]
rank
=
int
(
os
.
environ
.
get
(
'RANK'
,
0
))
logging
.
info
(
'Epoch {} Step {} CV info lr {} {} rank {}'
.
format
(
epoch
,
step
+
1
,
lr
,
rank
,
' '
.
join
([
'{}_{}'
.
format
(
k
,
v
)
for
k
,
v
in
loss_dict
.
items
()])))
if
writer
is
not
None
:
for
k
in
[
'epoch'
,
'lr'
]:
writer
.
add_scalar
(
'{}/{}'
.
format
(
tag
,
k
),
info_dict
[
k
],
step
+
1
)
for
k
,
v
in
loss_dict
.
items
():
writer
.
add_scalar
(
'{}/{}'
.
format
(
tag
,
k
),
v
,
step
+
1
)
examples/music_generation/inspiremusic/utils/utils.py
0 → 100644
View file @
0112b0f0
import
os
import
sys
def
align_trans_scp_file
(
trans
,
scp
):
trans_dict
=
{}
with
open
(
trans
,
'r'
)
as
f
:
for
line
in
f
:
sec
=
line
.
strip
().
split
(
"
\t
"
)
trans_dict
[
sec
[
0
]]
=
sec
[
1
]
scp_dict
=
{}
with
open
(
scp
,
'r'
)
as
f
:
for
line
in
f
:
sec
=
line
.
strip
().
split
(
" "
)
scp_dict
[
sec
[
0
]]
=
sec
[
1
]
with
open
(
"text"
,
"w"
)
as
f
:
for
k
,
v
in
scp_dict
.
items
():
f
.
write
(
"%s
\t
%s
\n
"
%
(
k
,
trans_dict
[
k
]))
if
__name__
==
'__main__'
:
trans
=
sys
.
argv
[
1
]
scp
=
sys
.
argv
[
2
]
align_trans_scp_file
(
trans
,
scp
)
\ No newline at end of file
examples/music_generation/inspiremusic/version.txt
0 → 100644
View file @
0112b0f0
v0.1
\ No newline at end of file
examples/music_generation/inspiremusic/wavtokenizer/__init__.py
0 → 100644
View file @
0112b0f0
examples/music_generation/inspiremusic/wavtokenizer/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
examples/music_generation/inspiremusic/wavtokenizer/decoder/__init__.py
0 → 100644
View file @
0112b0f0
Prev
1
…
3
4
5
6
7
8
9
10
11
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment