Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
InspireMusic_pytorch
Commits
0112b0f0
Commit
0112b0f0
authored
Feb 14, 2025
by
chenzk
Browse files
v1.0
parents
Pipeline
#2394
canceled with stages
Changes
474
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2038 additions
and
0 deletions
+2038
-0
inspiremusic/wavtokenizer/decoder/heads.py
inspiremusic/wavtokenizer/decoder/heads.py
+159
-0
inspiremusic/wavtokenizer/decoder/helpers.py
inspiremusic/wavtokenizer/decoder/helpers.py
+71
-0
inspiremusic/wavtokenizer/decoder/loss.py
inspiremusic/wavtokenizer/decoder/loss.py
+159
-0
inspiremusic/wavtokenizer/decoder/models.py
inspiremusic/wavtokenizer/decoder/models.py
+266
-0
inspiremusic/wavtokenizer/decoder/modules.py
inspiremusic/wavtokenizer/decoder/modules.py
+214
-0
inspiremusic/wavtokenizer/decoder/pretrained.py
inspiremusic/wavtokenizer/decoder/pretrained.py
+253
-0
inspiremusic/wavtokenizer/decoder/pretrained_model.py
inspiremusic/wavtokenizer/decoder/pretrained_model.py
+192
-0
inspiremusic/wavtokenizer/decoder/spectral_ops.py
inspiremusic/wavtokenizer/decoder/spectral_ops.py
+242
-0
inspiremusic/wavtokenizer/encoder/__init__.py
inspiremusic/wavtokenizer/encoder/__init__.py
+12
-0
inspiremusic/wavtokenizer/encoder/__pycache__/__init__.cpython-310.pyc
...wavtokenizer/encoder/__pycache__/__init__.cpython-310.pyc
+0
-0
inspiremusic/wavtokenizer/encoder/__pycache__/distrib.cpython-310.pyc
.../wavtokenizer/encoder/__pycache__/distrib.cpython-310.pyc
+0
-0
inspiremusic/wavtokenizer/encoder/__pycache__/model.cpython-310.pyc
...ic/wavtokenizer/encoder/__pycache__/model.cpython-310.pyc
+0
-0
inspiremusic/wavtokenizer/encoder/__pycache__/utils.cpython-310.pyc
...ic/wavtokenizer/encoder/__pycache__/utils.cpython-310.pyc
+0
-0
inspiremusic/wavtokenizer/encoder/distrib.py
inspiremusic/wavtokenizer/encoder/distrib.py
+124
-0
inspiremusic/wavtokenizer/encoder/model.py
inspiremusic/wavtokenizer/encoder/model.py
+324
-0
inspiremusic/wavtokenizer/encoder/modules/__init__.py
inspiremusic/wavtokenizer/encoder/modules/__init__.py
+22
-0
inspiremusic/wavtokenizer/encoder/modules/__pycache__/__init__.cpython-310.pyc
...izer/encoder/modules/__pycache__/__init__.cpython-310.pyc
+0
-0
inspiremusic/wavtokenizer/encoder/modules/__pycache__/conv.cpython-310.pyc
...okenizer/encoder/modules/__pycache__/conv.cpython-310.pyc
+0
-0
inspiremusic/wavtokenizer/encoder/modules/__pycache__/lstm.cpython-310.pyc
...okenizer/encoder/modules/__pycache__/lstm.cpython-310.pyc
+0
-0
inspiremusic/wavtokenizer/encoder/modules/__pycache__/norm.cpython-310.pyc
...okenizer/encoder/modules/__pycache__/norm.cpython-310.pyc
+0
-0
No files found.
inspiremusic/wavtokenizer/decoder/heads.py
0 → 100644
View file @
0112b0f0
import
torch
from
torch
import
nn
from
torchaudio.functional.functional
import
_hz_to_mel
,
_mel_to_hz
from
inspiremusic.wavtokenizer.decoder.spectral_ops
import
IMDCT
,
ISTFT
def
symexp
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
sign
(
x
)
*
(
torch
.
exp
(
x
.
abs
())
-
1
)
class
FourierHead
(
nn
.
Module
):
"""Base class for inverse fourier modules."""
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
raise
NotImplementedError
(
"Subclasses must implement the forward method."
)
class
ISTFTHead
(
FourierHead
):
"""
ISTFT Head module for predicting STFT complex coefficients.
Args:
dim (int): Hidden dimension of the model.
n_fft (int): Size of Fourier transform.
hop_length (int): The distance between neighboring sliding window frames, which should align with
the resolution of the input features.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def
__init__
(
self
,
dim
:
int
,
n_fft
:
int
,
hop_length
:
int
,
padding
:
str
=
"same"
):
super
().
__init__
()
out_dim
=
n_fft
+
2
self
.
out
=
torch
.
nn
.
Linear
(
dim
,
out_dim
)
self
.
istft
=
ISTFT
(
n_fft
=
n_fft
,
hop_length
=
hop_length
,
win_length
=
n_fft
,
padding
=
padding
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the ISTFTHead module.
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
x
=
self
.
out
(
x
).
transpose
(
1
,
2
)
mag
,
p
=
x
.
chunk
(
2
,
dim
=
1
)
mag
=
torch
.
exp
(
mag
)
mag
=
torch
.
clip
(
mag
,
max
=
1e2
)
# safeguard to prevent excessively large magnitudes
# wrapping happens here. These two lines produce real and imaginary value
x
=
torch
.
cos
(
p
)
y
=
torch
.
sin
(
p
)
# recalculating phase here does not produce anything new
# only costs time
# phase = torch.atan2(y, x)
# S = mag * torch.exp(phase * 1j)
# better directly produce the complex value
S
=
mag
*
(
x
+
1j
*
y
)
audio
=
self
.
istft
(
S
)
return
audio
class
IMDCTSymExpHead
(
FourierHead
):
"""
IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
Args:
dim (int): Hidden dimension of the model.
mdct_frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
based on perceptual scaling. Defaults to None.
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
"""
def
__init__
(
self
,
dim
:
int
,
mdct_frame_len
:
int
,
padding
:
str
=
"same"
,
sample_rate
:
int
=
None
,
clip_audio
:
bool
=
False
,
):
super
().
__init__
()
out_dim
=
mdct_frame_len
//
2
self
.
out
=
nn
.
Linear
(
dim
,
out_dim
)
self
.
imdct
=
IMDCT
(
frame_len
=
mdct_frame_len
,
padding
=
padding
)
self
.
clip_audio
=
clip_audio
if
sample_rate
is
not
None
:
# optionally init the last layer following mel-scale
m_max
=
_hz_to_mel
(
sample_rate
//
2
)
m_pts
=
torch
.
linspace
(
0
,
m_max
,
out_dim
)
f_pts
=
_mel_to_hz
(
m_pts
)
scale
=
1
-
(
f_pts
/
f_pts
.
max
())
with
torch
.
no_grad
():
self
.
out
.
weight
.
mul_
(
scale
.
view
(
-
1
,
1
))
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the IMDCTSymExpHead module.
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
x
=
self
.
out
(
x
)
x
=
symexp
(
x
)
x
=
torch
.
clip
(
x
,
min
=-
1e2
,
max
=
1e2
)
# safeguard to prevent excessively large magnitudes
audio
=
self
.
imdct
(
x
)
if
self
.
clip_audio
:
audio
=
torch
.
clip
(
x
,
min
=-
1.0
,
max
=
1.0
)
return
audio
class
IMDCTCosHead
(
FourierHead
):
"""
IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
Args:
dim (int): Hidden dimension of the model.
mdct_frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
"""
def
__init__
(
self
,
dim
:
int
,
mdct_frame_len
:
int
,
padding
:
str
=
"same"
,
clip_audio
:
bool
=
False
):
super
().
__init__
()
self
.
clip_audio
=
clip_audio
self
.
out
=
nn
.
Linear
(
dim
,
mdct_frame_len
)
self
.
imdct
=
IMDCT
(
frame_len
=
mdct_frame_len
,
padding
=
padding
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the IMDCTCosHead module.
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
x
=
self
.
out
(
x
)
m
,
p
=
x
.
chunk
(
2
,
dim
=
2
)
m
=
torch
.
exp
(
m
).
clip
(
max
=
1e2
)
# safeguard to prevent excessively large magnitudes
audio
=
self
.
imdct
(
m
*
torch
.
cos
(
p
))
if
self
.
clip_audio
:
audio
=
torch
.
clip
(
x
,
min
=-
1.0
,
max
=
1.0
)
return
audio
inspiremusic/wavtokenizer/decoder/helpers.py
0 → 100644
View file @
0112b0f0
import
matplotlib
import
numpy
as
np
import
torch
from
matplotlib
import
pyplot
as
plt
from
pytorch_lightning
import
Callback
matplotlib
.
use
(
"Agg"
)
def
save_figure_to_numpy
(
fig
:
plt
.
Figure
)
->
np
.
ndarray
:
"""
Save a matplotlib figure to a numpy array.
Args:
fig (Figure): Matplotlib figure object.
Returns:
ndarray: Numpy array representing the figure.
"""
data
=
np
.
fromstring
(
fig
.
canvas
.
tostring_rgb
(),
dtype
=
np
.
uint8
,
sep
=
""
)
data
=
data
.
reshape
(
fig
.
canvas
.
get_width_height
()[::
-
1
]
+
(
3
,))
return
data
def
plot_spectrogram_to_numpy
(
spectrogram
:
np
.
ndarray
)
->
np
.
ndarray
:
"""
Plot a spectrogram and convert it to a numpy array.
Args:
spectrogram (ndarray): Spectrogram data.
Returns:
ndarray: Numpy array representing the plotted spectrogram.
"""
spectrogram
=
spectrogram
.
astype
(
np
.
float32
)
fig
,
ax
=
plt
.
subplots
(
figsize
=
(
12
,
3
))
im
=
ax
.
imshow
(
spectrogram
,
aspect
=
"auto"
,
origin
=
"lower"
,
interpolation
=
"none"
)
plt
.
colorbar
(
im
,
ax
=
ax
)
plt
.
xlabel
(
"Frames"
)
plt
.
ylabel
(
"Channels"
)
plt
.
tight_layout
()
fig
.
canvas
.
draw
()
data
=
save_figure_to_numpy
(
fig
)
plt
.
close
()
return
data
class
GradNormCallback
(
Callback
):
"""
Callback to log the gradient norm.
"""
def
on_after_backward
(
self
,
trainer
,
model
):
model
.
log
(
"grad_norm"
,
gradient_norm
(
model
))
def
gradient_norm
(
model
:
torch
.
nn
.
Module
,
norm_type
:
float
=
2.0
)
->
torch
.
Tensor
:
"""
Compute the gradient norm.
Args:
model (Module): PyTorch model.
norm_type (float, optional): Type of the norm. Defaults to 2.0.
Returns:
Tensor: Gradient norm.
"""
grads
=
[
p
.
grad
for
p
in
model
.
parameters
()
if
p
.
grad
is
not
None
]
total_norm
=
torch
.
norm
(
torch
.
stack
([
torch
.
norm
(
g
.
detach
(),
norm_type
)
for
g
in
grads
]),
norm_type
)
return
total_norm
inspiremusic/wavtokenizer/decoder/loss.py
0 → 100644
View file @
0112b0f0
from
typing
import
List
,
Tuple
import
torch
import
torchaudio
from
torch
import
nn
from
decoder.modules
import
safe_log
import
torch.nn.functional
as
F
class
MelSpecReconstructionLoss
(
nn
.
Module
):
"""
L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample
"""
def
__init__
(
self
,
sample_rate
:
int
=
24000
,
n_fft
:
int
=
1024
,
hop_length
:
int
=
256
,
n_mels
:
int
=
100
,
):
super
().
__init__
()
self
.
mel_spec
=
torchaudio
.
transforms
.
MelSpectrogram
(
sample_rate
=
sample_rate
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
n_mels
=
n_mels
,
center
=
True
,
power
=
1
,
)
def
forward
(
self
,
y_hat
,
y
)
->
torch
.
Tensor
:
"""
Args:
y_hat (Tensor): Predicted audio waveform.
y (Tensor): Ground truth audio waveform.
Returns:
Tensor: L1 loss between the mel-scaled magnitude spectrograms.
"""
mel_hat
=
safe_log
(
self
.
mel_spec
(
y_hat
))
mel
=
safe_log
(
self
.
mel_spec
(
y
))
loss
=
torch
.
nn
.
functional
.
l1_loss
(
mel
,
mel_hat
)
return
loss
class
GeneratorLoss
(
nn
.
Module
):
"""
Generator Loss module. Calculates the loss for the generator based on discriminator outputs.
"""
def
forward
(
self
,
disc_outputs
:
List
[
torch
.
Tensor
])
->
Tuple
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]:
"""
Args:
disc_outputs (List[Tensor]): List of discriminator outputs.
Returns:
Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from
the sub-discriminators
"""
loss
=
0
gen_losses
=
[]
for
dg
in
disc_outputs
:
l
=
torch
.
mean
(
torch
.
clamp
(
1
-
dg
,
min
=
0
))
gen_losses
.
append
(
l
)
loss
+=
l
return
loss
,
gen_losses
class
DiscriminatorLoss
(
nn
.
Module
):
"""
Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs.
"""
def
forward
(
self
,
disc_real_outputs
:
List
[
torch
.
Tensor
],
disc_generated_outputs
:
List
[
torch
.
Tensor
]
)
->
Tuple
[
torch
.
Tensor
,
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
]]:
"""
Args:
disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples.
disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples.
Returns:
Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from
the sub-discriminators for real outputs, and a list of
loss values for generated outputs.
"""
loss
=
0
r_losses
=
[]
g_losses
=
[]
for
dr
,
dg
in
zip
(
disc_real_outputs
,
disc_generated_outputs
):
r_loss
=
torch
.
mean
(
torch
.
clamp
(
1
-
dr
,
min
=
0
))
g_loss
=
torch
.
mean
(
torch
.
clamp
(
1
+
dg
,
min
=
0
))
loss
+=
r_loss
+
g_loss
r_losses
.
append
(
r_loss
.
item
())
g_losses
.
append
(
g_loss
.
item
())
return
loss
,
r_losses
,
g_losses
class
FeatureMatchingLoss
(
nn
.
Module
):
"""
Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators.
"""
def
forward
(
self
,
fmap_r
:
List
[
List
[
torch
.
Tensor
]],
fmap_g
:
List
[
List
[
torch
.
Tensor
]])
->
torch
.
Tensor
:
"""
Args:
fmap_r (List[List[Tensor]]): List of feature maps from real samples.
fmap_g (List[List[Tensor]]): List of feature maps from generated samples.
Returns:
Tensor: The calculated feature matching loss.
"""
loss
=
0
for
dr
,
dg
in
zip
(
fmap_r
,
fmap_g
):
for
rl
,
gl
in
zip
(
dr
,
dg
):
loss
+=
torch
.
mean
(
torch
.
abs
(
rl
-
gl
))
return
loss
class
DACGANLoss
(
nn
.
Module
):
"""
Computes a discriminator loss, given a discriminator on
generated waveforms/spectrograms compared to ground truth
waveforms/spectrograms. Computes the loss for both the
discriminator and the generator in separate functions.
"""
def
__init__
(
self
,
discriminator
):
super
().
__init__
()
self
.
discriminator
=
discriminator
def
forward
(
self
,
fake
,
real
):
# d_fake = self.discriminator(fake.audio_data)
# d_real = self.discriminator(real.audio_data)
d_fake
=
self
.
discriminator
(
fake
)
d_real
=
self
.
discriminator
(
real
)
return
d_fake
,
d_real
def
discriminator_loss
(
self
,
fake
,
real
):
d_fake
,
d_real
=
self
.
forward
(
fake
.
clone
().
detach
(),
real
)
loss_d
=
0
for
x_fake
,
x_real
in
zip
(
d_fake
,
d_real
):
loss_d
+=
torch
.
mean
(
x_fake
[
-
1
]
**
2
)
loss_d
+=
torch
.
mean
((
1
-
x_real
[
-
1
])
**
2
)
return
loss_d
def
generator_loss
(
self
,
fake
,
real
):
d_fake
,
d_real
=
self
.
forward
(
fake
,
real
)
loss_g
=
0
for
x_fake
in
d_fake
:
loss_g
+=
torch
.
mean
((
1
-
x_fake
[
-
1
])
**
2
)
loss_feature
=
0
for
i
in
range
(
len
(
d_fake
)):
for
j
in
range
(
len
(
d_fake
[
i
])
-
1
):
loss_feature
+=
F
.
l1_loss
(
d_fake
[
i
][
j
],
d_real
[
i
][
j
].
detach
())
return
loss_g
,
loss_feature
inspiremusic/wavtokenizer/decoder/models.py
0 → 100644
View file @
0112b0f0
from
typing
import
Optional
import
torch
from
torch
import
nn
from
torch.nn.utils
import
weight_norm
from
inspiremusic.wavtokenizer.decoder.modules
import
ConvNeXtBlock
,
ResBlock1
,
AdaLayerNorm
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Normalize
(
in_channels
,
num_groups
=
32
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
norm1
=
Normalize
(
in_channels
)
self
.
conv1
=
torch
.
nn
.
Conv1d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
temb_channels
>
0
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv1d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
torch
.
nn
.
Conv1d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv1d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
temb
=
None
):
h
=
x
h
=
self
.
norm1
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
if
temb
is
not
None
:
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
norm2
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
norm
=
Normalize
(
in_channels
)
self
.
q
=
torch
.
nn
.
Conv1d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv1d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv1d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv1d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
):
h_
=
x
h_
=
self
.
norm
(
h_
)
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
# compute attention
b
,
c
,
h
=
q
.
shape
q
=
q
.
permute
(
0
,
2
,
1
)
# b,hw,c
w_
=
torch
.
bmm
(
q
,
k
)
# b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
# attend to values
w_
=
w_
.
permute
(
0
,
2
,
1
)
# b,hw,hw (first hw of k, second of q)
h_
=
torch
.
bmm
(
v
,
w_
)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_
=
self
.
proj_out
(
h_
)
return
x
+
h_
def
make_attn
(
in_channels
,
attn_type
=
"vanilla"
):
assert
attn_type
in
[
"vanilla"
,
"linear"
,
"none"
],
f
'attn_type
{
attn_type
}
unknown'
print
(
f
"making attention of type '
{
attn_type
}
' with
{
in_channels
}
in_channels"
)
if
attn_type
==
"vanilla"
:
return
AttnBlock
(
in_channels
)
class
Backbone
(
nn
.
Module
):
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
def
forward
(
self
,
x
:
torch
.
Tensor
,
**
kwargs
)
->
torch
.
Tensor
:
"""
Args:
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
C denotes output features, and L is the sequence length.
Returns:
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
and H denotes the model dimension.
"""
raise
NotImplementedError
(
"Subclasses must implement the forward method."
)
class
VocosBackbone
(
Backbone
):
"""
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
Args:
input_channels (int): Number of input features channels.
dim (int): Hidden dimension of the model.
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
num_layers (int): Number of ConvNeXtBlock layers.
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
None means non-conditional model. Defaults to None.
"""
def
__init__
(
self
,
input_channels
:
int
,
dim
:
int
,
intermediate_dim
:
int
,
num_layers
:
int
,
layer_scale_init_value
:
Optional
[
float
]
=
None
,
adanorm_num_embeddings
:
Optional
[
int
]
=
None
,
):
super
().
__init__
()
self
.
input_channels
=
input_channels
self
.
embed
=
nn
.
Conv1d
(
input_channels
,
dim
,
kernel_size
=
7
,
padding
=
3
)
self
.
adanorm
=
adanorm_num_embeddings
is
not
None
if
adanorm_num_embeddings
:
self
.
norm
=
AdaLayerNorm
(
adanorm_num_embeddings
,
dim
,
eps
=
1e-6
)
else
:
self
.
norm
=
nn
.
LayerNorm
(
dim
,
eps
=
1e-6
)
layer_scale_init_value
=
layer_scale_init_value
or
1
/
num_layers
self
.
convnext
=
nn
.
ModuleList
(
[
ConvNeXtBlock
(
dim
=
dim
,
intermediate_dim
=
intermediate_dim
,
layer_scale_init_value
=
layer_scale_init_value
,
adanorm_num_embeddings
=
adanorm_num_embeddings
,
)
for
_
in
range
(
num_layers
)
]
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
dim
,
eps
=
1e-6
)
self
.
apply
(
self
.
_init_weights
)
self
.
temb_ch
=
0
block_in
=
dim
dropout
=
0.1
attn_type
=
"vanilla"
pos_net
:
tp
.
List
[
nn
.
Module
]
=
[
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
),
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
),
make_attn
(
block_in
,
attn_type
=
attn_type
),
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
),
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
),
Normalize
(
block_in
)
]
self
.
pos_net
=
nn
.
Sequential
(
*
pos_net
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
(
nn
.
Conv1d
,
nn
.
Linear
)):
nn
.
init
.
trunc_normal_
(
m
.
weight
,
std
=
0.02
)
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
bandwidth_id
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
x
=
self
.
embed
(
x
)
x
=
self
.
pos_net
(
x
)
if
self
.
adanorm
:
# assert bandwidth_id is not None
if
bandwidth_id
is
None
:
bandwidth_id
=
torch
.
tensor
(
0
,
device
=
'cuda'
)
x
=
self
.
norm
(
x
.
transpose
(
1
,
2
),
cond_embedding_id
=
bandwidth_id
)
else
:
x
=
self
.
norm
(
x
.
transpose
(
1
,
2
))
x
=
x
.
transpose
(
1
,
2
)
for
conv_block
in
self
.
convnext
:
x
=
conv_block
(
x
,
cond_embedding_id
=
bandwidth_id
)
x
=
self
.
final_layer_norm
(
x
.
transpose
(
1
,
2
))
return
x
class
VocosResNetBackbone
(
Backbone
):
"""
Vocos backbone module built with ResBlocks.
Args:
input_channels (int): Number of input features channels.
dim (int): Hidden dimension of the model.
num_blocks (int): Number of ResBlock1 blocks.
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
"""
def
__init__
(
self
,
input_channels
,
dim
,
num_blocks
,
layer_scale_init_value
=
None
,
):
super
().
__init__
()
self
.
input_channels
=
input_channels
self
.
embed
=
weight_norm
(
nn
.
Conv1d
(
input_channels
,
dim
,
kernel_size
=
3
,
padding
=
1
))
layer_scale_init_value
=
layer_scale_init_value
or
1
/
num_blocks
/
3
self
.
resnet
=
nn
.
Sequential
(
*
[
ResBlock1
(
dim
=
dim
,
layer_scale_init_value
=
layer_scale_init_value
)
for
_
in
range
(
num_blocks
)]
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
**
kwargs
)
->
torch
.
Tensor
:
x
=
self
.
embed
(
x
)
x
=
self
.
resnet
(
x
)
x
=
x
.
transpose
(
1
,
2
)
return
x
inspiremusic/wavtokenizer/decoder/modules.py
0 → 100644
View file @
0112b0f0
from
typing
import
Optional
from
typing
import
Tuple
import
torch
from
torch
import
nn
from
torch.nn.utils
import
weight_norm
,
remove_weight_norm
class
ConvNeXtBlock
(
nn
.
Module
):
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
Args:
dim (int): Number of input channels.
intermediate_dim (int): Dimensionality of the intermediate layer.
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
Defaults to None.
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
None means non-conditional LayerNorm. Defaults to None.
"""
def
__init__
(
self
,
dim
:
int
,
intermediate_dim
:
int
,
layer_scale_init_value
:
Optional
[
float
]
=
None
,
adanorm_num_embeddings
:
Optional
[
int
]
=
None
,
):
super
().
__init__
()
self
.
dwconv
=
nn
.
Conv1d
(
dim
,
dim
,
kernel_size
=
7
,
padding
=
3
,
groups
=
dim
)
# depthwise conv
self
.
adanorm
=
adanorm_num_embeddings
is
not
None
if
adanorm_num_embeddings
:
self
.
norm
=
AdaLayerNorm
(
adanorm_num_embeddings
,
dim
,
eps
=
1e-6
)
else
:
self
.
norm
=
nn
.
LayerNorm
(
dim
,
eps
=
1e-6
)
self
.
pwconv1
=
nn
.
Linear
(
dim
,
intermediate_dim
)
# pointwise/1x1 convs, implemented with linear layers
self
.
act
=
nn
.
GELU
()
self
.
pwconv2
=
nn
.
Linear
(
intermediate_dim
,
dim
)
self
.
gamma
=
(
nn
.
Parameter
(
layer_scale_init_value
*
torch
.
ones
(
dim
),
requires_grad
=
True
)
if
layer_scale_init_value
>
0
else
None
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
cond_embedding_id
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
residual
=
x
x
=
self
.
dwconv
(
x
)
x
=
x
.
transpose
(
1
,
2
)
# (B, C, T) -> (B, T, C)
if
self
.
adanorm
:
assert
cond_embedding_id
is
not
None
x
=
self
.
norm
(
x
,
cond_embedding_id
)
else
:
x
=
self
.
norm
(
x
)
x
=
self
.
pwconv1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
pwconv2
(
x
)
if
self
.
gamma
is
not
None
:
x
=
self
.
gamma
*
x
x
=
x
.
transpose
(
1
,
2
)
# (B, T, C) -> (B, C, T)
x
=
residual
+
x
return
x
class
AdaLayerNorm
(
nn
.
Module
):
"""
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
Args:
num_embeddings (int): Number of embeddings.
embedding_dim (int): Dimension of the embeddings.
"""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
eps
:
float
=
1e-6
):
super
().
__init__
()
self
.
eps
=
eps
self
.
dim
=
embedding_dim
self
.
scale
=
nn
.
Embedding
(
num_embeddings
=
num_embeddings
,
embedding_dim
=
embedding_dim
)
self
.
shift
=
nn
.
Embedding
(
num_embeddings
=
num_embeddings
,
embedding_dim
=
embedding_dim
)
torch
.
nn
.
init
.
ones_
(
self
.
scale
.
weight
)
torch
.
nn
.
init
.
zeros_
(
self
.
shift
.
weight
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
cond_embedding_id
:
torch
.
Tensor
)
->
torch
.
Tensor
:
scale
=
self
.
scale
(
cond_embedding_id
)
shift
=
self
.
shift
(
cond_embedding_id
)
x
=
nn
.
functional
.
layer_norm
(
x
,
(
self
.
dim
,),
eps
=
self
.
eps
)
x
=
x
*
scale
+
shift
return
x
class
ResBlock1
(
nn
.
Module
):
"""
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
but without upsampling layers.
Args:
dim (int): Number of input channels.
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
Defaults to (1, 3, 5).
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
Defaults to 0.1.
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
Defaults to None.
"""
def
__init__
(
self
,
dim
:
int
,
kernel_size
:
int
=
3
,
dilation
:
Tuple
[
int
,
...]
=
(
1
,
3
,
5
),
lrelu_slope
:
float
=
0.1
,
layer_scale_init_value
:
float
=
None
,
):
super
().
__init__
()
self
.
lrelu_slope
=
lrelu_slope
self
.
convs1
=
nn
.
ModuleList
(
[
weight_norm
(
nn
.
Conv1d
(
dim
,
dim
,
kernel_size
,
1
,
dilation
=
dilation
[
0
],
padding
=
self
.
get_padding
(
kernel_size
,
dilation
[
0
]),
)
),
weight_norm
(
nn
.
Conv1d
(
dim
,
dim
,
kernel_size
,
1
,
dilation
=
dilation
[
1
],
padding
=
self
.
get_padding
(
kernel_size
,
dilation
[
1
]),
)
),
weight_norm
(
nn
.
Conv1d
(
dim
,
dim
,
kernel_size
,
1
,
dilation
=
dilation
[
2
],
padding
=
self
.
get_padding
(
kernel_size
,
dilation
[
2
]),
)
),
]
)
self
.
convs2
=
nn
.
ModuleList
(
[
weight_norm
(
nn
.
Conv1d
(
dim
,
dim
,
kernel_size
,
1
,
dilation
=
1
,
padding
=
self
.
get_padding
(
kernel_size
,
1
))),
weight_norm
(
nn
.
Conv1d
(
dim
,
dim
,
kernel_size
,
1
,
dilation
=
1
,
padding
=
self
.
get_padding
(
kernel_size
,
1
))),
weight_norm
(
nn
.
Conv1d
(
dim
,
dim
,
kernel_size
,
1
,
dilation
=
1
,
padding
=
self
.
get_padding
(
kernel_size
,
1
))),
]
)
self
.
gamma
=
nn
.
ParameterList
(
[
nn
.
Parameter
(
layer_scale_init_value
*
torch
.
ones
(
dim
,
1
),
requires_grad
=
True
)
if
layer_scale_init_value
is
not
None
else
None
,
nn
.
Parameter
(
layer_scale_init_value
*
torch
.
ones
(
dim
,
1
),
requires_grad
=
True
)
if
layer_scale_init_value
is
not
None
else
None
,
nn
.
Parameter
(
layer_scale_init_value
*
torch
.
ones
(
dim
,
1
),
requires_grad
=
True
)
if
layer_scale_init_value
is
not
None
else
None
,
]
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
for
c1
,
c2
,
gamma
in
zip
(
self
.
convs1
,
self
.
convs2
,
self
.
gamma
):
xt
=
torch
.
nn
.
functional
.
leaky_relu
(
x
,
negative_slope
=
self
.
lrelu_slope
)
xt
=
c1
(
xt
)
xt
=
torch
.
nn
.
functional
.
leaky_relu
(
xt
,
negative_slope
=
self
.
lrelu_slope
)
xt
=
c2
(
xt
)
if
gamma
is
not
None
:
xt
=
gamma
*
xt
x
=
xt
+
x
return
x
def
remove_weight_norm
(
self
):
for
l
in
self
.
convs1
:
remove_weight_norm
(
l
)
for
l
in
self
.
convs2
:
remove_weight_norm
(
l
)
@
staticmethod
def
get_padding
(
kernel_size
:
int
,
dilation
:
int
=
1
)
->
int
:
return
int
((
kernel_size
*
dilation
-
dilation
)
/
2
)
def
safe_log
(
x
:
torch
.
Tensor
,
clip_val
:
float
=
1e-7
)
->
torch
.
Tensor
:
"""
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
Args:
x (Tensor): Input tensor.
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
Returns:
Tensor: Element-wise logarithm of the input tensor with clipping applied.
"""
return
torch
.
log
(
torch
.
clip
(
x
,
min
=
clip_val
))
def
symlog
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
sign
(
x
)
*
torch
.
log1p
(
x
.
abs
())
def
symexp
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
sign
(
x
)
*
(
torch
.
exp
(
x
.
abs
())
-
1
)
inspiremusic/wavtokenizer/decoder/pretrained.py
0 → 100644
View file @
0112b0f0
import
os
from
typing
import
Tuple
,
Any
,
Union
,
Dict
import
torch
import
yaml
from
huggingface_hub
import
hf_hub_download
from
torch
import
nn
from
inspiremusic.wavtokenizer.decoder.feature_extractors
import
FeatureExtractor
,
EncodecFeatures
from
inspiremusic.wavtokenizer.decoder.heads
import
FourierHead
from
inspiremusic.wavtokenizer.decoder.models
import
Backbone
def
instantiate_class
(
args
:
Union
[
Any
,
Tuple
[
Any
,
...]],
init
:
Dict
[
str
,
Any
])
->
Any
:
"""Instantiates a class with the given args and init.
Args:
args: Positional arguments required for instantiation.
init: Dict of the form {"class_path":...,"init_args":...}.
Returns:
The instantiated class object.
"""
kwargs
=
init
.
get
(
"init_args"
,
{})
if
not
isinstance
(
args
,
tuple
):
args
=
(
args
,)
class_module
,
class_name
=
init
[
"class_path"
].
rsplit
(
"."
,
1
)
module
=
__import__
(
class_module
,
fromlist
=
[
class_name
])
args_class
=
getattr
(
module
,
class_name
)
return
args_class
(
*
args
,
**
kwargs
)
class
WavTokenizer
(
nn
.
Module
):
"""
The Vocos class represents a Fourier-based neural vocoder for audio synthesis.
This class is primarily designed for inference, with support for loading from pretrained
model checkpoints. It consists of three main components: a feature extractor,
a backbone, and a head.
"""
def
__init__
(
self
,
feature_extractor
:
FeatureExtractor
,
backbone
:
Backbone
,
head
:
FourierHead
,
):
super
().
__init__
()
self
.
feature_extractor
=
feature_extractor
self
.
backbone
=
backbone
self
.
head
=
head
@
classmethod
def
from_hparams
(
cls
,
config_path
:
str
)
->
"Vocos"
:
"""
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
with
open
(
config_path
,
"r"
)
as
f
:
config
=
yaml
.
safe_load
(
f
)
feature_extractor
=
instantiate_class
(
args
=
(),
init
=
config
[
"feature_extractor"
])
backbone
=
instantiate_class
(
args
=
(),
init
=
config
[
"backbone"
])
head
=
instantiate_class
(
args
=
(),
init
=
config
[
"head"
])
model
=
cls
(
feature_extractor
=
feature_extractor
,
backbone
=
backbone
,
head
=
head
)
return
model
@
classmethod
def
from_pretrained
(
self
,
repo_id
:
str
)
->
"Vocos"
:
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
config_path
=
hf_hub_download
(
repo_id
=
repo_id
,
filename
=
"config.yaml"
)
model_path
=
hf_hub_download
(
repo_id
=
repo_id
,
filename
=
"pytorch_model.bin"
)
model
=
self
.
from_hparams
(
config_path
)
state_dict
=
torch
.
load
(
model_path
,
map_location
=
"cpu"
)
if
isinstance
(
model
.
feature_extractor
,
EncodecFeatures
):
encodec_parameters
=
{
"feature_extractor.encodec."
+
key
:
value
for
key
,
value
in
model
.
feature_extractor
.
encodec
.
state_dict
().
items
()
}
state_dict
.
update
(
encodec_parameters
)
model
.
load_state_dict
(
state_dict
)
model
.
eval
()
return
model
@
classmethod
def
from_hparams_feat
(
cls
,
config_path
:
str
)
->
"Vocos"
:
"""
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
with
open
(
config_path
,
"r"
)
as
f
:
config
=
yaml
.
safe_load
(
f
)
feature_extractor
=
instantiate_class
(
args
=
(),
init
=
config
[
'model'
][
'init_args'
][
"feature_extractor"
])
backbone
=
instantiate_class
(
args
=
(),
init
=
config
[
'model'
][
'init_args'
][
"backbone"
])
head
=
instantiate_class
(
args
=
(),
init
=
config
[
'model'
][
'init_args'
][
"head"
])
model
=
cls
(
feature_extractor
=
feature_extractor
,
backbone
=
backbone
,
head
=
head
)
return
model
@
classmethod
def
from_pretrained_feat
(
self
,
config_path
,
model_path
):
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
model
=
self
.
from_hparams_feat
(
config_path
)
state_dict_raw
=
torch
.
load
(
model_path
,
map_location
=
"cpu"
)[
'state_dict'
]
state_dict
=
dict
()
for
k
,
v
in
state_dict_raw
.
items
():
if
k
.
startswith
(
'backbone.'
)
or
k
.
startswith
(
'head.'
)
or
k
.
startswith
(
'feature_extractor.'
):
state_dict
[
k
]
=
v
model
.
load_state_dict
(
state_dict
)
model
.
eval
()
return
model
@
classmethod
def
estimator
(
self
,
config_path
,
model_path
):
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
model
=
self
.
from_hparams_feat
(
config_path
)
state_dict_raw
=
torch
.
load
(
model_path
,
map_location
=
"cpu"
)[
'state_dict'
]
state_dict
=
dict
()
for
k
,
v
in
state_dict_raw
.
items
():
if
k
.
startswith
(
'backbone.'
)
or
k
.
startswith
(
'head.'
)
or
k
.
startswith
(
'feature_extractor.'
):
state_dict
[
k
]
=
v
model
.
load_state_dict
(
state_dict
)
model
.
eval
()
return
model
@
classmethod
def
from_pretrained0911
(
self
,
config_path
,
model_folder_path
):
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
model
=
self
.
from_hparams0802
(
config_path
)
models
=
os
.
listdir
(
model_folder_path
)
val_loss
=
[]
for
item
in
models
:
if
not
item
.
startswith
(
'vocos_'
):
continue
val_loss
.
append
(
item
[
-
11
:
-
5
])
val_loss
.
sort
()
val_loss
=
val_loss
[:
3
]
# 取前3性能较好的模型平均
state_dict
=
dict
()
state_dicts
=
[]
for
item
in
models
:
if
not
item
.
startswith
(
'vocos_'
):
continue
ll
=
item
[
-
11
:
-
5
]
if
ll
not
in
val_loss
:
continue
model_path
=
model_folder_path
+
'/'
+
item
state_dict_raw
=
torch
.
load
(
model_path
,
map_location
=
"cpu"
)[
'state_dict'
]
state_dict_single
=
dict
()
for
k
,
v
in
state_dict_raw
.
items
():
if
k
.
startswith
(
'backbone.'
)
or
k
.
startswith
(
'head.'
)
or
k
.
startswith
(
'feature_extractor.'
):
state_dict_single
[
k
]
=
v
state_dicts
.
append
(
state_dict_single
)
for
kk
in
state_dicts
[
0
].
keys
():
vv
=
state_dicts
[
0
][
kk
]
for
i
in
range
(
1
,
len
(
state_dicts
)):
ss
=
state_dicts
[
i
]
vv
+=
ss
[
kk
]
vm
=
vv
/
len
(
state_dicts
)
state_dict
[
kk
]
=
vm
model
.
load_state_dict
(
state_dict
)
model
.
eval
()
return
model
@
torch
.
inference_mode
()
def
forward
(
self
,
audio_input
:
torch
.
Tensor
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
"""
Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input,
which is then passed through the backbone and the head to reconstruct the audio output.
Args:
audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T),
where B is the batch size and L is the waveform length.
Returns:
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
"""
features
,
_
,
_
=
self
.
feature_extractor
(
audio_input
,
**
kwargs
)
# 0818
audio_output
=
self
.
decode
(
features
,
**
kwargs
)
return
audio_output
# 0818
@
torch
.
inference_mode
()
def
encode
(
self
,
audio_input
:
torch
.
Tensor
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
features
,
discrete_codes
,
_
=
self
.
feature_extractor
(
audio_input
,
**
kwargs
)
return
features
,
discrete_codes
# 0818
@
torch
.
inference_mode
()
def
encode_infer
(
self
,
audio_input
:
torch
.
Tensor
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
features
,
discrete_codes
,
_
=
self
.
feature_extractor
.
infer
(
audio_input
,
**
kwargs
)
return
features
,
discrete_codes
@
torch
.
inference_mode
()
def
infer
(
self
,
audio_input
:
torch
.
Tensor
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
_
,
discrete_codes
,
_
=
self
.
feature_extractor
.
_infer
(
audio_input
,
**
kwargs
)
discrete_codes
=
discrete_codes
.
clamp
(
min
=
0
,
max
=
16383
)
return
discrete_codes
@
torch
.
inference_mode
()
def
decode
(
self
,
features_input
:
torch
.
Tensor
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
"""
Method to decode audio waveform from already calculated features. The features input is passed through
the backbone and the head to reconstruct the audio output.
Args:
features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size,
C denotes the feature dimension, and L is the sequence length.
Returns:
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
"""
x
=
self
.
backbone
(
features_input
,
**
kwargs
)
audio_output
=
self
.
head
(
x
)
return
audio_output
@
torch
.
inference_mode
()
def
codes_to_features
(
self
,
codes
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's
codebook weights.
Args:
codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L),
where K is the number of codebooks, B is the batch size and L is the sequence length.
Returns:
Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension,
and L is the sequence length.
"""
assert
isinstance
(
self
.
feature_extractor
,
EncodecFeatures
),
"Feature extractor should be an instance of EncodecFeatures"
if
codes
.
dim
()
==
2
:
codes
=
codes
.
unsqueeze
(
1
)
n_bins
=
self
.
feature_extractor
.
encodec
.
quantizer
.
bins
offsets
=
torch
.
arange
(
0
,
n_bins
*
len
(
codes
),
n_bins
,
device
=
codes
.
device
)
embeddings_idxs
=
codes
+
offsets
.
view
(
-
1
,
1
,
1
)
tmp
=
torch
.
cat
([
vq
.
codebook
for
vq
in
self
.
feature_extractor
.
encodec
.
quantizer
.
vq
.
layers
],
dim
=
0
)
# features = torch.nn.functional.embedding(embeddings_idxs, self.feature_extractor.codebook_weights).sum(dim=0)
features
=
torch
.
nn
.
functional
.
embedding
(
embeddings_idxs
,
tmp
).
sum
(
dim
=
0
)
features
=
features
.
transpose
(
1
,
2
)
return
features
inspiremusic/wavtokenizer/decoder/pretrained_model.py
0 → 100644
View file @
0112b0f0
from
typing
import
Tuple
,
Any
,
Union
,
Dict
import
torch
import
yaml
from
huggingface_hub
import
hf_hub_download
from
torch
import
nn
from
inspiremusic.wavtokenizer.decoder.feature_extractors
import
FeatureExtractor
,
EncodecFeatures
from
inspiremusic.wavtokenizer.decoder.heads
import
FourierHead
from
inspiremusic.wavtokenizer.decoder.models
import
Backbone
from
inspiremusic.wavtokenizer.decoder.discriminators
import
MultiPeriodDiscriminator
,
MultiResolutionDiscriminator
def
instantiate_class
(
args
:
Union
[
Any
,
Tuple
[
Any
,
...]],
init
:
Dict
[
str
,
Any
])
->
Any
:
"""Instantiates a class with the given args and init.
Args:
args: Positional arguments required for instantiation.
init: Dict of the form {"class_path":...,"init_args":...}.
Returns:
The instantiated class object.
"""
kwargs
=
init
.
get
(
"init_args"
,
{})
if
not
isinstance
(
args
,
tuple
):
args
=
(
args
,)
class_module
,
class_name
=
init
[
"class_path"
].
rsplit
(
"."
,
1
)
module
=
__import__
(
class_module
,
fromlist
=
[
class_name
])
args_class
=
getattr
(
module
,
class_name
)
return
args_class
(
*
args
,
**
kwargs
)
class
WavTokenizer
(
nn
.
Module
):
"""
The Vocos class represents a Fourier-based neural vocoder for audio synthesis.
This class is primarily designed for inference, with support for loading from pretrained
model checkpoints. It consists of three main components: a feature extractor,
a backbone, and a head.
"""
def
__init__
(
self
,
feature_extractor
:
FeatureExtractor
,
backbone
:
Backbone
,
head
:
FourierHead
,
multiperioddisc
:
MultiPeriodDiscriminator
,
multiresddisc
:
MultiResolutionDiscriminator
,
):
super
().
__init__
()
self
.
feature_extractor
=
feature_extractor
self
.
backbone
=
backbone
self
.
head
=
head
self
.
multiperioddisc
=
multiperioddisc
self
.
multiresddisc
=
multiresddisc
@
classmethod
def
from_hparams0828
(
cls
,
config_path
:
str
)
->
"Vocos"
:
"""
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
with
open
(
config_path
,
"r"
)
as
f
:
config
=
yaml
.
safe_load
(
f
)
feature_extractor
=
instantiate_class
(
args
=
(),
init
=
config
[
'model'
][
'init_args'
][
"feature_extractor"
])
backbone
=
instantiate_class
(
args
=
(),
init
=
config
[
'model'
][
'init_args'
][
"backbone"
])
head
=
instantiate_class
(
args
=
(),
init
=
config
[
'model'
][
'init_args'
][
"head"
])
model
=
cls
(
feature_extractor
=
feature_extractor
,
backbone
=
backbone
,
head
=
head
,
multiperioddisc
=
MultiPeriodDiscriminator
(
num_embeddings
=
4
),
multiresddisc
=
MultiResolutionDiscriminator
(
num_embeddings
=
4
))
return
model
@
classmethod
def
from_pretrained0828
(
self
,
config_path
,
model_path
):
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
model
=
self
.
from_hparams0828
(
config_path
)
state_dict_raw
=
torch
.
load
(
model_path
,
map_location
=
"cpu"
)[
'state_dict'
]
state_dict
=
dict
()
for
k
,
v
in
state_dict_raw
.
items
():
if
k
.
startswith
(
'backbone.'
)
or
k
.
startswith
(
'head.'
)
or
k
.
startswith
(
'feature_extractor.'
)
\
or
k
.
startswith
(
'multiperioddisc.'
)
or
k
.
startswith
(
'multiresddisc.'
):
state_dict
[
k
]
=
v
# if isinstance(model.feature_extractor, EncodecFeatures):
# encodec_parameters = {
# "feature_extractor.encodec." + key: value
# for key, value in model.feature_extractor.encodec.state_dict().items()
# }
# state_dict.update(encodec_parameters)
model
.
load_state_dict
(
state_dict
)
return
model
@
classmethod
def
from_hparams0802
(
cls
,
config_path
:
str
)
->
"Vocos"
:
"""
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
with
open
(
config_path
,
"r"
)
as
f
:
config
=
yaml
.
safe_load
(
f
)
feature_extractor
=
instantiate_class
(
args
=
(),
init
=
config
[
'model'
][
'init_args'
][
"feature_extractor"
])
backbone
=
instantiate_class
(
args
=
(),
init
=
config
[
'model'
][
'init_args'
][
"backbone"
])
head
=
instantiate_class
(
args
=
(),
init
=
config
[
'model'
][
'init_args'
][
"head"
])
model
=
cls
(
feature_extractor
=
feature_extractor
,
backbone
=
backbone
,
head
=
head
)
return
model
@
classmethod
def
from_pretrained0802
(
self
,
config_path
,
model_path
):
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
model
=
self
.
from_hparams0802
(
config_path
)
state_dict_raw
=
torch
.
load
(
model_path
,
map_location
=
"cpu"
)[
'state_dict'
]
state_dict
=
dict
()
for
k
,
v
in
state_dict_raw
.
items
():
if
k
.
startswith
(
'backbone.'
)
or
k
.
startswith
(
'head.'
)
or
k
.
startswith
(
'feature_extractor.'
):
state_dict
[
k
]
=
v
# if isinstance(model.feature_extractor, EncodecFeatures):
# encodec_parameters = {
# "feature_extractor.encodec." + key: value
# for key, value in model.feature_extractor.encodec.state_dict().items()
# }
# state_dict.update(encodec_parameters)
model
.
load_state_dict
(
state_dict
)
model
.
eval
()
return
model
@
torch
.
inference_mode
()
def
forward
(
self
,
audio_input
:
torch
.
Tensor
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
"""
Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input,
which is then passed through the backbone and the head to reconstruct the audio output.
Args:
audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T),
where B is the batch size and L is the waveform length.
Returns:
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
"""
features
,
_
,
_
=
self
.
feature_extractor
(
audio_input
,
**
kwargs
)
# 0818
audio_output
=
self
.
decode
(
features
,
**
kwargs
)
return
audio_output
# 0818
@
torch
.
inference_mode
()
def
encode
(
self
,
audio_input
:
torch
.
Tensor
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
features
,
_
,
_
=
self
.
feature_extractor
(
audio_input
,
**
kwargs
)
return
features
@
torch
.
inference_mode
()
def
decode
(
self
,
features_input
:
torch
.
Tensor
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
"""
Method to decode audio waveform from already calculated features. The features input is passed through
the backbone and the head to reconstruct the audio output.
Args:
features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size,
C denotes the feature dimension, and L is the sequence length.
Returns:
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
"""
x
=
self
.
backbone
(
features_input
,
**
kwargs
)
audio_output
=
self
.
head
(
x
)
return
audio_output
@
torch
.
inference_mode
()
def
codes_to_features
(
self
,
codes
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's
codebook weights.
Args:
codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L),
where K is the number of codebooks, B is the batch size and L is the sequence length.
Returns:
Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension,
and L is the sequence length.
"""
assert
isinstance
(
self
.
feature_extractor
,
EncodecFeatures
),
"Feature extractor should be an instance of EncodecFeatures"
if
codes
.
dim
()
==
2
:
codes
=
codes
.
unsqueeze
(
1
)
n_bins
=
self
.
feature_extractor
.
encodec
.
quantizer
.
bins
offsets
=
torch
.
arange
(
0
,
n_bins
*
len
(
codes
),
n_bins
,
device
=
codes
.
device
)
embeddings_idxs
=
codes
+
offsets
.
view
(
-
1
,
1
,
1
)
features
=
torch
.
nn
.
functional
.
embedding
(
embeddings_idxs
,
self
.
feature_extractor
.
codebook_weights
).
sum
(
dim
=
0
)
features
=
features
.
transpose
(
1
,
2
)
return
features
inspiremusic/wavtokenizer/decoder/spectral_ops.py
0 → 100644
View file @
0112b0f0
import
numpy
as
np
import
scipy
import
torch
from
torch
import
nn
,
view_as_real
,
view_as_complex
import
pdb
class
ISTFT
(
nn
.
Module
):
"""
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
See issue: https://github.com/pytorch/pytorch/issues/62323
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
The NOLA constraint is met as we trim padded samples anyway.
Args:
n_fft (int): Size of Fourier transform.
hop_length (int): The distance between neighboring sliding window frames.
win_length (int): The size of window frame and STFT filter.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def
__init__
(
self
,
n_fft
:
int
,
hop_length
:
int
,
win_length
:
int
,
padding
:
str
=
"same"
):
super
().
__init__
()
if
padding
not
in
[
"center"
,
"same"
]:
raise
ValueError
(
"Padding must be 'center' or 'same'."
)
self
.
padding
=
padding
self
.
n_fft
=
n_fft
self
.
hop_length
=
hop_length
self
.
win_length
=
win_length
window
=
torch
.
hann_window
(
win_length
)
self
.
register_buffer
(
"window"
,
window
)
def
forward
(
self
,
spec
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
Args:
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
N is the number of frequency bins, and T is the number of time frames.
Returns:
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
"""
if
self
.
padding
==
"center"
:
# Fallback to pytorch native implementation
return
torch
.
istft
(
spec
,
self
.
n_fft
,
self
.
hop_length
,
self
.
win_length
,
self
.
window
,
center
=
True
)
elif
self
.
padding
==
"same"
:
pad
=
(
self
.
win_length
-
self
.
hop_length
)
//
2
else
:
raise
ValueError
(
"Padding must be 'center' or 'same'."
)
assert
spec
.
dim
()
==
3
,
"Expected a 3D tensor as input"
B
,
N
,
T
=
spec
.
shape
# Inverse FFT
ifft
=
torch
.
fft
.
irfft
(
spec
,
self
.
n_fft
,
dim
=
1
,
norm
=
"backward"
)
ifft
=
ifft
*
self
.
window
[
None
,
:,
None
]
# Overlap and Add
output_size
=
(
T
-
1
)
*
self
.
hop_length
+
self
.
win_length
y
=
torch
.
nn
.
functional
.
fold
(
ifft
,
output_size
=
(
1
,
output_size
),
kernel_size
=
(
1
,
self
.
win_length
),
stride
=
(
1
,
self
.
hop_length
),
)[:,
0
,
0
,
pad
:
-
pad
]
# Window envelope
window_sq
=
self
.
window
.
square
().
expand
(
1
,
T
,
-
1
).
transpose
(
1
,
2
)
window_envelope
=
torch
.
nn
.
functional
.
fold
(
window_sq
,
output_size
=
(
1
,
output_size
),
kernel_size
=
(
1
,
self
.
win_length
),
stride
=
(
1
,
self
.
hop_length
),
).
squeeze
()[
pad
:
-
pad
]
# Normalize
# assert (window_envelope > 1e-11).all()
if
not
torch
.
all
(
window_envelope
>
1e-11
):
window_envelope
=
torch
.
clamp
(
window_envelope
,
min
=
1e-11
)
y
=
y
/
window_envelope
return
y
def
onnx_forward
(
self
,
spec
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
Args:
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
N is the number of frequency bins, and T is the number of time frames.
Returns:
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
"""
if
self
.
padding
==
"center"
:
# Fallback to pytorch native implementation
return
torch
.
istft
(
spec
,
self
.
n_fft
,
self
.
hop_length
,
self
.
win_length
,
self
.
window
,
center
=
True
)
elif
self
.
padding
==
"same"
:
pad
=
(
self
.
win_length
-
self
.
hop_length
)
//
2
else
:
raise
ValueError
(
"Padding must be 'center' or 'same'."
)
assert
spec
.
dim
()
==
3
,
"Expected a 3D tensor as input"
B
,
N
,
T
=
spec
.
shape
pdb
.
set_trace
()
# Inverse FFT
ifft
=
torch
.
fft
.
irfft
(
spec
,
self
.
n_fft
,
dim
=
1
,
norm
=
"backward"
)
ifft
=
ifft
*
self
.
window
[
None
,
:,
None
]
# Overlap and Add
output_size
=
(
T
-
1
)
*
self
.
hop_length
+
self
.
win_length
y
=
torch
.
nn
.
functional
.
fold
(
ifft
,
output_size
=
(
1
,
output_size
),
kernel_size
=
(
1
,
self
.
win_length
),
stride
=
(
1
,
self
.
hop_length
),
)[:,
0
,
0
,
pad
:
-
pad
]
# Window envelope
window_sq
=
self
.
window
.
square
().
expand
(
1
,
T
,
-
1
).
transpose
(
1
,
2
)
window_envelope
=
torch
.
nn
.
functional
.
fold
(
window_sq
,
output_size
=
(
1
,
output_size
),
kernel_size
=
(
1
,
self
.
win_length
),
stride
=
(
1
,
self
.
hop_length
),
).
squeeze
()[
pad
:
-
pad
]
# Normalize
# assert (window_envelope > 1e-11).all()
if
not
torch
.
all
(
window_envelope
>
1e-11
):
window_envelope
=
torch
.
clamp
(
window_envelope
,
min
=
1e-11
)
y
=
y
/
window_envelope
return
y
class
MDCT
(
nn
.
Module
):
"""
Modified Discrete Cosine Transform (MDCT) module.
Args:
frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def
__init__
(
self
,
frame_len
:
int
,
padding
:
str
=
"same"
):
super
().
__init__
()
if
padding
not
in
[
"center"
,
"same"
]:
raise
ValueError
(
"Padding must be 'center' or 'same'."
)
self
.
padding
=
padding
self
.
frame_len
=
frame_len
N
=
frame_len
//
2
n0
=
(
N
+
1
)
/
2
window
=
torch
.
from_numpy
(
scipy
.
signal
.
cosine
(
frame_len
)).
float
()
self
.
register_buffer
(
"window"
,
window
)
pre_twiddle
=
torch
.
exp
(
-
1j
*
torch
.
pi
*
torch
.
arange
(
frame_len
)
/
frame_len
)
post_twiddle
=
torch
.
exp
(
-
1j
*
torch
.
pi
*
n0
*
(
torch
.
arange
(
N
)
+
0.5
)
/
N
)
# view_as_real: NCCL Backend does not support ComplexFloat data type
# https://github.com/pytorch/pytorch/issues/71613
self
.
register_buffer
(
"pre_twiddle"
,
view_as_real
(
pre_twiddle
))
self
.
register_buffer
(
"post_twiddle"
,
view_as_real
(
post_twiddle
))
def
forward
(
self
,
audio
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
Args:
audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
and T is the length of the audio.
Returns:
Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
and N is the number of frequency bins.
"""
if
self
.
padding
==
"center"
:
audio
=
torch
.
nn
.
functional
.
pad
(
audio
,
(
self
.
frame_len
//
2
,
self
.
frame_len
//
2
))
elif
self
.
padding
==
"same"
:
# hop_length is 1/2 frame_len
audio
=
torch
.
nn
.
functional
.
pad
(
audio
,
(
self
.
frame_len
//
4
,
self
.
frame_len
//
4
))
else
:
raise
ValueError
(
"Padding must be 'center' or 'same'."
)
x
=
audio
.
unfold
(
-
1
,
self
.
frame_len
,
self
.
frame_len
//
2
)
N
=
self
.
frame_len
//
2
x
=
x
*
self
.
window
.
expand
(
x
.
shape
)
X
=
torch
.
fft
.
fft
(
x
*
view_as_complex
(
self
.
pre_twiddle
).
expand
(
x
.
shape
),
dim
=-
1
)[...,
:
N
]
res
=
X
*
view_as_complex
(
self
.
post_twiddle
).
expand
(
X
.
shape
)
*
np
.
sqrt
(
1
/
N
)
return
torch
.
real
(
res
)
*
np
.
sqrt
(
2
)
class
IMDCT
(
nn
.
Module
):
"""
Inverse Modified Discrete Cosine Transform (IMDCT) module.
Args:
frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def
__init__
(
self
,
frame_len
:
int
,
padding
:
str
=
"same"
):
super
().
__init__
()
if
padding
not
in
[
"center"
,
"same"
]:
raise
ValueError
(
"Padding must be 'center' or 'same'."
)
self
.
padding
=
padding
self
.
frame_len
=
frame_len
N
=
frame_len
//
2
n0
=
(
N
+
1
)
/
2
window
=
torch
.
from_numpy
(
scipy
.
signal
.
cosine
(
frame_len
)).
float
()
self
.
register_buffer
(
"window"
,
window
)
pre_twiddle
=
torch
.
exp
(
1j
*
torch
.
pi
*
n0
*
torch
.
arange
(
N
*
2
)
/
N
)
post_twiddle
=
torch
.
exp
(
1j
*
torch
.
pi
*
(
torch
.
arange
(
N
*
2
)
+
n0
)
/
(
N
*
2
))
self
.
register_buffer
(
"pre_twiddle"
,
view_as_real
(
pre_twiddle
))
self
.
register_buffer
(
"post_twiddle"
,
view_as_real
(
post_twiddle
))
def
forward
(
self
,
X
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
Args:
X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
L is the number of frames, and N is the number of frequency bins.
Returns:
Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
"""
B
,
L
,
N
=
X
.
shape
Y
=
torch
.
zeros
((
B
,
L
,
N
*
2
),
dtype
=
X
.
dtype
,
device
=
X
.
device
)
Y
[...,
:
N
]
=
X
Y
[...,
N
:]
=
-
1
*
torch
.
conj
(
torch
.
flip
(
X
,
dims
=
(
-
1
,)))
y
=
torch
.
fft
.
ifft
(
Y
*
view_as_complex
(
self
.
pre_twiddle
).
expand
(
Y
.
shape
),
dim
=-
1
)
y
=
torch
.
real
(
y
*
view_as_complex
(
self
.
post_twiddle
).
expand
(
y
.
shape
))
*
np
.
sqrt
(
N
)
*
np
.
sqrt
(
2
)
result
=
y
*
self
.
window
.
expand
(
y
.
shape
)
output_size
=
(
1
,
(
L
+
1
)
*
N
)
audio
=
torch
.
nn
.
functional
.
fold
(
result
.
transpose
(
1
,
2
),
output_size
=
output_size
,
kernel_size
=
(
1
,
self
.
frame_len
),
stride
=
(
1
,
self
.
frame_len
//
2
),
)[:,
0
,
0
,
:]
if
self
.
padding
==
"center"
:
pad
=
self
.
frame_len
//
2
elif
self
.
padding
==
"same"
:
pad
=
self
.
frame_len
//
4
else
:
raise
ValueError
(
"Padding must be 'center' or 'same'."
)
audio
=
audio
[:,
pad
:
-
pad
]
return
audio
inspiremusic/wavtokenizer/encoder/__init__.py
0 → 100644
View file @
0112b0f0
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# flake8: noqa
"""EnCodec neural audio codec."""
__version__
=
"0.1.2a3"
from
.model
import
EncodecModel
inspiremusic/wavtokenizer/encoder/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
inspiremusic/wavtokenizer/encoder/__pycache__/distrib.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
inspiremusic/wavtokenizer/encoder/__pycache__/model.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
inspiremusic/wavtokenizer/encoder/__pycache__/utils.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
inspiremusic/wavtokenizer/encoder/distrib.py
0 → 100644
View file @
0112b0f0
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Torch distributed utilities."""
import
typing
as
tp
import
torch
def
rank
():
if
torch
.
distributed
.
is_initialized
():
return
torch
.
distributed
.
get_rank
()
else
:
return
0
def
world_size
():
if
torch
.
distributed
.
is_initialized
():
return
torch
.
distributed
.
get_world_size
()
else
:
return
1
def
is_distributed
():
return
world_size
()
>
1
def
all_reduce
(
tensor
:
torch
.
Tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
):
if
is_distributed
():
return
torch
.
distributed
.
all_reduce
(
tensor
,
op
)
def
_is_complex_or_float
(
tensor
):
return
torch
.
is_floating_point
(
tensor
)
or
torch
.
is_complex
(
tensor
)
def
_check_number_of_params
(
params
:
tp
.
List
[
torch
.
Tensor
]):
# utility function to check that the number of params in all workers is the same,
# and thus avoid a deadlock with distributed all reduce.
if
not
is_distributed
()
or
not
params
:
return
tensor
=
torch
.
tensor
([
len
(
params
)],
device
=
params
[
0
].
device
,
dtype
=
torch
.
long
)
all_reduce
(
tensor
)
if
tensor
.
item
()
!=
len
(
params
)
*
world_size
():
# If not all the workers have the same number, for at least one of them,
# this inequality will be verified.
raise
RuntimeError
(
f
"Mismatch in number of params: ours is
{
len
(
params
)
}
, "
"at least one worker has a different one."
)
def
broadcast_tensors
(
tensors
:
tp
.
Iterable
[
torch
.
Tensor
],
src
:
int
=
0
):
"""Broadcast the tensors from the given parameters to all workers.
This can be used to ensure that all workers have the same model to start with.
"""
if
not
is_distributed
():
return
tensors
=
[
tensor
for
tensor
in
tensors
if
_is_complex_or_float
(
tensor
)]
_check_number_of_params
(
tensors
)
handles
=
[]
for
tensor
in
tensors
:
handle
=
torch
.
distributed
.
broadcast
(
tensor
.
data
,
src
=
src
,
async_op
=
True
)
handles
.
append
(
handle
)
for
handle
in
handles
:
handle
.
wait
()
def
sync_buffer
(
buffers
,
average
=
True
):
"""
Sync grad for buffers. If average is False, broadcast instead of averaging.
"""
if
not
is_distributed
():
return
handles
=
[]
for
buffer
in
buffers
:
if
torch
.
is_floating_point
(
buffer
.
data
):
if
average
:
handle
=
torch
.
distributed
.
all_reduce
(
buffer
.
data
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
async_op
=
True
)
else
:
handle
=
torch
.
distributed
.
broadcast
(
buffer
.
data
,
src
=
0
,
async_op
=
True
)
handles
.
append
((
buffer
,
handle
))
for
buffer
,
handle
in
handles
:
handle
.
wait
()
if
average
:
buffer
.
data
/=
world_size
def
sync_grad
(
params
):
"""
Simpler alternative to DistributedDataParallel, that doesn't rely
on any black magic. For simple models it can also be as fast.
Just call this on your model parameters after the call to backward!
"""
if
not
is_distributed
():
return
handles
=
[]
for
p
in
params
:
if
p
.
grad
is
not
None
:
handle
=
torch
.
distributed
.
all_reduce
(
p
.
grad
.
data
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
async_op
=
True
)
handles
.
append
((
p
,
handle
))
for
p
,
handle
in
handles
:
handle
.
wait
()
p
.
grad
.
data
/=
world_size
()
def
average_metrics
(
metrics
:
tp
.
Dict
[
str
,
float
],
count
=
1.
):
"""Average a dictionary of metrics across all workers, using the optional
`count` as unnormalized weight.
"""
if
not
is_distributed
():
return
metrics
keys
,
values
=
zip
(
*
metrics
.
items
())
device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
tensor
=
torch
.
tensor
(
list
(
values
)
+
[
1
],
device
=
device
,
dtype
=
torch
.
float32
)
tensor
*=
count
all_reduce
(
tensor
)
averaged
=
(
tensor
[:
-
1
]
/
tensor
[
-
1
]).
cpu
().
tolist
()
return
dict
(
zip
(
keys
,
averaged
))
inspiremusic/wavtokenizer/encoder/model.py
0 → 100644
View file @
0112b0f0
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""EnCodec model implementation."""
import
math
from
pathlib
import
Path
import
typing
as
tp
import
numpy
as
np
import
torch
from
torch
import
nn
from
.
import
quantization
as
qt
from
.
import
modules
as
m
from
.utils
import
_check_checksum
,
_linear_overlap_add
,
_get_checkpoint_url
ROOT_URL
=
'https://dl.fbaipublicfiles.com/encodec/v0/'
EncodedFrame
=
tp
.
Tuple
[
torch
.
Tensor
,
tp
.
Optional
[
torch
.
Tensor
]]
class
LMModel
(
nn
.
Module
):
"""Language Model to estimate probabilities of each codebook entry.
We predict all codebooks in parallel for a given time step.
Args:
n_q (int): number of codebooks.
card (int): codebook cardinality.
dim (int): transformer dimension.
**kwargs: passed to `encoder.modules.transformer.StreamingTransformerEncoder`.
"""
def
__init__
(
self
,
n_q
:
int
=
32
,
card
:
int
=
1024
,
dim
:
int
=
200
,
**
kwargs
):
super
().
__init__
()
self
.
card
=
card
self
.
n_q
=
n_q
self
.
dim
=
dim
self
.
transformer
=
m
.
StreamingTransformerEncoder
(
dim
=
dim
,
**
kwargs
)
self
.
emb
=
nn
.
ModuleList
([
nn
.
Embedding
(
card
+
1
,
dim
)
for
_
in
range
(
n_q
)])
self
.
linears
=
nn
.
ModuleList
([
nn
.
Linear
(
dim
,
card
)
for
_
in
range
(
n_q
)])
def
forward
(
self
,
indices
:
torch
.
Tensor
,
states
:
tp
.
Optional
[
tp
.
List
[
torch
.
Tensor
]]
=
None
,
offset
:
int
=
0
):
"""
Args:
indices (torch.Tensor): indices from the previous time step. Indices
should be 1 + actual index in the codebook. The value 0 is reserved for
when the index is missing (i.e. first time step). Shape should be
`[B, n_q, T]`.
states: state for the streaming decoding.
offset: offset of the current time step.
Returns a 3-tuple `(probabilities, new_states, new_offset)` with probabilities
with a shape `[B, card, n_q, T]`.
"""
B
,
K
,
T
=
indices
.
shape
input_
=
sum
([
self
.
emb
[
k
](
indices
[:,
k
])
for
k
in
range
(
K
)])
out
,
states
,
offset
=
self
.
transformer
(
input_
,
states
,
offset
)
logits
=
torch
.
stack
([
self
.
linears
[
k
](
out
)
for
k
in
range
(
K
)],
dim
=
1
).
permute
(
0
,
3
,
1
,
2
)
return
torch
.
softmax
(
logits
,
dim
=
1
),
states
,
offset
class
EncodecModel
(
nn
.
Module
):
"""EnCodec model operating on the raw waveform.
Args:
target_bandwidths (list of float): Target bandwidths.
encoder (nn.Module): Encoder network.
decoder (nn.Module): Decoder network.
sample_rate (int): Audio sample rate.
channels (int): Number of audio channels.
normalize (bool): Whether to apply audio normalization.
segment (float or None): segment duration in sec. when doing overlap-add.
overlap (float): overlap between segment, given as a fraction of the segment duration.
name (str): name of the model, used as metadata when compressing audio.
"""
def
__init__
(
self
,
encoder
:
m
.
SEANetEncoder
,
decoder
:
m
.
SEANetDecoder
,
quantizer
:
qt
.
ResidualVectorQuantizer
,
target_bandwidths
:
tp
.
List
[
float
],
sample_rate
:
int
,
channels
:
int
,
normalize
:
bool
=
False
,
segment
:
tp
.
Optional
[
float
]
=
None
,
overlap
:
float
=
0.01
,
name
:
str
=
'unset'
):
super
().
__init__
()
self
.
bandwidth
:
tp
.
Optional
[
float
]
=
None
self
.
target_bandwidths
=
target_bandwidths
self
.
encoder
=
encoder
self
.
quantizer
=
quantizer
self
.
decoder
=
decoder
self
.
sample_rate
=
sample_rate
self
.
channels
=
channels
self
.
normalize
=
normalize
self
.
segment
=
segment
self
.
overlap
=
overlap
self
.
frame_rate
=
math
.
ceil
(
self
.
sample_rate
/
np
.
prod
(
self
.
encoder
.
ratios
))
self
.
name
=
name
self
.
bits_per_codebook
=
int
(
math
.
log2
(
self
.
quantizer
.
bins
))
assert
2
**
self
.
bits_per_codebook
==
self
.
quantizer
.
bins
,
\
"quantizer bins must be a power of 2."
@
property
def
segment_length
(
self
)
->
tp
.
Optional
[
int
]:
if
self
.
segment
is
None
:
return
None
return
int
(
self
.
segment
*
self
.
sample_rate
)
@
property
def
segment_stride
(
self
)
->
tp
.
Optional
[
int
]:
segment_length
=
self
.
segment_length
if
segment_length
is
None
:
return
None
return
max
(
1
,
int
((
1
-
self
.
overlap
)
*
segment_length
))
def
encode
(
self
,
x
:
torch
.
Tensor
)
->
tp
.
List
[
EncodedFrame
]:
"""Given a tensor `x`, returns a list of frames containing
the discrete encoded codes for `x`, along with rescaling factors
for each segment, when `self.normalize` is True.
Each frames is a tuple `(codebook, scale)`, with `codebook` of
shape `[B, K, T]`, with `K` the number of codebooks.
"""
assert
x
.
dim
()
==
3
_
,
channels
,
length
=
x
.
shape
assert
channels
>
0
and
channels
<=
2
segment_length
=
self
.
segment_length
if
segment_length
is
None
:
segment_length
=
length
stride
=
length
else
:
stride
=
self
.
segment_stride
# type: ignore
assert
stride
is
not
None
encoded_frames
:
tp
.
List
[
EncodedFrame
]
=
[]
for
offset
in
range
(
0
,
length
,
stride
):
frame
=
x
[:,
:,
offset
:
offset
+
segment_length
]
encoded_frames
.
append
(
self
.
_encode_frame
(
frame
))
return
encoded_frames
def
_encode_frame
(
self
,
x
:
torch
.
Tensor
)
->
EncodedFrame
:
length
=
x
.
shape
[
-
1
]
duration
=
length
/
self
.
sample_rate
assert
self
.
segment
is
None
or
duration
<=
1e-5
+
self
.
segment
if
self
.
normalize
:
mono
=
x
.
mean
(
dim
=
1
,
keepdim
=
True
)
volume
=
mono
.
pow
(
2
).
mean
(
dim
=
2
,
keepdim
=
True
).
sqrt
()
scale
=
1e-8
+
volume
x
=
x
/
scale
scale
=
scale
.
view
(
-
1
,
1
)
else
:
scale
=
None
emb
=
self
.
encoder
(
x
)
codes
=
self
.
quantizer
.
encode
(
emb
,
self
.
frame_rate
,
self
.
bandwidth
)
codes
=
codes
.
transpose
(
0
,
1
)
# codes is [B, K, T], with T frames, K nb of codebooks.
return
codes
,
scale
def
decode
(
self
,
encoded_frames
:
tp
.
List
[
EncodedFrame
])
->
torch
.
Tensor
:
"""Decode the given frames into a waveform.
Note that the output might be a bit bigger than the input. In that case,
any extra steps at the end can be trimmed.
"""
segment_length
=
self
.
segment_length
if
segment_length
is
None
:
assert
len
(
encoded_frames
)
==
1
return
self
.
_decode_frame
(
encoded_frames
[
0
])
frames
=
[
self
.
_decode_frame
(
frame
)
for
frame
in
encoded_frames
]
return
_linear_overlap_add
(
frames
,
self
.
segment_stride
or
1
)
def
_decode_frame
(
self
,
encoded_frame
:
EncodedFrame
)
->
torch
.
Tensor
:
codes
,
scale
=
encoded_frame
codes
=
codes
.
transpose
(
0
,
1
)
emb
=
self
.
quantizer
.
decode
(
codes
)
out
=
self
.
decoder
(
emb
)
if
scale
is
not
None
:
out
=
out
*
scale
.
view
(
-
1
,
1
,
1
)
return
out
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
frames
=
self
.
encode
(
x
)
return
self
.
decode
(
frames
)[:,
:,
:
x
.
shape
[
-
1
]]
def
set_target_bandwidth
(
self
,
bandwidth
:
float
):
if
bandwidth
not
in
self
.
target_bandwidths
:
raise
ValueError
(
f
"This model doesn't support the bandwidth
{
bandwidth
}
. "
f
"Select one of
{
self
.
target_bandwidths
}
."
)
self
.
bandwidth
=
bandwidth
def
get_lm_model
(
self
)
->
LMModel
:
"""Return the associated LM model to improve the compression rate.
"""
device
=
next
(
self
.
parameters
()).
device
lm
=
LMModel
(
self
.
quantizer
.
n_q
,
self
.
quantizer
.
bins
,
num_layers
=
5
,
dim
=
200
,
past_context
=
int
(
3.5
*
self
.
frame_rate
)).
to
(
device
)
checkpoints
=
{
'encodec_24khz'
:
'encodec_lm_24khz-1608e3c0.th'
,
'encodec_48khz'
:
'encodec_lm_48khz-7add9fc3.th'
,
}
try
:
checkpoint_name
=
checkpoints
[
self
.
name
]
except
KeyError
:
raise
RuntimeError
(
"No LM pre-trained for the current Encodec model."
)
url
=
_get_checkpoint_url
(
ROOT_URL
,
checkpoint_name
)
state
=
torch
.
hub
.
load_state_dict_from_url
(
url
,
map_location
=
'cpu'
,
check_hash
=
True
)
# type: ignore
lm
.
load_state_dict
(
state
)
lm
.
eval
()
return
lm
@
staticmethod
def
_get_model
(
target_bandwidths
:
tp
.
List
[
float
],
sample_rate
:
int
=
24_000
,
channels
:
int
=
1
,
causal
:
bool
=
True
,
model_norm
:
str
=
'weight_norm'
,
audio_normalize
:
bool
=
False
,
segment
:
tp
.
Optional
[
float
]
=
None
,
name
:
str
=
'unset'
):
encoder
=
m
.
SEANetEncoder
(
channels
=
channels
,
norm
=
model_norm
,
causal
=
causal
)
decoder
=
m
.
SEANetDecoder
(
channels
=
channels
,
norm
=
model_norm
,
causal
=
causal
)
n_q
=
int
(
1000
*
target_bandwidths
[
-
1
]
//
(
math
.
ceil
(
sample_rate
/
encoder
.
hop_length
)
*
10
))
quantizer
=
qt
.
ResidualVectorQuantizer
(
dimension
=
encoder
.
dimension
,
n_q
=
n_q
,
bins
=
1024
,
)
model
=
EncodecModel
(
encoder
,
decoder
,
quantizer
,
target_bandwidths
,
sample_rate
,
channels
,
normalize
=
audio_normalize
,
segment
=
segment
,
name
=
name
,
)
return
model
@
staticmethod
def
_get_pretrained
(
checkpoint_name
:
str
,
repository
:
tp
.
Optional
[
Path
]
=
None
):
if
repository
is
not
None
:
if
not
repository
.
is_dir
():
raise
ValueError
(
f
"
{
repository
}
must exist and be a directory."
)
file
=
repository
/
checkpoint_name
checksum
=
file
.
stem
.
split
(
'-'
)[
1
]
_check_checksum
(
file
,
checksum
)
return
torch
.
load
(
file
)
else
:
url
=
_get_checkpoint_url
(
ROOT_URL
,
checkpoint_name
)
return
torch
.
hub
.
load_state_dict_from_url
(
url
,
map_location
=
'cpu'
,
check_hash
=
True
)
# type:ignore
@
staticmethod
def
encodec_model_24khz
(
pretrained
:
bool
=
True
,
repository
:
tp
.
Optional
[
Path
]
=
None
):
"""Return the pretrained causal 24khz model.
"""
if
repository
:
assert
pretrained
target_bandwidths
=
[
1.5
,
3.
,
6
,
12.
,
24.
]
checkpoint_name
=
'encodec_24khz-d7cc33bc.th'
sample_rate
=
24_000
channels
=
1
model
=
EncodecModel
.
_get_model
(
target_bandwidths
,
sample_rate
,
channels
,
causal
=
True
,
model_norm
=
'weight_norm'
,
audio_normalize
=
False
,
name
=
'encodec_24khz'
if
pretrained
else
'unset'
)
if
pretrained
:
state_dict
=
EncodecModel
.
_get_pretrained
(
checkpoint_name
,
repository
)
model
.
load_state_dict
(
state_dict
)
model
.
eval
()
return
model
@
staticmethod
def
encodec_model_48khz
(
pretrained
:
bool
=
True
,
repository
:
tp
.
Optional
[
Path
]
=
None
):
"""Return the pretrained 48khz model.
"""
if
repository
:
assert
pretrained
target_bandwidths
=
[
3.
,
6.
,
12.
,
24.
]
checkpoint_name
=
'encodec_48khz-7e698e3e.th'
sample_rate
=
48_000
channels
=
2
model
=
EncodecModel
.
_get_model
(
target_bandwidths
,
sample_rate
,
channels
,
causal
=
False
,
model_norm
=
'time_group_norm'
,
audio_normalize
=
True
,
segment
=
1.
,
name
=
'encodec_48khz'
if
pretrained
else
'unset'
)
if
pretrained
:
state_dict
=
EncodecModel
.
_get_pretrained
(
checkpoint_name
,
repository
)
model
.
load_state_dict
(
state_dict
)
model
.
eval
()
return
model
def
test
():
from
itertools
import
product
import
torchaudio
bandwidths
=
[
3
,
6
,
12
,
24
]
models
=
{
'encodec_24khz'
:
EncodecModel
.
encodec_model_24khz
,
'encodec_48khz'
:
EncodecModel
.
encodec_model_48khz
}
for
model_name
,
bw
in
product
(
models
.
keys
(),
bandwidths
):
model
=
models
[
model_name
]()
model
.
set_target_bandwidth
(
bw
)
audio_suffix
=
model_name
.
split
(
'_'
)[
1
][:
3
]
wav
,
sr
=
torchaudio
.
load
(
f
"test_
{
audio_suffix
}
.wav"
)
wav
=
wav
[:,
:
model
.
sample_rate
*
2
]
wav_in
=
wav
.
unsqueeze
(
0
)
wav_dec
=
model
(
wav_in
)[
0
]
assert
wav
.
shape
==
wav_dec
.
shape
,
(
wav
.
shape
,
wav_dec
.
shape
)
if
__name__
==
'__main__'
:
test
()
inspiremusic/wavtokenizer/encoder/modules/__init__.py
0 → 100644
View file @
0112b0f0
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Torch modules."""
# flake8: noqa
from
.conv
import
(
pad1d
,
unpad1d
,
NormConv1d
,
NormConvTranspose1d
,
NormConv2d
,
NormConvTranspose2d
,
SConv1d
,
SConvTranspose1d
,
)
from
.lstm
import
SLSTM
from
.seanet
import
SEANetEncoder
,
SEANetDecoder
from
.transformer
import
StreamingTransformerEncoder
inspiremusic/wavtokenizer/encoder/modules/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
inspiremusic/wavtokenizer/encoder/modules/__pycache__/conv.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
inspiremusic/wavtokenizer/encoder/modules/__pycache__/lstm.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
inspiremusic/wavtokenizer/encoder/modules/__pycache__/norm.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
Prev
1
…
13
14
15
16
17
18
19
20
21
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment