Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
InspireMusic_pytorch
Commits
0112b0f0
Commit
0112b0f0
authored
Feb 14, 2025
by
chenzk
Browse files
v1.0
parents
Pipeline
#2394
canceled with stages
Changes
474
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2783 additions
and
0 deletions
+2783
-0
examples/music_generation/inspiremusic/wavtokenizer/decoder/__pycache__/__init__.cpython-310.pyc
...wavtokenizer/decoder/__pycache__/__init__.cpython-310.pyc
+0
-0
examples/music_generation/inspiremusic/wavtokenizer/decoder/__pycache__/feature_extractors.cpython-310.pyc
...er/decoder/__pycache__/feature_extractors.cpython-310.pyc
+0
-0
examples/music_generation/inspiremusic/wavtokenizer/decoder/__pycache__/heads.cpython-310.pyc
...ic/wavtokenizer/decoder/__pycache__/heads.cpython-310.pyc
+0
-0
examples/music_generation/inspiremusic/wavtokenizer/decoder/__pycache__/models.cpython-310.pyc
...c/wavtokenizer/decoder/__pycache__/models.cpython-310.pyc
+0
-0
examples/music_generation/inspiremusic/wavtokenizer/decoder/__pycache__/modules.cpython-310.pyc
.../wavtokenizer/decoder/__pycache__/modules.cpython-310.pyc
+0
-0
examples/music_generation/inspiremusic/wavtokenizer/decoder/__pycache__/pretrained.cpython-310.pyc
...vtokenizer/decoder/__pycache__/pretrained.cpython-310.pyc
+0
-0
examples/music_generation/inspiremusic/wavtokenizer/decoder/__pycache__/spectral_ops.cpython-310.pyc
...okenizer/decoder/__pycache__/spectral_ops.cpython-310.pyc
+0
-0
examples/music_generation/inspiremusic/wavtokenizer/decoder/dataset.py
...c_generation/inspiremusic/wavtokenizer/decoder/dataset.py
+125
-0
examples/music_generation/inspiremusic/wavtokenizer/decoder/discriminator_dac.py
...on/inspiremusic/wavtokenizer/decoder/discriminator_dac.py
+249
-0
examples/music_generation/inspiremusic/wavtokenizer/decoder/discriminators.py
...ation/inspiremusic/wavtokenizer/decoder/discriminators.py
+202
-0
examples/music_generation/inspiremusic/wavtokenizer/decoder/experiment.py
...eneration/inspiremusic/wavtokenizer/decoder/experiment.py
+474
-0
examples/music_generation/inspiremusic/wavtokenizer/decoder/feature_extractors.py
...n/inspiremusic/wavtokenizer/decoder/feature_extractors.py
+177
-0
examples/music_generation/inspiremusic/wavtokenizer/decoder/heads.py
...sic_generation/inspiremusic/wavtokenizer/decoder/heads.py
+159
-0
examples/music_generation/inspiremusic/wavtokenizer/decoder/helpers.py
...c_generation/inspiremusic/wavtokenizer/decoder/helpers.py
+71
-0
examples/music_generation/inspiremusic/wavtokenizer/decoder/loss.py
...usic_generation/inspiremusic/wavtokenizer/decoder/loss.py
+159
-0
examples/music_generation/inspiremusic/wavtokenizer/decoder/models.py
...ic_generation/inspiremusic/wavtokenizer/decoder/models.py
+266
-0
examples/music_generation/inspiremusic/wavtokenizer/decoder/modules.py
...c_generation/inspiremusic/wavtokenizer/decoder/modules.py
+214
-0
examples/music_generation/inspiremusic/wavtokenizer/decoder/pretrained.py
...eneration/inspiremusic/wavtokenizer/decoder/pretrained.py
+253
-0
examples/music_generation/inspiremusic/wavtokenizer/decoder/pretrained_model.py
...ion/inspiremusic/wavtokenizer/decoder/pretrained_model.py
+192
-0
examples/music_generation/inspiremusic/wavtokenizer/decoder/spectral_ops.py
...eration/inspiremusic/wavtokenizer/decoder/spectral_ops.py
+242
-0
No files found.
examples/music_generation/inspiremusic/wavtokenizer/decoder/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
examples/music_generation/inspiremusic/wavtokenizer/decoder/__pycache__/feature_extractors.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
examples/music_generation/inspiremusic/wavtokenizer/decoder/__pycache__/heads.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
examples/music_generation/inspiremusic/wavtokenizer/decoder/__pycache__/models.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
examples/music_generation/inspiremusic/wavtokenizer/decoder/__pycache__/modules.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
examples/music_generation/inspiremusic/wavtokenizer/decoder/__pycache__/pretrained.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
examples/music_generation/inspiremusic/wavtokenizer/decoder/__pycache__/spectral_ops.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
examples/music_generation/inspiremusic/wavtokenizer/decoder/dataset.py
0 → 100644
View file @
0112b0f0
from
dataclasses
import
dataclass
import
numpy
as
np
import
torch
import
torchaudio
from
pytorch_lightning
import
LightningDataModule
from
torch.utils.data
import
Dataset
,
DataLoader
import
soundfile
# import librosa
import
random
torch
.
set_num_threads
(
1
)
@
dataclass
class
DataConfig
:
filelist_path
:
str
sampling_rate
:
int
num_samples
:
int
batch_size
:
int
num_workers
:
int
def
collate_fn
(
batch
):
batch
=
[
item
for
item
in
batch
if
item
is
not
None
]
return
torch
.
stack
(
batch
,
dim
=
0
)
class
VocosDataModule
(
LightningDataModule
):
def
__init__
(
self
,
train_params
:
DataConfig
,
val_params
:
DataConfig
):
super
().
__init__
()
self
.
train_config
=
train_params
self
.
val_config
=
val_params
def
_get_dataloder
(
self
,
cfg
:
DataConfig
,
train
:
bool
):
dataset
=
VocosDataset
(
cfg
,
train
=
train
)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
cfg
.
batch_size
,
num_workers
=
cfg
.
num_workers
,
shuffle
=
train
,
pin_memory
=
True
,
collate_fn
=
collate_fn
)
return
dataloader
def
train_dataloader
(
self
)
->
DataLoader
:
return
self
.
_get_dataloder
(
self
.
train_config
,
train
=
True
)
def
val_dataloader
(
self
)
->
DataLoader
:
return
self
.
_get_dataloder
(
self
.
val_config
,
train
=
False
)
class
VocosDataset
(
Dataset
):
def
__init__
(
self
,
cfg
:
DataConfig
,
train
:
bool
):
with
open
(
cfg
.
filelist_path
)
as
f
:
self
.
filelist
=
f
.
read
().
splitlines
()
self
.
sampling_rate
=
cfg
.
sampling_rate
self
.
num_samples
=
cfg
.
num_samples
self
.
train
=
train
def
__len__
(
self
)
->
int
:
return
len
(
self
.
filelist
)
def
__getitem__
(
self
,
index
:
int
)
->
torch
.
Tensor
:
audio_path
=
self
.
filelist
[
index
]
# y, sr = torchaudio.load(audio_path)
# print(audio_path,"111")
try
:
y1
,
sr
=
soundfile
.
read
(
audio_path
)
# y1, sr = librosa.load(audio_path,sr=None)
y
=
torch
.
tensor
(
y1
).
float
().
unsqueeze
(
0
)
# if y.size(0) > 1:
# # mix to mono
# y = y.mean(dim=0, keepdim=True)
if
y
.
ndim
>
2
:
# mix to mono
# print("有问题哈,数据处理部分")
# y = y.mean(dim=-1, keepdim=False)
random_channel
=
random
.
randint
(
0
,
y
.
size
(
-
1
)
-
1
)
y
=
y
[:,
:,
random_channel
]
gain
=
np
.
random
.
uniform
(
-
1
,
-
6
)
if
self
.
train
else
-
3
y
,
_
=
torchaudio
.
sox_effects
.
apply_effects_tensor
(
y
,
sr
,
[[
"norm"
,
f
"
{
gain
:.
2
f
}
"
]])
if
sr
!=
self
.
sampling_rate
:
y
=
torchaudio
.
functional
.
resample
(
y
,
orig_freq
=
sr
,
new_freq
=
self
.
sampling_rate
)
if
y
.
size
(
-
1
)
<
self
.
num_samples
:
pad_length
=
self
.
num_samples
-
y
.
size
(
-
1
)
padding_tensor
=
y
.
repeat
(
1
,
1
+
pad_length
//
y
.
size
(
-
1
))
y
=
torch
.
cat
((
y
,
padding_tensor
[:,
:
pad_length
]),
dim
=
1
)
elif
self
.
train
:
start
=
np
.
random
.
randint
(
low
=
0
,
high
=
y
.
size
(
-
1
)
-
self
.
num_samples
+
1
)
y
=
y
[:,
start
:
start
+
self
.
num_samples
]
else
:
# During validation, take always the first segment for determinism
y
=
y
[:,
:
self
.
num_samples
]
return
y
[
0
]
except
Exception
as
e
:
print
(
f
"Error processing file
{
audio_path
}
at index
{
index
}
:
{
e
}
"
)
# 这里可以继续选择抛出异常,或者返回一个 None 表示无效数据
return
None
# def __getitem__(self, index: int) -> torch.Tensor:
# audio_path = self.filelist[index]
# try:
# y, sr = torchaudio.load(audio_path)
# if y.size(0) > 1:
# # 随机选择一个通道
# random_channel = random.randint(0, y.size(0) - 1)
# y = y[random_channel, :].unsqueeze(0) # 保持返回值为 (1, T) 的形式
# # gain = np.random.uniform(-1, -6) if self.train else -3
# # y, _ = torchaudio.sox_effects.apply_effects_tensor(y, sr, [["norm", f"{gain:.2f}"]])
# if sr != self.sampling_rate:
# y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=self.sampling_rate)
# if y.size(-1) < self.num_samples:
# pad_length = self.num_samples - y.size(-1)
# padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1))
# y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1)
# elif self.train:
# start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1)
# y = y[:, start: start + self.num_samples]
# else:
# # During validation, take always the first segment for determinism
# y = y[:, :self.num_samples]
# return y[0]
# except Exception as e:
# print(f"Error processing file {audio_path} at index {index}: {e}")
# # 这里可以继续选择抛出异常,或者返回一个 None 表示无效数据
# return None
\ No newline at end of file
examples/music_generation/inspiremusic/wavtokenizer/decoder/discriminator_dac.py
0 → 100644
View file @
0112b0f0
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
# from audiotools import AudioSignal
# from audiotools import ml
# from audiotools import STFTParams
from
einops
import
rearrange
from
torch.nn.utils
import
weight_norm
from
collections
import
namedtuple
STFTParams
=
namedtuple
(
"STFTParams"
,
[
"window_length"
,
"hop_length"
,
"window_type"
,
"match_stride"
,
"padding_type"
],
)
STFTParams
.
__new__
.
__defaults__
=
(
None
,
None
,
None
,
None
,
None
)
def
WNConv1d
(
*
args
,
**
kwargs
):
act
=
kwargs
.
pop
(
"act"
,
True
)
conv
=
weight_norm
(
nn
.
Conv1d
(
*
args
,
**
kwargs
))
if
not
act
:
return
conv
return
nn
.
Sequential
(
conv
,
nn
.
LeakyReLU
(
0.1
))
def
WNConv2d
(
*
args
,
**
kwargs
):
act
=
kwargs
.
pop
(
"act"
,
True
)
conv
=
weight_norm
(
nn
.
Conv2d
(
*
args
,
**
kwargs
))
if
not
act
:
return
conv
return
nn
.
Sequential
(
conv
,
nn
.
LeakyReLU
(
0.1
))
class
MPD
(
nn
.
Module
):
def
__init__
(
self
,
period
):
super
().
__init__
()
self
.
period
=
period
self
.
convs
=
nn
.
ModuleList
(
[
WNConv2d
(
1
,
32
,
(
5
,
1
),
(
3
,
1
),
padding
=
(
2
,
0
)),
WNConv2d
(
32
,
128
,
(
5
,
1
),
(
3
,
1
),
padding
=
(
2
,
0
)),
WNConv2d
(
128
,
512
,
(
5
,
1
),
(
3
,
1
),
padding
=
(
2
,
0
)),
WNConv2d
(
512
,
1024
,
(
5
,
1
),
(
3
,
1
),
padding
=
(
2
,
0
)),
WNConv2d
(
1024
,
1024
,
(
5
,
1
),
1
,
padding
=
(
2
,
0
)),
]
)
self
.
conv_post
=
WNConv2d
(
1024
,
1
,
kernel_size
=
(
3
,
1
),
padding
=
(
1
,
0
),
act
=
False
)
def
pad_to_period
(
self
,
x
):
t
=
x
.
shape
[
-
1
]
x
=
F
.
pad
(
x
,
(
0
,
self
.
period
-
t
%
self
.
period
),
mode
=
"reflect"
)
return
x
def
forward
(
self
,
x
):
fmap
=
[]
x
=
self
.
pad_to_period
(
x
)
x
=
rearrange
(
x
,
"b c (l p) -> b c l p"
,
p
=
self
.
period
)
for
layer
in
self
.
convs
:
x
=
layer
(
x
)
fmap
.
append
(
x
)
x
=
self
.
conv_post
(
x
)
fmap
.
append
(
x
)
return
fmap
class
MSD
(
nn
.
Module
):
def
__init__
(
self
,
rate
:
int
=
1
,
sample_rate
:
int
=
48000
):
super
().
__init__
()
self
.
convs
=
nn
.
ModuleList
(
[
WNConv1d
(
1
,
16
,
15
,
1
,
padding
=
7
),
WNConv1d
(
16
,
64
,
41
,
4
,
groups
=
4
,
padding
=
20
),
WNConv1d
(
64
,
256
,
41
,
4
,
groups
=
16
,
padding
=
20
),
WNConv1d
(
256
,
1024
,
41
,
4
,
groups
=
64
,
padding
=
20
),
WNConv1d
(
1024
,
1024
,
41
,
4
,
groups
=
256
,
padding
=
20
),
WNConv1d
(
1024
,
1024
,
5
,
1
,
padding
=
2
),
]
)
self
.
conv_post
=
WNConv1d
(
1024
,
1
,
3
,
1
,
padding
=
1
,
act
=
False
)
self
.
sample_rate
=
sample_rate
self
.
rate
=
rate
def
forward
(
self
,
x
):
# x = AudioSignal(x, self.sample_rate)
# x.resample(self.sample_rate // self.rate)
# x = x.audio_data
fmap
=
[]
for
l
in
self
.
convs
:
x
=
l
(
x
)
fmap
.
append
(
x
)
x
=
self
.
conv_post
(
x
)
fmap
.
append
(
x
)
return
fmap
BANDS
=
[(
0.0
,
0.1
),
(
0.1
,
0.25
),
(
0.25
,
0.5
),
(
0.5
,
0.75
),
(
0.75
,
1.0
)]
class
MRD
(
nn
.
Module
):
def
__init__
(
self
,
window_length
:
int
,
hop_factor
:
float
=
0.25
,
sample_rate
:
int
=
24000
,
bands
:
list
=
BANDS
,
):
"""Complex multi-band spectrogram discriminator.
Parameters
----------
window_length : int
Window length of STFT.
hop_factor : float, optional
Hop factor of the STFT, defaults to ``0.25 * window_length``.
sample_rate : int, optional
Sampling rate of audio in Hz, by default 24000
bands : list, optional
Bands to run discriminator over.
"""
super
().
__init__
()
self
.
window_length
=
window_length
self
.
hop_factor
=
hop_factor
self
.
sample_rate
=
sample_rate
self
.
stft_params
=
STFTParams
(
window_length
=
window_length
,
hop_length
=
int
(
window_length
*
hop_factor
),
match_stride
=
True
,
)
n_fft
=
window_length
//
2
+
1
bands
=
[(
int
(
b
[
0
]
*
n_fft
),
int
(
b
[
1
]
*
n_fft
))
for
b
in
bands
]
self
.
bands
=
bands
self
.
n_fft
=
window_length
ch
=
32
convs
=
lambda
:
nn
.
ModuleList
(
[
WNConv2d
(
2
,
ch
,
(
3
,
9
),
(
1
,
1
),
padding
=
(
1
,
4
)),
WNConv2d
(
ch
,
ch
,
(
3
,
9
),
(
1
,
2
),
padding
=
(
1
,
4
)),
WNConv2d
(
ch
,
ch
,
(
3
,
9
),
(
1
,
2
),
padding
=
(
1
,
4
)),
WNConv2d
(
ch
,
ch
,
(
3
,
9
),
(
1
,
2
),
padding
=
(
1
,
4
)),
WNConv2d
(
ch
,
ch
,
(
3
,
3
),
(
1
,
1
),
padding
=
(
1
,
1
)),
]
)
self
.
band_convs
=
nn
.
ModuleList
([
convs
()
for
_
in
range
(
len
(
self
.
bands
))])
self
.
conv_post
=
WNConv2d
(
ch
,
1
,
(
3
,
3
),
(
1
,
1
),
padding
=
(
1
,
1
),
act
=
False
)
def
spectrogram
(
self
,
x
):
# x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
# x = torch.view_as_real(x.stft())
# x.squeeze(0).stft(n_fft=1024,win_length=1024,return_complex=True).size()
# breakpoint()
if
x
.
size
(
0
)
==
1
:
# x = torch.view_as_real(x.squeeze(0).stft(n_fft=self.window_length,return_complex=True).unsqueeze(0))
x
=
torch
.
view_as_real
(
x
.
squeeze
(
0
).
stft
(
n_fft
=
self
.
n_fft
,
return_complex
=
True
).
unsqueeze
(
0
))
else
:
# x = torch.view_as_real(x.squeeze(1).stft(n_fft=self.window_length,return_complex=True).unsqueeze(1))
x
=
torch
.
view_as_real
(
x
.
squeeze
(
1
).
stft
(
n_fft
=
self
.
n_fft
,
return_complex
=
True
).
unsqueeze
(
1
))
x
=
rearrange
(
x
,
"b 1 f t c -> (b 1) c t f"
)
# Split into bands
x_bands
=
[
x
[...,
b
[
0
]
:
b
[
1
]]
for
b
in
self
.
bands
]
return
x_bands
def
forward
(
self
,
x
):
x_bands
=
self
.
spectrogram
(
x
)
fmap
=
[]
x
=
[]
for
band
,
stack
in
zip
(
x_bands
,
self
.
band_convs
):
for
layer
in
stack
:
band
=
layer
(
band
)
fmap
.
append
(
band
)
x
.
append
(
band
)
x
=
torch
.
cat
(
x
,
dim
=-
1
)
x
=
self
.
conv_post
(
x
)
fmap
.
append
(
x
)
return
fmap
# class DACDiscriminator(ml.BaseModel):
class
DACDiscriminator
(
nn
.
Module
):
def
__init__
(
self
,
rates
:
list
=
[],
periods
:
list
=
[
2
,
3
,
5
,
7
,
11
],
fft_sizes
:
list
=
[
2048
,
1024
,
512
],
sample_rate
:
int
=
24000
,
bands
:
list
=
BANDS
,
):
"""Discriminator that combines multiple discriminators.
Parameters
----------
rates : list, optional
sampling rates (in Hz) to run MSD at, by default []
If empty, MSD is not used.
periods : list, optional
periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
fft_sizes : list, optional
Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
sample_rate : int, optional
Sampling rate of audio in Hz, by default 24000
bands : list, optional
Bands to run MRD at, by default `BANDS`
"""
super
().
__init__
()
discs
=
[]
discs
+=
[
MPD
(
p
)
for
p
in
periods
]
discs
+=
[
MSD
(
r
,
sample_rate
=
sample_rate
)
for
r
in
rates
]
discs
+=
[
MRD
(
f
,
sample_rate
=
sample_rate
,
bands
=
bands
)
for
f
in
fft_sizes
]
self
.
discriminators
=
nn
.
ModuleList
(
discs
)
def
preprocess
(
self
,
y
):
# Remove DC offset
y
=
y
-
y
.
mean
(
dim
=-
1
,
keepdims
=
True
)
# Peak normalize the volume of input audio
y
=
0.8
*
y
/
(
y
.
abs
().
max
(
dim
=-
1
,
keepdim
=
True
)[
0
]
+
1e-9
)
return
y
def
forward
(
self
,
x
):
x
=
self
.
preprocess
(
x
)
fmaps
=
[
d
(
x
)
for
d
in
self
.
discriminators
]
return
fmaps
if
__name__
==
"__main__"
:
disc
=
DACDiscriminator
()
x
=
torch
.
zeros
(
1
,
1
,
24000
)
results
=
disc
(
x
)
breakpoint
()
for
i
,
result
in
enumerate
(
results
):
print
(
f
"disc
{
i
}
"
)
for
i
,
r
in
enumerate
(
result
):
print
(
r
.
shape
,
r
.
mean
(),
r
.
min
(),
r
.
max
())
print
(
"00"
)
examples/music_generation/inspiremusic/wavtokenizer/decoder/discriminators.py
0 → 100644
View file @
0112b0f0
from
typing
import
Tuple
,
List
import
torch
from
torch
import
nn
from
torch.nn
import
Conv2d
from
torch.nn.utils
import
weight_norm
class
MultiPeriodDiscriminator
(
nn
.
Module
):
"""
Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan.
Additionally, it allows incorporating conditional information with a learned embeddings table.
Args:
periods (tuple[int]): Tuple of periods for each discriminator.
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
Defaults to None.
"""
def
__init__
(
self
,
periods
:
Tuple
[
int
]
=
(
2
,
3
,
5
,
7
,
11
),
num_embeddings
:
int
=
None
):
super
().
__init__
()
self
.
discriminators
=
nn
.
ModuleList
([
DiscriminatorP
(
period
=
p
,
num_embeddings
=
num_embeddings
)
for
p
in
periods
])
def
forward
(
self
,
y
:
torch
.
Tensor
,
y_hat
:
torch
.
Tensor
,
bandwidth_id
:
torch
.
Tensor
=
None
)
->
Tuple
[
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
],
List
[
List
[
torch
.
Tensor
]],
List
[
List
[
torch
.
Tensor
]]]:
y_d_rs
=
[]
y_d_gs
=
[]
fmap_rs
=
[]
fmap_gs
=
[]
for
d
in
self
.
discriminators
:
y_d_r
,
fmap_r
=
d
(
x
=
y
,
cond_embedding_id
=
bandwidth_id
)
y_d_g
,
fmap_g
=
d
(
x
=
y_hat
,
cond_embedding_id
=
bandwidth_id
)
y_d_rs
.
append
(
y_d_r
)
fmap_rs
.
append
(
fmap_r
)
y_d_gs
.
append
(
y_d_g
)
fmap_gs
.
append
(
fmap_g
)
return
y_d_rs
,
y_d_gs
,
fmap_rs
,
fmap_gs
class
DiscriminatorP
(
nn
.
Module
):
def
__init__
(
self
,
period
:
int
,
in_channels
:
int
=
1
,
kernel_size
:
int
=
5
,
stride
:
int
=
3
,
lrelu_slope
:
float
=
0.1
,
num_embeddings
:
int
=
None
,
):
super
().
__init__
()
self
.
period
=
period
self
.
convs
=
nn
.
ModuleList
(
[
weight_norm
(
Conv2d
(
in_channels
,
32
,
(
kernel_size
,
1
),
(
stride
,
1
),
padding
=
(
kernel_size
//
2
,
0
))),
weight_norm
(
Conv2d
(
32
,
128
,
(
kernel_size
,
1
),
(
stride
,
1
),
padding
=
(
kernel_size
//
2
,
0
))),
weight_norm
(
Conv2d
(
128
,
512
,
(
kernel_size
,
1
),
(
stride
,
1
),
padding
=
(
kernel_size
//
2
,
0
))),
weight_norm
(
Conv2d
(
512
,
1024
,
(
kernel_size
,
1
),
(
stride
,
1
),
padding
=
(
kernel_size
//
2
,
0
))),
weight_norm
(
Conv2d
(
1024
,
1024
,
(
kernel_size
,
1
),
(
1
,
1
),
padding
=
(
kernel_size
//
2
,
0
))),
]
)
if
num_embeddings
is
not
None
:
self
.
emb
=
torch
.
nn
.
Embedding
(
num_embeddings
=
num_embeddings
,
embedding_dim
=
1024
)
torch
.
nn
.
init
.
zeros_
(
self
.
emb
.
weight
)
self
.
conv_post
=
weight_norm
(
Conv2d
(
1024
,
1
,
(
3
,
1
),
1
,
padding
=
(
1
,
0
)))
self
.
lrelu_slope
=
lrelu_slope
def
forward
(
self
,
x
:
torch
.
Tensor
,
cond_embedding_id
:
torch
.
Tensor
=
None
)
->
Tuple
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]:
x
=
x
.
unsqueeze
(
1
)
fmap
=
[]
# 1d to 2d
b
,
c
,
t
=
x
.
shape
if
t
%
self
.
period
!=
0
:
# pad first
n_pad
=
self
.
period
-
(
t
%
self
.
period
)
x
=
torch
.
nn
.
functional
.
pad
(
x
,
(
0
,
n_pad
),
"reflect"
)
t
=
t
+
n_pad
x
=
x
.
view
(
b
,
c
,
t
//
self
.
period
,
self
.
period
)
for
i
,
l
in
enumerate
(
self
.
convs
):
x
=
l
(
x
)
x
=
torch
.
nn
.
functional
.
leaky_relu
(
x
,
self
.
lrelu_slope
)
if
i
>
0
:
fmap
.
append
(
x
)
if
cond_embedding_id
is
not
None
:
emb
=
self
.
emb
(
cond_embedding_id
)
h
=
(
emb
.
view
(
1
,
-
1
,
1
,
1
)
*
x
).
sum
(
dim
=
1
,
keepdims
=
True
)
else
:
h
=
0
x
=
self
.
conv_post
(
x
)
fmap
.
append
(
x
)
x
+=
h
x
=
torch
.
flatten
(
x
,
1
,
-
1
)
return
x
,
fmap
class
MultiResolutionDiscriminator
(
nn
.
Module
):
def
__init__
(
self
,
resolutions
:
Tuple
[
Tuple
[
int
,
int
,
int
]]
=
((
1024
,
256
,
1024
),
(
2048
,
512
,
2048
),
(
512
,
128
,
512
)),
num_embeddings
:
int
=
None
,
):
"""
Multi-Resolution Discriminator module adapted from https://github.com/mindslab-ai/univnet.
Additionally, it allows incorporating conditional information with a learned embeddings table.
Args:
resolutions (tuple[tuple[int, int, int]]): Tuple of resolutions for each discriminator.
Each resolution should be a tuple of (n_fft, hop_length, win_length).
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
Defaults to None.
"""
super
().
__init__
()
self
.
discriminators
=
nn
.
ModuleList
(
[
DiscriminatorR
(
resolution
=
r
,
num_embeddings
=
num_embeddings
)
for
r
in
resolutions
]
)
def
forward
(
self
,
y
:
torch
.
Tensor
,
y_hat
:
torch
.
Tensor
,
bandwidth_id
:
torch
.
Tensor
=
None
)
->
Tuple
[
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
],
List
[
List
[
torch
.
Tensor
]],
List
[
List
[
torch
.
Tensor
]]]:
y_d_rs
=
[]
y_d_gs
=
[]
fmap_rs
=
[]
fmap_gs
=
[]
for
d
in
self
.
discriminators
:
y_d_r
,
fmap_r
=
d
(
x
=
y
,
cond_embedding_id
=
bandwidth_id
)
y_d_g
,
fmap_g
=
d
(
x
=
y_hat
,
cond_embedding_id
=
bandwidth_id
)
y_d_rs
.
append
(
y_d_r
)
fmap_rs
.
append
(
fmap_r
)
y_d_gs
.
append
(
y_d_g
)
fmap_gs
.
append
(
fmap_g
)
return
y_d_rs
,
y_d_gs
,
fmap_rs
,
fmap_gs
class
DiscriminatorR
(
nn
.
Module
):
def
__init__
(
self
,
resolution
:
Tuple
[
int
,
int
,
int
],
channels
:
int
=
64
,
in_channels
:
int
=
1
,
num_embeddings
:
int
=
None
,
lrelu_slope
:
float
=
0.1
,
):
super
().
__init__
()
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
lrelu_slope
=
lrelu_slope
self
.
convs
=
nn
.
ModuleList
(
[
weight_norm
(
nn
.
Conv2d
(
in_channels
,
channels
,
kernel_size
=
(
7
,
5
),
stride
=
(
2
,
2
),
padding
=
(
3
,
2
))),
weight_norm
(
nn
.
Conv2d
(
channels
,
channels
,
kernel_size
=
(
5
,
3
),
stride
=
(
2
,
1
),
padding
=
(
2
,
1
))),
weight_norm
(
nn
.
Conv2d
(
channels
,
channels
,
kernel_size
=
(
5
,
3
),
stride
=
(
2
,
2
),
padding
=
(
2
,
1
))),
weight_norm
(
nn
.
Conv2d
(
channels
,
channels
,
kernel_size
=
3
,
stride
=
(
2
,
1
),
padding
=
1
)),
weight_norm
(
nn
.
Conv2d
(
channels
,
channels
,
kernel_size
=
3
,
stride
=
(
2
,
2
),
padding
=
1
)),
]
)
if
num_embeddings
is
not
None
:
self
.
emb
=
torch
.
nn
.
Embedding
(
num_embeddings
=
num_embeddings
,
embedding_dim
=
channels
)
torch
.
nn
.
init
.
zeros_
(
self
.
emb
.
weight
)
self
.
conv_post
=
weight_norm
(
nn
.
Conv2d
(
channels
,
1
,
(
3
,
3
),
padding
=
(
1
,
1
)))
def
forward
(
self
,
x
:
torch
.
Tensor
,
cond_embedding_id
:
torch
.
Tensor
=
None
)
->
Tuple
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]:
fmap
=
[]
x
=
self
.
spectrogram
(
x
)
x
=
x
.
unsqueeze
(
1
)
for
l
in
self
.
convs
:
x
=
l
(
x
)
x
=
torch
.
nn
.
functional
.
leaky_relu
(
x
,
self
.
lrelu_slope
)
fmap
.
append
(
x
)
if
cond_embedding_id
is
not
None
:
emb
=
self
.
emb
(
cond_embedding_id
)
h
=
(
emb
.
view
(
1
,
-
1
,
1
,
1
)
*
x
).
sum
(
dim
=
1
,
keepdims
=
True
)
else
:
h
=
0
x
=
self
.
conv_post
(
x
)
fmap
.
append
(
x
)
x
+=
h
x
=
torch
.
flatten
(
x
,
1
,
-
1
)
return
x
,
fmap
def
spectrogram
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
n_fft
,
hop_length
,
win_length
=
self
.
resolution
magnitude_spectrogram
=
torch
.
stft
(
x
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
win_length
=
win_length
,
window
=
None
,
# interestingly rectangular window kind of works here
center
=
True
,
return_complex
=
True
,
).
abs
()
return
magnitude_spectrogram
examples/music_generation/inspiremusic/wavtokenizer/decoder/experiment.py
0 → 100644
View file @
0112b0f0
import
math
import
numpy
as
np
import
pytorch_lightning
as
pl
import
torch
import
torchaudio
import
transformers
import
yaml
from
decoder.discriminator_dac
import
DACDiscriminator
from
decoder.discriminators
import
MultiPeriodDiscriminator
,
MultiResolutionDiscriminator
from
decoder.feature_extractors
import
FeatureExtractor
from
decoder.heads
import
FourierHead
from
decoder.helpers
import
plot_spectrogram_to_numpy
from
decoder.loss
import
DiscriminatorLoss
,
GeneratorLoss
,
FeatureMatchingLoss
,
MelSpecReconstructionLoss
,
DACGANLoss
from
decoder.models
import
Backbone
from
decoder.modules
import
safe_log
from
decoder.pretrained_model
import
instantiate_class
class
VocosExp
(
pl
.
LightningModule
):
# noinspection PyUnusedLocal
def
__init__
(
self
,
feature_extractor
:
FeatureExtractor
,
backbone
:
Backbone
,
head
:
FourierHead
,
resume_config
:
str
,
resume_model
:
str
,
sample_rate
:
int
=
24000
,
initial_learning_rate
:
float
=
2e-4
,
num_warmup_steps
:
int
=
0
,
mel_loss_coeff
:
float
=
45
,
mrd_loss_coeff
:
float
=
1.0
,
pretrain_mel_steps
:
int
=
0
,
decay_mel_coeff
:
bool
=
False
,
evaluate_utmos
:
bool
=
False
,
evaluate_pesq
:
bool
=
False
,
evaluate_periodicty
:
bool
=
False
,
resume
:
bool
=
False
,
):
"""
Args:
feature_extractor (FeatureExtractor): An instance of FeatureExtractor to extract features from audio signals.
backbone (Backbone): An instance of Backbone model.
head (FourierHead): An instance of Fourier head to generate spectral coefficients and reconstruct a waveform.
sample_rate (int): Sampling rate of the audio signals.
initial_learning_rate (float): Initial learning rate for the optimizer.
num_warmup_steps (int): Number of steps for the warmup phase of learning rate scheduler. Default is 0.
mel_loss_coeff (float, optional): Coefficient for Mel-spectrogram loss in the loss function. Default is 45.
mrd_loss_coeff (float, optional): Coefficient for Multi Resolution Discriminator loss. Default is 1.0.
pretrain_mel_steps (int, optional): Number of steps to pre-train the model without the GAN objective. Default is 0.
decay_mel_coeff (bool, optional): If True, the Mel-spectrogram loss coefficient is decayed during training. Default is False.
evaluate_utmos (bool, optional): If True, UTMOS scores are computed for each validation run.
evaluate_pesq (bool, optional): If True, PESQ scores are computed for each validation run.
evaluate_periodicty (bool, optional): If True, periodicity scores are computed for each validation run.
"""
super
().
__init__
()
self
.
save_hyperparameters
(
ignore
=
[
"feature_extractor"
,
"backbone"
,
"head"
])
self
.
feature_extractor
=
feature_extractor
self
.
backbone
=
backbone
self
.
head
=
head
self
.
resume_config
=
resume_config
self
.
resume_model
=
resume_model
self
.
resume
=
resume
self
.
multiperioddisc
=
MultiPeriodDiscriminator
()
self
.
multiresddisc
=
MultiResolutionDiscriminator
()
self
.
dac
=
DACDiscriminator
()
self
.
dacdiscriminator
=
DACGANLoss
(
self
.
dac
)
self
.
disc_loss
=
DiscriminatorLoss
()
self
.
gen_loss
=
GeneratorLoss
()
self
.
feat_matching_loss
=
FeatureMatchingLoss
()
self
.
melspec_loss
=
MelSpecReconstructionLoss
(
sample_rate
=
sample_rate
)
self
.
train_discriminator
=
False
self
.
base_mel_coeff
=
self
.
mel_loss_coeff
=
mel_loss_coeff
def
configure_optimizers
(
self
):
disc_params
=
[
{
"params"
:
self
.
multiperioddisc
.
parameters
()},
{
"params"
:
self
.
multiresddisc
.
parameters
()},
{
"params"
:
self
.
dac
.
parameters
()},
]
gen_params
=
[
{
"params"
:
self
.
feature_extractor
.
parameters
()},
{
"params"
:
self
.
backbone
.
parameters
()},
{
"params"
:
self
.
head
.
parameters
()},
]
opt_disc
=
torch
.
optim
.
AdamW
(
disc_params
,
lr
=
self
.
hparams
.
initial_learning_rate
)
opt_gen
=
torch
.
optim
.
AdamW
(
gen_params
,
lr
=
self
.
hparams
.
initial_learning_rate
)
max_steps
=
self
.
trainer
.
max_steps
//
2
# Max steps per optimizer
scheduler_disc
=
transformers
.
get_cosine_schedule_with_warmup
(
opt_disc
,
num_warmup_steps
=
self
.
hparams
.
num_warmup_steps
,
num_training_steps
=
max_steps
,
)
scheduler_gen
=
transformers
.
get_cosine_schedule_with_warmup
(
opt_gen
,
num_warmup_steps
=
self
.
hparams
.
num_warmup_steps
,
num_training_steps
=
max_steps
,
)
return
(
[
opt_disc
,
opt_gen
],
[{
"scheduler"
:
scheduler_disc
,
"interval"
:
"step"
},
{
"scheduler"
:
scheduler_gen
,
"interval"
:
"step"
}],
)
def
forward
(
self
,
audio_input
,
**
kwargs
):
features
,
_
,
commit_loss
=
self
.
feature_extractor
(
audio_input
,
**
kwargs
)
# print('1111', self.feature_extractor.state_dict()['encodec.decoder.model.3.convtr.convtr.weight_g'])
x
=
self
.
backbone
(
features
,
**
kwargs
)
audio_output
=
self
.
head
(
x
)
return
audio_output
,
commit_loss
def
training_step
(
self
,
batch
,
batch_idx
,
optimizer_idx
,
**
kwargs
):
audio_input
=
batch
# train discriminator
if
optimizer_idx
==
0
and
self
.
train_discriminator
:
with
torch
.
no_grad
():
audio_hat
,
_
=
self
(
audio_input
,
**
kwargs
)
loss_dac
=
self
.
dacdiscriminator
.
discriminator_loss
(
audio_hat
.
unsqueeze
(
1
),
audio_input
.
unsqueeze
(
1
))
real_score_mp
,
gen_score_mp
,
_
,
_
=
self
.
multiperioddisc
(
y
=
audio_input
,
y_hat
=
audio_hat
,
**
kwargs
,)
real_score_mrd
,
gen_score_mrd
,
_
,
_
=
self
.
multiresddisc
(
y
=
audio_input
,
y_hat
=
audio_hat
,
**
kwargs
,)
loss_mp
,
loss_mp_real
,
_
=
self
.
disc_loss
(
disc_real_outputs
=
real_score_mp
,
disc_generated_outputs
=
gen_score_mp
)
loss_mrd
,
loss_mrd_real
,
_
=
self
.
disc_loss
(
disc_real_outputs
=
real_score_mrd
,
disc_generated_outputs
=
gen_score_mrd
)
loss_mp
/=
len
(
loss_mp_real
)
loss_mrd
/=
len
(
loss_mrd_real
)
loss
=
loss_mp
+
self
.
hparams
.
mrd_loss_coeff
*
loss_mrd
+
loss_dac
self
.
log
(
"discriminator/total"
,
loss
,
prog_bar
=
True
)
self
.
log
(
"discriminator/multi_period_loss"
,
loss_mp
)
self
.
log
(
"discriminator/multi_res_loss"
,
loss_mrd
)
self
.
log
(
"discriminator/dac"
,
loss_dac
)
return
loss
# train generator
if
optimizer_idx
==
1
:
audio_hat
,
commit_loss
=
self
(
audio_input
,
**
kwargs
)
if
self
.
train_discriminator
:
loss_dac_1
,
loss_dac_2
=
self
.
dacdiscriminator
.
generator_loss
(
audio_hat
.
unsqueeze
(
1
),
audio_input
.
unsqueeze
(
1
))
_
,
gen_score_mp
,
fmap_rs_mp
,
fmap_gs_mp
=
self
.
multiperioddisc
(
y
=
audio_input
,
y_hat
=
audio_hat
,
**
kwargs
,
)
_
,
gen_score_mrd
,
fmap_rs_mrd
,
fmap_gs_mrd
=
self
.
multiresddisc
(
y
=
audio_input
,
y_hat
=
audio_hat
,
**
kwargs
,
)
loss_gen_mp
,
list_loss_gen_mp
=
self
.
gen_loss
(
disc_outputs
=
gen_score_mp
)
loss_gen_mrd
,
list_loss_gen_mrd
=
self
.
gen_loss
(
disc_outputs
=
gen_score_mrd
)
loss_gen_mp
=
loss_gen_mp
/
len
(
list_loss_gen_mp
)
loss_gen_mrd
=
loss_gen_mrd
/
len
(
list_loss_gen_mrd
)
loss_fm_mp
=
self
.
feat_matching_loss
(
fmap_r
=
fmap_rs_mp
,
fmap_g
=
fmap_gs_mp
)
/
len
(
fmap_rs_mp
)
loss_fm_mrd
=
self
.
feat_matching_loss
(
fmap_r
=
fmap_rs_mrd
,
fmap_g
=
fmap_gs_mrd
)
/
len
(
fmap_rs_mrd
)
self
.
log
(
"generator/multi_period_loss"
,
loss_gen_mp
)
self
.
log
(
"generator/multi_res_loss"
,
loss_gen_mrd
)
self
.
log
(
"generator/feature_matching_mp"
,
loss_fm_mp
)
self
.
log
(
"generator/feature_matching_mrd"
,
loss_fm_mrd
)
self
.
log
(
"generator/loss_dac_1"
,
loss_dac_1
)
self
.
log
(
"generator/loss_dac_2"
,
loss_dac_2
)
else
:
loss_gen_mp
=
loss_gen_mrd
=
loss_fm_mp
=
loss_fm_mrd
=
0
mel_loss
=
self
.
melspec_loss
(
audio_hat
,
audio_input
)
loss
=
(
loss_gen_mp
+
self
.
hparams
.
mrd_loss_coeff
*
loss_gen_mrd
+
loss_fm_mp
+
self
.
hparams
.
mrd_loss_coeff
*
loss_fm_mrd
+
self
.
mel_loss_coeff
*
mel_loss
+
1000
*
commit_loss
+
loss_dac_1
+
loss_dac_2
)
self
.
log
(
"generator/total_loss"
,
loss
,
prog_bar
=
True
)
self
.
log
(
"mel_loss_coeff"
,
self
.
mel_loss_coeff
)
self
.
log
(
"generator/mel_loss"
,
mel_loss
)
self
.
log
(
"commit_loss"
,
commit_loss
)
if
self
.
global_step
%
1000
==
0
and
self
.
global_rank
==
0
:
self
.
logger
.
experiment
.
add_audio
(
"train/audio_in"
,
audio_input
[
0
].
data
.
cpu
(),
self
.
global_step
,
self
.
hparams
.
sample_rate
)
self
.
logger
.
experiment
.
add_audio
(
"train/audio_pred"
,
audio_hat
[
0
].
data
.
cpu
(),
self
.
global_step
,
self
.
hparams
.
sample_rate
)
with
torch
.
no_grad
():
mel
=
safe_log
(
self
.
melspec_loss
.
mel_spec
(
audio_input
[
0
]))
mel_hat
=
safe_log
(
self
.
melspec_loss
.
mel_spec
(
audio_hat
[
0
]))
self
.
logger
.
experiment
.
add_image
(
"train/mel_target"
,
plot_spectrogram_to_numpy
(
mel
.
data
.
cpu
().
numpy
()),
self
.
global_step
,
dataformats
=
"HWC"
,
)
self
.
logger
.
experiment
.
add_image
(
"train/mel_pred"
,
plot_spectrogram_to_numpy
(
mel_hat
.
data
.
cpu
().
numpy
()),
self
.
global_step
,
dataformats
=
"HWC"
,
)
return
loss
def
on_validation_epoch_start
(
self
):
if
self
.
hparams
.
evaluate_utmos
:
from
metrics.UTMOS
import
UTMOSScore
if
not
hasattr
(
self
,
"utmos_model"
):
self
.
utmos_model
=
UTMOSScore
(
device
=
self
.
device
)
def
validation_step
(
self
,
batch
,
batch_idx
,
**
kwargs
):
audio_input
=
batch
audio_hat
,
commit_loss
=
self
(
audio_input
,
**
kwargs
)
audio_16_khz
=
torchaudio
.
functional
.
resample
(
audio_input
,
orig_freq
=
self
.
hparams
.
sample_rate
,
new_freq
=
16000
)
audio_hat_16khz
=
torchaudio
.
functional
.
resample
(
audio_hat
,
orig_freq
=
self
.
hparams
.
sample_rate
,
new_freq
=
16000
)
if
self
.
hparams
.
evaluate_periodicty
:
from
metrics.periodicity
import
calculate_periodicity_metrics
periodicity_loss
,
pitch_loss
,
f1_score
=
calculate_periodicity_metrics
(
audio_16_khz
,
audio_hat_16khz
)
else
:
periodicity_loss
=
pitch_loss
=
f1_score
=
0
if
self
.
hparams
.
evaluate_utmos
:
utmos_score
=
self
.
utmos_model
.
score
(
audio_hat_16khz
.
unsqueeze
(
1
)).
mean
()
else
:
utmos_score
=
torch
.
zeros
(
1
,
device
=
self
.
device
)
if
self
.
hparams
.
evaluate_pesq
:
from
pesq
import
pesq
pesq_score
=
0
for
ref
,
deg
in
zip
(
audio_16_khz
.
cpu
().
numpy
(),
audio_hat_16khz
.
cpu
().
numpy
()):
pesq_score
+=
pesq
(
16000
,
ref
,
deg
,
"wb"
,
on_error
=
1
)
pesq_score
/=
len
(
audio_16_khz
)
pesq_score
=
torch
.
tensor
(
pesq_score
)
else
:
pesq_score
=
torch
.
zeros
(
1
,
device
=
self
.
device
)
mel_loss
=
self
.
melspec_loss
(
audio_hat
.
unsqueeze
(
1
),
audio_input
.
unsqueeze
(
1
))
total_loss
=
mel_loss
+
(
5
-
utmos_score
)
+
(
5
-
pesq_score
)
+
1000
*
commit_loss
return
{
"val_loss"
:
total_loss
,
"mel_loss"
:
mel_loss
,
"utmos_score"
:
utmos_score
,
"pesq_score"
:
pesq_score
,
"periodicity_loss"
:
periodicity_loss
,
"pitch_loss"
:
pitch_loss
,
"f1_score"
:
f1_score
,
"audio_input"
:
audio_input
[
0
],
"audio_pred"
:
audio_hat
[
0
],
}
def
validation_epoch_end
(
self
,
outputs
):
if
self
.
global_rank
==
0
:
*
_
,
audio_in
,
audio_pred
=
outputs
[
0
].
values
()
self
.
logger
.
experiment
.
add_audio
(
"val_in"
,
audio_in
.
data
.
cpu
().
numpy
(),
self
.
global_step
,
self
.
hparams
.
sample_rate
)
self
.
logger
.
experiment
.
add_audio
(
"val_pred"
,
audio_pred
.
data
.
cpu
().
numpy
(),
self
.
global_step
,
self
.
hparams
.
sample_rate
)
mel_target
=
safe_log
(
self
.
melspec_loss
.
mel_spec
(
audio_in
))
mel_hat
=
safe_log
(
self
.
melspec_loss
.
mel_spec
(
audio_pred
))
self
.
logger
.
experiment
.
add_image
(
"val_mel_target"
,
plot_spectrogram_to_numpy
(
mel_target
.
data
.
cpu
().
numpy
()),
self
.
global_step
,
dataformats
=
"HWC"
,
)
self
.
logger
.
experiment
.
add_image
(
"val_mel_hat"
,
plot_spectrogram_to_numpy
(
mel_hat
.
data
.
cpu
().
numpy
()),
self
.
global_step
,
dataformats
=
"HWC"
,
)
avg_loss
=
torch
.
stack
([
x
[
"val_loss"
]
for
x
in
outputs
]).
mean
()
mel_loss
=
torch
.
stack
([
x
[
"mel_loss"
]
for
x
in
outputs
]).
mean
()
utmos_score
=
torch
.
stack
([
x
[
"utmos_score"
]
for
x
in
outputs
]).
mean
()
pesq_score
=
torch
.
stack
([
x
[
"pesq_score"
]
for
x
in
outputs
]).
mean
()
periodicity_loss
=
np
.
array
([
x
[
"periodicity_loss"
]
for
x
in
outputs
]).
mean
()
pitch_loss
=
np
.
array
([
x
[
"pitch_loss"
]
for
x
in
outputs
]).
mean
()
f1_score
=
np
.
array
([
x
[
"f1_score"
]
for
x
in
outputs
]).
mean
()
self
.
log
(
"val_loss"
,
avg_loss
,
sync_dist
=
True
)
self
.
log
(
"val/mel_loss"
,
mel_loss
,
sync_dist
=
True
)
self
.
log
(
"val/utmos_score"
,
utmos_score
,
sync_dist
=
True
)
self
.
log
(
"val/pesq_score"
,
pesq_score
,
sync_dist
=
True
)
self
.
log
(
"val/periodicity_loss"
,
periodicity_loss
,
sync_dist
=
True
)
self
.
log
(
"val/pitch_loss"
,
pitch_loss
,
sync_dist
=
True
)
self
.
log
(
"val/f1_score"
,
f1_score
,
sync_dist
=
True
)
@
property
def
global_step
(
self
):
"""
Override global_step so that it returns the total number of batches processed
"""
return
self
.
trainer
.
fit_loop
.
epoch_loop
.
total_batch_idx
def
on_train_batch_start
(
self
,
*
args
):
if
self
.
global_step
>=
self
.
hparams
.
pretrain_mel_steps
:
self
.
train_discriminator
=
True
else
:
self
.
train_discriminator
=
False
def
on_train_batch_end
(
self
,
*
args
):
def
mel_loss_coeff_decay
(
current_step
,
num_cycles
=
0.5
):
max_steps
=
self
.
trainer
.
max_steps
//
2
if
current_step
<
self
.
hparams
.
num_warmup_steps
:
return
1.0
progress
=
float
(
current_step
-
self
.
hparams
.
num_warmup_steps
)
/
float
(
max
(
1
,
max_steps
-
self
.
hparams
.
num_warmup_steps
)
)
return
max
(
0.0
,
0.5
*
(
1.0
+
math
.
cos
(
math
.
pi
*
float
(
num_cycles
)
*
2.0
*
progress
)))
if
self
.
hparams
.
decay_mel_coeff
:
self
.
mel_loss_coeff
=
self
.
base_mel_coeff
*
mel_loss_coeff_decay
(
self
.
global_step
+
1
)
class
WavTokenizer
(
VocosExp
):
"""
WavTokenizer is a subclass of VocosExp that overrides the parent experiment to function as a conditional GAN.
It manages an additional `bandwidth_id` attribute, which denotes a learnable embedding corresponding to
a specific bandwidth value of EnCodec. During training, a random bandwidth_id is generated for each step,
while during validation, a fixed bandwidth_id is used.
"""
def
__init__
(
self
,
feature_extractor
:
FeatureExtractor
,
backbone
:
Backbone
,
head
:
FourierHead
,
resume_config
:
str
,
resume_model
:
str
,
sample_rate
:
int
=
24000
,
initial_learning_rate
:
float
=
2e-4
,
num_warmup_steps
:
int
=
0
,
mel_loss_coeff
:
float
=
45
,
mrd_loss_coeff
:
float
=
1.0
,
pretrain_mel_steps
:
int
=
0
,
decay_mel_coeff
:
bool
=
False
,
evaluate_utmos
:
bool
=
False
,
evaluate_pesq
:
bool
=
False
,
evaluate_periodicty
:
bool
=
False
,
resume
:
bool
=
False
,
):
super
().
__init__
(
feature_extractor
,
backbone
,
head
,
resume_config
,
resume_model
,
sample_rate
,
initial_learning_rate
,
num_warmup_steps
,
mel_loss_coeff
,
mrd_loss_coeff
,
pretrain_mel_steps
,
decay_mel_coeff
,
evaluate_utmos
,
evaluate_pesq
,
evaluate_periodicty
,
resume
)
# Override with conditional discriminators
# VocosExp.__init__(self, feature_extractor, backbone, head, resume_config, resume_model)
# if self.resume:
# VocosExp.load_from_checkpoint(self.resume_model)
self
.
multiperioddisc
=
MultiPeriodDiscriminator
(
num_embeddings
=
len
(
self
.
feature_extractor
.
bandwidths
))
self
.
multiresddisc
=
MultiResolutionDiscriminator
(
num_embeddings
=
len
(
self
.
feature_extractor
.
bandwidths
))
self
.
dac
=
DACDiscriminator
()
if
self
.
resume
:
print
(
'加载预训练模型:'
,
self
.
resume_model
)
# with open(self.resume_config, "r") as f:
# config = yaml.safe_load(f)
# feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"])
# backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"])
# head = instantiate_class(args=(), init=config['model']['init_args']["head"])
# 不加载量化器部分权重
state_dict_raw
=
torch
.
load
(
self
.
resume_model
,
map_location
=
self
.
device
)[
'state_dict'
]
state_dict_fa_qa
=
dict
()
state_dict_fa_en
=
dict
()
state_dict_fa_de
=
dict
()
state_dict_bb
=
dict
()
state_dict_hd
=
dict
()
state_dict_mp
=
dict
()
state_dict_mr
=
dict
()
state_dict_dac
=
dict
()
for
k
,
v
in
state_dict_raw
.
items
():
# breakpoint()
if
k
.
startswith
(
'feature_extractor.encodec.quantizer'
):
# breakpoint()
# print("*****",k)
ss
=
k
[
46
:
48
]
if
ss
[
-
1
]
==
'.'
:
num
=
int
(
ss
[
0
])
# print("num,k",num,k[36:])
if
num
<=
7
:
state_dict_fa_qa
[
k
[
36
:]]
=
v
if
k
.
startswith
(
'feature_extractor.encodec.encoder'
):
state_dict_fa_en
[
k
[
34
:]]
=
v
if
k
.
startswith
(
'feature_extractor.encodec.decoder'
):
state_dict_fa_de
[
k
[
34
:]]
=
v
if
k
.
startswith
(
'backbone.'
):
state_dict_bb
[
k
[
9
:]]
=
v
if
k
.
startswith
(
'head.'
):
state_dict_hd
[
k
[
5
:]]
=
v
if
k
.
startswith
(
'multiperioddisc.'
):
state_dict_mp
[
k
[
16
:]]
=
v
if
k
.
startswith
(
'multiresddisc.'
):
state_dict_mr
[
k
[
14
:]]
=
v
if
k
.
startswith
(
'dac.'
):
state_dict_dac
[
k
[
4
:]]
=
v
# breakpoint()
# feature_extractor.encodec.quantizer.load_state_dict(state_dict_fa_qa, strict=True)
feature_extractor
.
encodec
.
encoder
.
load_state_dict
(
state_dict_fa_en
,
strict
=
True
)
feature_extractor
.
encodec
.
decoder
.
load_state_dict
(
state_dict_fa_de
,
strict
=
True
)
feature_extractor
.
encodec
.
quantizer
.
load_state_dict
(
state_dict_fa_qa
,
strict
=
True
)
backbone
.
load_state_dict
(
state_dict_bb
,
strict
=
True
)
head
.
load_state_dict
(
state_dict_hd
,
strict
=
True
)
self
.
feature_extractor
=
feature_extractor
.
to
(
self
.
device
)
self
.
backbone
=
backbone
.
to
(
self
.
device
)
self
.
head
=
head
.
to
(
self
.
device
)
self
.
multiperioddisc
.
load_state_dict
(
state_dict_mp
,
strict
=
True
)
self
.
multiresddisc
.
load_state_dict
(
state_dict_mr
,
strict
=
True
)
self
.
dac
.
load_state_dict
(
state_dict_dac
,
strict
=
True
)
def
training_step
(
self
,
*
args
):
# print('-------------------train--------------------')
# if self.global_rank == 0 and self.resume:
# config_path = self.resume_config
# model_path = self.resume_model
# self.pretrained_load(config_path, model_path)
# print('加载预训练模型:', model_path)
bandwidth_id
=
torch
.
randint
(
low
=
0
,
high
=
len
(
self
.
feature_extractor
.
bandwidths
),
size
=
(
1
,),
device
=
self
.
device
,)
output
=
super
().
training_step
(
*
args
,
bandwidth_id
=
bandwidth_id
)
return
output
def
validation_step
(
self
,
*
args
):
# print('-------------------valid--------------------')
bandwidth_id
=
torch
.
tensor
([
0
],
device
=
self
.
device
)
output
=
super
().
validation_step
(
*
args
,
bandwidth_id
=
bandwidth_id
)
return
output
def
validation_epoch_end
(
self
,
outputs
):
if
self
.
global_rank
==
0
:
*
_
,
audio_in
,
_
=
outputs
[
0
].
values
()
# Resynthesis with encodec for reference
self
.
feature_extractor
.
encodec
.
set_target_bandwidth
(
self
.
feature_extractor
.
bandwidths
[
0
])
encodec_audio
=
self
.
feature_extractor
.
encodec
(
audio_in
[
None
,
None
,
:])
self
.
logger
.
experiment
.
add_audio
(
"encodec"
,
encodec_audio
[
0
,
0
].
data
.
cpu
().
numpy
(),
self
.
global_step
,
self
.
hparams
.
sample_rate
,
)
super
().
validation_epoch_end
(
outputs
)
examples/music_generation/inspiremusic/wavtokenizer/decoder/feature_extractors.py
0 → 100644
View file @
0112b0f0
from
typing
import
List
import
torch
import
torchaudio
from
torch
import
nn
import
math
# from inspiremusic.wavtokenizer.decoder.modules import safe_log
from
inspiremusic.wavtokenizer.encoder.modules
import
SEANetEncoder
,
SEANetDecoder
from
inspiremusic.wavtokenizer.encoder
import
EncodecModel
from
inspiremusic.wavtokenizer.encoder.quantization
import
ResidualVectorQuantizer
def
safe_log
(
x
:
torch
.
Tensor
,
clip_val
:
float
=
1e-7
)
->
torch
.
Tensor
:
"""
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
Args:
x (Tensor): Input tensor.
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
Returns:
Tensor: Element-wise logarithm of the input tensor with clipping applied.
"""
return
torch
.
log
(
torch
.
clip
(
x
,
min
=
clip_val
))
def
symlog
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
sign
(
x
)
*
torch
.
log1p
(
x
.
abs
())
def
symexp
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
sign
(
x
)
*
(
torch
.
exp
(
x
.
abs
())
-
1
)
class
FeatureExtractor
(
nn
.
Module
):
"""Base class for feature extractors."""
def
forward
(
self
,
audio
:
torch
.
Tensor
,
**
kwargs
)
->
torch
.
Tensor
:
"""
Extract features from the given audio.
Args:
audio (Tensor): Input audio waveform.
Returns:
Tensor: Extracted features of shape (B, C, L), where B is the batch size,
C denotes output features, and L is the sequence length.
"""
raise
NotImplementedError
(
"Subclasses must implement the forward method."
)
class
MelSpectrogramFeatures
(
FeatureExtractor
):
def
__init__
(
self
,
sample_rate
=
24000
,
n_fft
=
1024
,
hop_length
=
256
,
n_mels
=
100
,
padding
=
"center"
):
super
().
__init__
()
if
padding
not
in
[
"center"
,
"same"
]:
raise
ValueError
(
"Padding must be 'center' or 'same'."
)
self
.
padding
=
padding
self
.
mel_spec
=
torchaudio
.
transforms
.
MelSpectrogram
(
sample_rate
=
sample_rate
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
n_mels
=
n_mels
,
center
=
padding
==
"center"
,
power
=
1
,
)
def
forward
(
self
,
audio
,
**
kwargs
):
if
self
.
padding
==
"same"
:
pad
=
self
.
mel_spec
.
win_length
-
self
.
mel_spec
.
hop_length
audio
=
torch
.
nn
.
functional
.
pad
(
audio
,
(
pad
//
2
,
pad
//
2
),
mode
=
"reflect"
)
mel
=
self
.
mel_spec
(
audio
)
features
=
safe_log
(
mel
)
return
features
class
EncodecFeatures
(
FeatureExtractor
):
def
__init__
(
self
,
encodec_model
:
str
=
"encodec_24khz"
,
bandwidths
:
List
[
float
]
=
[
1.5
,
3.0
,
6.0
,
12.0
],
train_codebooks
:
bool
=
False
,
num_quantizers
:
int
=
1
,
dowmsamples
:
List
[
int
]
=
[
6
,
5
,
5
,
4
],
vq_bins
:
int
=
16384
,
vq_kmeans
:
int
=
800
,
):
super
().
__init__
()
# breakpoint()
self
.
frame_rate
=
25
# not use
# n_q = int(bandwidths[-1]*1000/(math.log2(2048) * self.frame_rate))
n_q
=
num_quantizers
# important
encoder
=
SEANetEncoder
(
causal
=
False
,
n_residual_layers
=
1
,
norm
=
'weight_norm'
,
pad_mode
=
'reflect'
,
lstm
=
2
,
dimension
=
512
,
channels
=
1
,
n_filters
=
32
,
ratios
=
dowmsamples
,
activation
=
'ELU'
,
kernel_size
=
7
,
residual_kernel_size
=
3
,
last_kernel_size
=
7
,
dilation_base
=
2
,
true_skip
=
False
,
compress
=
2
)
decoder
=
SEANetDecoder
(
causal
=
False
,
n_residual_layers
=
1
,
norm
=
'weight_norm'
,
pad_mode
=
'reflect'
,
lstm
=
2
,
dimension
=
512
,
channels
=
1
,
n_filters
=
32
,
ratios
=
[
8
,
5
,
4
,
2
],
activation
=
'ELU'
,
kernel_size
=
7
,
residual_kernel_size
=
3
,
last_kernel_size
=
7
,
dilation_base
=
2
,
true_skip
=
False
,
compress
=
2
)
quantizer
=
ResidualVectorQuantizer
(
dimension
=
512
,
n_q
=
n_q
,
bins
=
vq_bins
,
kmeans_iters
=
vq_kmeans
,
decay
=
0.99
,
kmeans_init
=
True
)
# breakpoint()
if
encodec_model
==
"encodec_24khz"
:
self
.
encodec
=
EncodecModel
(
encoder
=
encoder
,
decoder
=
decoder
,
quantizer
=
quantizer
,
target_bandwidths
=
bandwidths
,
sample_rate
=
24000
,
channels
=
1
)
else
:
raise
ValueError
(
f
"Unsupported encodec_model:
{
encodec_model
}
. Supported options are 'encodec_24khz'."
)
for
param
in
self
.
encodec
.
parameters
():
param
.
requires_grad
=
True
# self.num_q = n_q
# codebook_weights = torch.cat([vq.codebook for vq in self.encodec.quantizer.vq.layers[: self.num_q]], dim=0)
# self.codebook_weights = torch.nn.Parameter(codebook_weights, requires_grad=train_codebooks)
self
.
bandwidths
=
bandwidths
# @torch.no_grad()
# def get_encodec_codes(self, audio):
# audio = audio.unsqueeze(1)
# emb = self.encodec.encoder(audio)
# codes = self.encodec.quantizer.encode(emb, self.encodec.frame_rate, self.encodec.bandwidth)
# return codes
def
forward
(
self
,
audio
:
torch
.
Tensor
,
bandwidth_id
:
torch
.
Tensor
=
torch
.
tensor
(
0
)):
if
self
.
training
:
self
.
encodec
.
train
()
audio
=
audio
.
unsqueeze
(
1
)
# audio(16,24000)
# breakpoint()
emb
=
self
.
encodec
.
encoder
(
audio
)
q_res
=
self
.
encodec
.
quantizer
(
emb
,
self
.
frame_rate
,
bandwidth
=
self
.
bandwidths
[
bandwidth_id
])
quantized
=
q_res
.
quantized
codes
=
q_res
.
codes
commit_loss
=
q_res
.
penalty
# codes(8,16,75),features(16,128,75)
return
quantized
,
codes
,
commit_loss
# codes = self.get_encodec_codes(audio)
# # Instead of summing in the loop, it stores subsequent VQ dictionaries in a single `self.codebook_weights`
# # with offsets given by the number of bins, and finally summed in a vectorized operation.
# offsets = torch.arange(
# 0, self.encodec.quantizer.bins * len(codes), self.encodec.quantizer.bins, device=audio.device
# )
# embeddings_idxs = codes + offsets.view(-1, 1, 1)
# features = torch.nn.functional.embedding(embeddings_idxs, self.codebook_weights).sum(dim=0)
# return features.transpose(1, 2)
def
infer
(
self
,
audio
:
torch
.
Tensor
,
bandwidth_id
:
torch
.
Tensor
):
if
self
.
training
:
self
.
encodec
.
train
()
audio
=
audio
.
unsqueeze
(
1
)
# audio(16,24000)
emb
=
self
.
encodec
.
encoder
(
audio
)
q_res
=
self
.
encodec
.
quantizer
.
infer
(
emb
,
self
.
frame_rate
,
bandwidth
=
self
.
bandwidths
[
bandwidth_id
])
quantized
=
q_res
.
quantized
codes
=
q_res
.
codes
commit_loss
=
q_res
.
penalty
# codes(8,16,75),features(16,128,75)
return
quantized
,
codes
,
commit_loss
def
_infer
(
self
,
audio
:
torch
.
Tensor
,
bandwidth_id
:
torch
.
Tensor
=
torch
.
tensor
(
0
)):
if
self
.
training
:
self
.
encodec
.
train
()
audio
=
audio
.
unsqueeze
(
1
)
# audio(16,24000)
emb
=
self
.
encodec
.
encoder
(
audio
)
q_res
=
self
.
encodec
.
quantizer
.
infer
(
emb
,
self
.
frame_rate
,
bandwidth
=
self
.
bandwidths
[
bandwidth_id
])
quantized
=
q_res
.
quantized
codes
=
q_res
.
codes
commit_loss
=
q_res
.
penalty
# codes(8,16,75),features(16,128,75)
return
quantized
,
codes
,
commit_loss
\ No newline at end of file
examples/music_generation/inspiremusic/wavtokenizer/decoder/heads.py
0 → 100644
View file @
0112b0f0
import
torch
from
torch
import
nn
from
torchaudio.functional.functional
import
_hz_to_mel
,
_mel_to_hz
from
inspiremusic.wavtokenizer.decoder.spectral_ops
import
IMDCT
,
ISTFT
def
symexp
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
sign
(
x
)
*
(
torch
.
exp
(
x
.
abs
())
-
1
)
class
FourierHead
(
nn
.
Module
):
"""Base class for inverse fourier modules."""
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
raise
NotImplementedError
(
"Subclasses must implement the forward method."
)
class
ISTFTHead
(
FourierHead
):
"""
ISTFT Head module for predicting STFT complex coefficients.
Args:
dim (int): Hidden dimension of the model.
n_fft (int): Size of Fourier transform.
hop_length (int): The distance between neighboring sliding window frames, which should align with
the resolution of the input features.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def
__init__
(
self
,
dim
:
int
,
n_fft
:
int
,
hop_length
:
int
,
padding
:
str
=
"same"
):
super
().
__init__
()
out_dim
=
n_fft
+
2
self
.
out
=
torch
.
nn
.
Linear
(
dim
,
out_dim
)
self
.
istft
=
ISTFT
(
n_fft
=
n_fft
,
hop_length
=
hop_length
,
win_length
=
n_fft
,
padding
=
padding
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the ISTFTHead module.
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
x
=
self
.
out
(
x
).
transpose
(
1
,
2
)
mag
,
p
=
x
.
chunk
(
2
,
dim
=
1
)
mag
=
torch
.
exp
(
mag
)
mag
=
torch
.
clip
(
mag
,
max
=
1e2
)
# safeguard to prevent excessively large magnitudes
# wrapping happens here. These two lines produce real and imaginary value
x
=
torch
.
cos
(
p
)
y
=
torch
.
sin
(
p
)
# recalculating phase here does not produce anything new
# only costs time
# phase = torch.atan2(y, x)
# S = mag * torch.exp(phase * 1j)
# better directly produce the complex value
S
=
mag
*
(
x
+
1j
*
y
)
audio
=
self
.
istft
(
S
)
return
audio
class
IMDCTSymExpHead
(
FourierHead
):
"""
IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
Args:
dim (int): Hidden dimension of the model.
mdct_frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
based on perceptual scaling. Defaults to None.
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
"""
def
__init__
(
self
,
dim
:
int
,
mdct_frame_len
:
int
,
padding
:
str
=
"same"
,
sample_rate
:
int
=
None
,
clip_audio
:
bool
=
False
,
):
super
().
__init__
()
out_dim
=
mdct_frame_len
//
2
self
.
out
=
nn
.
Linear
(
dim
,
out_dim
)
self
.
imdct
=
IMDCT
(
frame_len
=
mdct_frame_len
,
padding
=
padding
)
self
.
clip_audio
=
clip_audio
if
sample_rate
is
not
None
:
# optionally init the last layer following mel-scale
m_max
=
_hz_to_mel
(
sample_rate
//
2
)
m_pts
=
torch
.
linspace
(
0
,
m_max
,
out_dim
)
f_pts
=
_mel_to_hz
(
m_pts
)
scale
=
1
-
(
f_pts
/
f_pts
.
max
())
with
torch
.
no_grad
():
self
.
out
.
weight
.
mul_
(
scale
.
view
(
-
1
,
1
))
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the IMDCTSymExpHead module.
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
x
=
self
.
out
(
x
)
x
=
symexp
(
x
)
x
=
torch
.
clip
(
x
,
min
=-
1e2
,
max
=
1e2
)
# safeguard to prevent excessively large magnitudes
audio
=
self
.
imdct
(
x
)
if
self
.
clip_audio
:
audio
=
torch
.
clip
(
x
,
min
=-
1.0
,
max
=
1.0
)
return
audio
class
IMDCTCosHead
(
FourierHead
):
"""
IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
Args:
dim (int): Hidden dimension of the model.
mdct_frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
"""
def
__init__
(
self
,
dim
:
int
,
mdct_frame_len
:
int
,
padding
:
str
=
"same"
,
clip_audio
:
bool
=
False
):
super
().
__init__
()
self
.
clip_audio
=
clip_audio
self
.
out
=
nn
.
Linear
(
dim
,
mdct_frame_len
)
self
.
imdct
=
IMDCT
(
frame_len
=
mdct_frame_len
,
padding
=
padding
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the IMDCTCosHead module.
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
x
=
self
.
out
(
x
)
m
,
p
=
x
.
chunk
(
2
,
dim
=
2
)
m
=
torch
.
exp
(
m
).
clip
(
max
=
1e2
)
# safeguard to prevent excessively large magnitudes
audio
=
self
.
imdct
(
m
*
torch
.
cos
(
p
))
if
self
.
clip_audio
:
audio
=
torch
.
clip
(
x
,
min
=-
1.0
,
max
=
1.0
)
return
audio
examples/music_generation/inspiremusic/wavtokenizer/decoder/helpers.py
0 → 100644
View file @
0112b0f0
import
matplotlib
import
numpy
as
np
import
torch
from
matplotlib
import
pyplot
as
plt
from
pytorch_lightning
import
Callback
matplotlib
.
use
(
"Agg"
)
def
save_figure_to_numpy
(
fig
:
plt
.
Figure
)
->
np
.
ndarray
:
"""
Save a matplotlib figure to a numpy array.
Args:
fig (Figure): Matplotlib figure object.
Returns:
ndarray: Numpy array representing the figure.
"""
data
=
np
.
fromstring
(
fig
.
canvas
.
tostring_rgb
(),
dtype
=
np
.
uint8
,
sep
=
""
)
data
=
data
.
reshape
(
fig
.
canvas
.
get_width_height
()[::
-
1
]
+
(
3
,))
return
data
def
plot_spectrogram_to_numpy
(
spectrogram
:
np
.
ndarray
)
->
np
.
ndarray
:
"""
Plot a spectrogram and convert it to a numpy array.
Args:
spectrogram (ndarray): Spectrogram data.
Returns:
ndarray: Numpy array representing the plotted spectrogram.
"""
spectrogram
=
spectrogram
.
astype
(
np
.
float32
)
fig
,
ax
=
plt
.
subplots
(
figsize
=
(
12
,
3
))
im
=
ax
.
imshow
(
spectrogram
,
aspect
=
"auto"
,
origin
=
"lower"
,
interpolation
=
"none"
)
plt
.
colorbar
(
im
,
ax
=
ax
)
plt
.
xlabel
(
"Frames"
)
plt
.
ylabel
(
"Channels"
)
plt
.
tight_layout
()
fig
.
canvas
.
draw
()
data
=
save_figure_to_numpy
(
fig
)
plt
.
close
()
return
data
class
GradNormCallback
(
Callback
):
"""
Callback to log the gradient norm.
"""
def
on_after_backward
(
self
,
trainer
,
model
):
model
.
log
(
"grad_norm"
,
gradient_norm
(
model
))
def
gradient_norm
(
model
:
torch
.
nn
.
Module
,
norm_type
:
float
=
2.0
)
->
torch
.
Tensor
:
"""
Compute the gradient norm.
Args:
model (Module): PyTorch model.
norm_type (float, optional): Type of the norm. Defaults to 2.0.
Returns:
Tensor: Gradient norm.
"""
grads
=
[
p
.
grad
for
p
in
model
.
parameters
()
if
p
.
grad
is
not
None
]
total_norm
=
torch
.
norm
(
torch
.
stack
([
torch
.
norm
(
g
.
detach
(),
norm_type
)
for
g
in
grads
]),
norm_type
)
return
total_norm
examples/music_generation/inspiremusic/wavtokenizer/decoder/loss.py
0 → 100644
View file @
0112b0f0
from
typing
import
List
,
Tuple
import
torch
import
torchaudio
from
torch
import
nn
from
decoder.modules
import
safe_log
import
torch.nn.functional
as
F
class
MelSpecReconstructionLoss
(
nn
.
Module
):
"""
L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample
"""
def
__init__
(
self
,
sample_rate
:
int
=
24000
,
n_fft
:
int
=
1024
,
hop_length
:
int
=
256
,
n_mels
:
int
=
100
,
):
super
().
__init__
()
self
.
mel_spec
=
torchaudio
.
transforms
.
MelSpectrogram
(
sample_rate
=
sample_rate
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
n_mels
=
n_mels
,
center
=
True
,
power
=
1
,
)
def
forward
(
self
,
y_hat
,
y
)
->
torch
.
Tensor
:
"""
Args:
y_hat (Tensor): Predicted audio waveform.
y (Tensor): Ground truth audio waveform.
Returns:
Tensor: L1 loss between the mel-scaled magnitude spectrograms.
"""
mel_hat
=
safe_log
(
self
.
mel_spec
(
y_hat
))
mel
=
safe_log
(
self
.
mel_spec
(
y
))
loss
=
torch
.
nn
.
functional
.
l1_loss
(
mel
,
mel_hat
)
return
loss
class
GeneratorLoss
(
nn
.
Module
):
"""
Generator Loss module. Calculates the loss for the generator based on discriminator outputs.
"""
def
forward
(
self
,
disc_outputs
:
List
[
torch
.
Tensor
])
->
Tuple
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]:
"""
Args:
disc_outputs (List[Tensor]): List of discriminator outputs.
Returns:
Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from
the sub-discriminators
"""
loss
=
0
gen_losses
=
[]
for
dg
in
disc_outputs
:
l
=
torch
.
mean
(
torch
.
clamp
(
1
-
dg
,
min
=
0
))
gen_losses
.
append
(
l
)
loss
+=
l
return
loss
,
gen_losses
class
DiscriminatorLoss
(
nn
.
Module
):
"""
Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs.
"""
def
forward
(
self
,
disc_real_outputs
:
List
[
torch
.
Tensor
],
disc_generated_outputs
:
List
[
torch
.
Tensor
]
)
->
Tuple
[
torch
.
Tensor
,
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
]]:
"""
Args:
disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples.
disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples.
Returns:
Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from
the sub-discriminators for real outputs, and a list of
loss values for generated outputs.
"""
loss
=
0
r_losses
=
[]
g_losses
=
[]
for
dr
,
dg
in
zip
(
disc_real_outputs
,
disc_generated_outputs
):
r_loss
=
torch
.
mean
(
torch
.
clamp
(
1
-
dr
,
min
=
0
))
g_loss
=
torch
.
mean
(
torch
.
clamp
(
1
+
dg
,
min
=
0
))
loss
+=
r_loss
+
g_loss
r_losses
.
append
(
r_loss
.
item
())
g_losses
.
append
(
g_loss
.
item
())
return
loss
,
r_losses
,
g_losses
class
FeatureMatchingLoss
(
nn
.
Module
):
"""
Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators.
"""
def
forward
(
self
,
fmap_r
:
List
[
List
[
torch
.
Tensor
]],
fmap_g
:
List
[
List
[
torch
.
Tensor
]])
->
torch
.
Tensor
:
"""
Args:
fmap_r (List[List[Tensor]]): List of feature maps from real samples.
fmap_g (List[List[Tensor]]): List of feature maps from generated samples.
Returns:
Tensor: The calculated feature matching loss.
"""
loss
=
0
for
dr
,
dg
in
zip
(
fmap_r
,
fmap_g
):
for
rl
,
gl
in
zip
(
dr
,
dg
):
loss
+=
torch
.
mean
(
torch
.
abs
(
rl
-
gl
))
return
loss
class
DACGANLoss
(
nn
.
Module
):
"""
Computes a discriminator loss, given a discriminator on
generated waveforms/spectrograms compared to ground truth
waveforms/spectrograms. Computes the loss for both the
discriminator and the generator in separate functions.
"""
def
__init__
(
self
,
discriminator
):
super
().
__init__
()
self
.
discriminator
=
discriminator
def
forward
(
self
,
fake
,
real
):
# d_fake = self.discriminator(fake.audio_data)
# d_real = self.discriminator(real.audio_data)
d_fake
=
self
.
discriminator
(
fake
)
d_real
=
self
.
discriminator
(
real
)
return
d_fake
,
d_real
def
discriminator_loss
(
self
,
fake
,
real
):
d_fake
,
d_real
=
self
.
forward
(
fake
.
clone
().
detach
(),
real
)
loss_d
=
0
for
x_fake
,
x_real
in
zip
(
d_fake
,
d_real
):
loss_d
+=
torch
.
mean
(
x_fake
[
-
1
]
**
2
)
loss_d
+=
torch
.
mean
((
1
-
x_real
[
-
1
])
**
2
)
return
loss_d
def
generator_loss
(
self
,
fake
,
real
):
d_fake
,
d_real
=
self
.
forward
(
fake
,
real
)
loss_g
=
0
for
x_fake
in
d_fake
:
loss_g
+=
torch
.
mean
((
1
-
x_fake
[
-
1
])
**
2
)
loss_feature
=
0
for
i
in
range
(
len
(
d_fake
)):
for
j
in
range
(
len
(
d_fake
[
i
])
-
1
):
loss_feature
+=
F
.
l1_loss
(
d_fake
[
i
][
j
],
d_real
[
i
][
j
].
detach
())
return
loss_g
,
loss_feature
examples/music_generation/inspiremusic/wavtokenizer/decoder/models.py
0 → 100644
View file @
0112b0f0
from
typing
import
Optional
import
torch
from
torch
import
nn
from
torch.nn.utils
import
weight_norm
from
inspiremusic.wavtokenizer.decoder.modules
import
ConvNeXtBlock
,
ResBlock1
,
AdaLayerNorm
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Normalize
(
in_channels
,
num_groups
=
32
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
norm1
=
Normalize
(
in_channels
)
self
.
conv1
=
torch
.
nn
.
Conv1d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
temb_channels
>
0
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv1d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
torch
.
nn
.
Conv1d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv1d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
temb
=
None
):
h
=
x
h
=
self
.
norm1
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
if
temb
is
not
None
:
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
norm2
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
norm
=
Normalize
(
in_channels
)
self
.
q
=
torch
.
nn
.
Conv1d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv1d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv1d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv1d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
):
h_
=
x
h_
=
self
.
norm
(
h_
)
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
# compute attention
b
,
c
,
h
=
q
.
shape
q
=
q
.
permute
(
0
,
2
,
1
)
# b,hw,c
w_
=
torch
.
bmm
(
q
,
k
)
# b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
# attend to values
w_
=
w_
.
permute
(
0
,
2
,
1
)
# b,hw,hw (first hw of k, second of q)
h_
=
torch
.
bmm
(
v
,
w_
)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_
=
self
.
proj_out
(
h_
)
return
x
+
h_
def
make_attn
(
in_channels
,
attn_type
=
"vanilla"
):
assert
attn_type
in
[
"vanilla"
,
"linear"
,
"none"
],
f
'attn_type
{
attn_type
}
unknown'
print
(
f
"making attention of type '
{
attn_type
}
' with
{
in_channels
}
in_channels"
)
if
attn_type
==
"vanilla"
:
return
AttnBlock
(
in_channels
)
class
Backbone
(
nn
.
Module
):
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
def
forward
(
self
,
x
:
torch
.
Tensor
,
**
kwargs
)
->
torch
.
Tensor
:
"""
Args:
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
C denotes output features, and L is the sequence length.
Returns:
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
and H denotes the model dimension.
"""
raise
NotImplementedError
(
"Subclasses must implement the forward method."
)
class
VocosBackbone
(
Backbone
):
"""
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
Args:
input_channels (int): Number of input features channels.
dim (int): Hidden dimension of the model.
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
num_layers (int): Number of ConvNeXtBlock layers.
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
None means non-conditional model. Defaults to None.
"""
def
__init__
(
self
,
input_channels
:
int
,
dim
:
int
,
intermediate_dim
:
int
,
num_layers
:
int
,
layer_scale_init_value
:
Optional
[
float
]
=
None
,
adanorm_num_embeddings
:
Optional
[
int
]
=
None
,
):
super
().
__init__
()
self
.
input_channels
=
input_channels
self
.
embed
=
nn
.
Conv1d
(
input_channels
,
dim
,
kernel_size
=
7
,
padding
=
3
)
self
.
adanorm
=
adanorm_num_embeddings
is
not
None
if
adanorm_num_embeddings
:
self
.
norm
=
AdaLayerNorm
(
adanorm_num_embeddings
,
dim
,
eps
=
1e-6
)
else
:
self
.
norm
=
nn
.
LayerNorm
(
dim
,
eps
=
1e-6
)
layer_scale_init_value
=
layer_scale_init_value
or
1
/
num_layers
self
.
convnext
=
nn
.
ModuleList
(
[
ConvNeXtBlock
(
dim
=
dim
,
intermediate_dim
=
intermediate_dim
,
layer_scale_init_value
=
layer_scale_init_value
,
adanorm_num_embeddings
=
adanorm_num_embeddings
,
)
for
_
in
range
(
num_layers
)
]
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
dim
,
eps
=
1e-6
)
self
.
apply
(
self
.
_init_weights
)
self
.
temb_ch
=
0
block_in
=
dim
dropout
=
0.1
attn_type
=
"vanilla"
pos_net
:
tp
.
List
[
nn
.
Module
]
=
[
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
),
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
),
make_attn
(
block_in
,
attn_type
=
attn_type
),
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
),
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
),
Normalize
(
block_in
)
]
self
.
pos_net
=
nn
.
Sequential
(
*
pos_net
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
(
nn
.
Conv1d
,
nn
.
Linear
)):
nn
.
init
.
trunc_normal_
(
m
.
weight
,
std
=
0.02
)
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
bandwidth_id
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
x
=
self
.
embed
(
x
)
x
=
self
.
pos_net
(
x
)
if
self
.
adanorm
:
# assert bandwidth_id is not None
if
bandwidth_id
is
None
:
bandwidth_id
=
torch
.
tensor
(
0
,
device
=
'cuda'
)
x
=
self
.
norm
(
x
.
transpose
(
1
,
2
),
cond_embedding_id
=
bandwidth_id
)
else
:
x
=
self
.
norm
(
x
.
transpose
(
1
,
2
))
x
=
x
.
transpose
(
1
,
2
)
for
conv_block
in
self
.
convnext
:
x
=
conv_block
(
x
,
cond_embedding_id
=
bandwidth_id
)
x
=
self
.
final_layer_norm
(
x
.
transpose
(
1
,
2
))
return
x
class
VocosResNetBackbone
(
Backbone
):
"""
Vocos backbone module built with ResBlocks.
Args:
input_channels (int): Number of input features channels.
dim (int): Hidden dimension of the model.
num_blocks (int): Number of ResBlock1 blocks.
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
"""
def
__init__
(
self
,
input_channels
,
dim
,
num_blocks
,
layer_scale_init_value
=
None
,
):
super
().
__init__
()
self
.
input_channels
=
input_channels
self
.
embed
=
weight_norm
(
nn
.
Conv1d
(
input_channels
,
dim
,
kernel_size
=
3
,
padding
=
1
))
layer_scale_init_value
=
layer_scale_init_value
or
1
/
num_blocks
/
3
self
.
resnet
=
nn
.
Sequential
(
*
[
ResBlock1
(
dim
=
dim
,
layer_scale_init_value
=
layer_scale_init_value
)
for
_
in
range
(
num_blocks
)]
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
**
kwargs
)
->
torch
.
Tensor
:
x
=
self
.
embed
(
x
)
x
=
self
.
resnet
(
x
)
x
=
x
.
transpose
(
1
,
2
)
return
x
examples/music_generation/inspiremusic/wavtokenizer/decoder/modules.py
0 → 100644
View file @
0112b0f0
from
typing
import
Optional
from
typing
import
Tuple
import
torch
from
torch
import
nn
from
torch.nn.utils
import
weight_norm
,
remove_weight_norm
class
ConvNeXtBlock
(
nn
.
Module
):
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
Args:
dim (int): Number of input channels.
intermediate_dim (int): Dimensionality of the intermediate layer.
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
Defaults to None.
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
None means non-conditional LayerNorm. Defaults to None.
"""
def
__init__
(
self
,
dim
:
int
,
intermediate_dim
:
int
,
layer_scale_init_value
:
Optional
[
float
]
=
None
,
adanorm_num_embeddings
:
Optional
[
int
]
=
None
,
):
super
().
__init__
()
self
.
dwconv
=
nn
.
Conv1d
(
dim
,
dim
,
kernel_size
=
7
,
padding
=
3
,
groups
=
dim
)
# depthwise conv
self
.
adanorm
=
adanorm_num_embeddings
is
not
None
if
adanorm_num_embeddings
:
self
.
norm
=
AdaLayerNorm
(
adanorm_num_embeddings
,
dim
,
eps
=
1e-6
)
else
:
self
.
norm
=
nn
.
LayerNorm
(
dim
,
eps
=
1e-6
)
self
.
pwconv1
=
nn
.
Linear
(
dim
,
intermediate_dim
)
# pointwise/1x1 convs, implemented with linear layers
self
.
act
=
nn
.
GELU
()
self
.
pwconv2
=
nn
.
Linear
(
intermediate_dim
,
dim
)
self
.
gamma
=
(
nn
.
Parameter
(
layer_scale_init_value
*
torch
.
ones
(
dim
),
requires_grad
=
True
)
if
layer_scale_init_value
>
0
else
None
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
cond_embedding_id
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
residual
=
x
x
=
self
.
dwconv
(
x
)
x
=
x
.
transpose
(
1
,
2
)
# (B, C, T) -> (B, T, C)
if
self
.
adanorm
:
assert
cond_embedding_id
is
not
None
x
=
self
.
norm
(
x
,
cond_embedding_id
)
else
:
x
=
self
.
norm
(
x
)
x
=
self
.
pwconv1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
pwconv2
(
x
)
if
self
.
gamma
is
not
None
:
x
=
self
.
gamma
*
x
x
=
x
.
transpose
(
1
,
2
)
# (B, T, C) -> (B, C, T)
x
=
residual
+
x
return
x
class
AdaLayerNorm
(
nn
.
Module
):
"""
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
Args:
num_embeddings (int): Number of embeddings.
embedding_dim (int): Dimension of the embeddings.
"""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
eps
:
float
=
1e-6
):
super
().
__init__
()
self
.
eps
=
eps
self
.
dim
=
embedding_dim
self
.
scale
=
nn
.
Embedding
(
num_embeddings
=
num_embeddings
,
embedding_dim
=
embedding_dim
)
self
.
shift
=
nn
.
Embedding
(
num_embeddings
=
num_embeddings
,
embedding_dim
=
embedding_dim
)
torch
.
nn
.
init
.
ones_
(
self
.
scale
.
weight
)
torch
.
nn
.
init
.
zeros_
(
self
.
shift
.
weight
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
cond_embedding_id
:
torch
.
Tensor
)
->
torch
.
Tensor
:
scale
=
self
.
scale
(
cond_embedding_id
)
shift
=
self
.
shift
(
cond_embedding_id
)
x
=
nn
.
functional
.
layer_norm
(
x
,
(
self
.
dim
,),
eps
=
self
.
eps
)
x
=
x
*
scale
+
shift
return
x
class
ResBlock1
(
nn
.
Module
):
"""
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
but without upsampling layers.
Args:
dim (int): Number of input channels.
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
Defaults to (1, 3, 5).
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
Defaults to 0.1.
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
Defaults to None.
"""
def
__init__
(
self
,
dim
:
int
,
kernel_size
:
int
=
3
,
dilation
:
Tuple
[
int
,
...]
=
(
1
,
3
,
5
),
lrelu_slope
:
float
=
0.1
,
layer_scale_init_value
:
float
=
None
,
):
super
().
__init__
()
self
.
lrelu_slope
=
lrelu_slope
self
.
convs1
=
nn
.
ModuleList
(
[
weight_norm
(
nn
.
Conv1d
(
dim
,
dim
,
kernel_size
,
1
,
dilation
=
dilation
[
0
],
padding
=
self
.
get_padding
(
kernel_size
,
dilation
[
0
]),
)
),
weight_norm
(
nn
.
Conv1d
(
dim
,
dim
,
kernel_size
,
1
,
dilation
=
dilation
[
1
],
padding
=
self
.
get_padding
(
kernel_size
,
dilation
[
1
]),
)
),
weight_norm
(
nn
.
Conv1d
(
dim
,
dim
,
kernel_size
,
1
,
dilation
=
dilation
[
2
],
padding
=
self
.
get_padding
(
kernel_size
,
dilation
[
2
]),
)
),
]
)
self
.
convs2
=
nn
.
ModuleList
(
[
weight_norm
(
nn
.
Conv1d
(
dim
,
dim
,
kernel_size
,
1
,
dilation
=
1
,
padding
=
self
.
get_padding
(
kernel_size
,
1
))),
weight_norm
(
nn
.
Conv1d
(
dim
,
dim
,
kernel_size
,
1
,
dilation
=
1
,
padding
=
self
.
get_padding
(
kernel_size
,
1
))),
weight_norm
(
nn
.
Conv1d
(
dim
,
dim
,
kernel_size
,
1
,
dilation
=
1
,
padding
=
self
.
get_padding
(
kernel_size
,
1
))),
]
)
self
.
gamma
=
nn
.
ParameterList
(
[
nn
.
Parameter
(
layer_scale_init_value
*
torch
.
ones
(
dim
,
1
),
requires_grad
=
True
)
if
layer_scale_init_value
is
not
None
else
None
,
nn
.
Parameter
(
layer_scale_init_value
*
torch
.
ones
(
dim
,
1
),
requires_grad
=
True
)
if
layer_scale_init_value
is
not
None
else
None
,
nn
.
Parameter
(
layer_scale_init_value
*
torch
.
ones
(
dim
,
1
),
requires_grad
=
True
)
if
layer_scale_init_value
is
not
None
else
None
,
]
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
for
c1
,
c2
,
gamma
in
zip
(
self
.
convs1
,
self
.
convs2
,
self
.
gamma
):
xt
=
torch
.
nn
.
functional
.
leaky_relu
(
x
,
negative_slope
=
self
.
lrelu_slope
)
xt
=
c1
(
xt
)
xt
=
torch
.
nn
.
functional
.
leaky_relu
(
xt
,
negative_slope
=
self
.
lrelu_slope
)
xt
=
c2
(
xt
)
if
gamma
is
not
None
:
xt
=
gamma
*
xt
x
=
xt
+
x
return
x
def
remove_weight_norm
(
self
):
for
l
in
self
.
convs1
:
remove_weight_norm
(
l
)
for
l
in
self
.
convs2
:
remove_weight_norm
(
l
)
@
staticmethod
def
get_padding
(
kernel_size
:
int
,
dilation
:
int
=
1
)
->
int
:
return
int
((
kernel_size
*
dilation
-
dilation
)
/
2
)
def
safe_log
(
x
:
torch
.
Tensor
,
clip_val
:
float
=
1e-7
)
->
torch
.
Tensor
:
"""
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
Args:
x (Tensor): Input tensor.
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
Returns:
Tensor: Element-wise logarithm of the input tensor with clipping applied.
"""
return
torch
.
log
(
torch
.
clip
(
x
,
min
=
clip_val
))
def
symlog
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
sign
(
x
)
*
torch
.
log1p
(
x
.
abs
())
def
symexp
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
sign
(
x
)
*
(
torch
.
exp
(
x
.
abs
())
-
1
)
examples/music_generation/inspiremusic/wavtokenizer/decoder/pretrained.py
0 → 100644
View file @
0112b0f0
import
os
from
typing
import
Tuple
,
Any
,
Union
,
Dict
import
torch
import
yaml
from
huggingface_hub
import
hf_hub_download
from
torch
import
nn
from
inspiremusic.wavtokenizer.decoder.feature_extractors
import
FeatureExtractor
,
EncodecFeatures
from
inspiremusic.wavtokenizer.decoder.heads
import
FourierHead
from
inspiremusic.wavtokenizer.decoder.models
import
Backbone
def
instantiate_class
(
args
:
Union
[
Any
,
Tuple
[
Any
,
...]],
init
:
Dict
[
str
,
Any
])
->
Any
:
"""Instantiates a class with the given args and init.
Args:
args: Positional arguments required for instantiation.
init: Dict of the form {"class_path":...,"init_args":...}.
Returns:
The instantiated class object.
"""
kwargs
=
init
.
get
(
"init_args"
,
{})
if
not
isinstance
(
args
,
tuple
):
args
=
(
args
,)
class_module
,
class_name
=
init
[
"class_path"
].
rsplit
(
"."
,
1
)
module
=
__import__
(
class_module
,
fromlist
=
[
class_name
])
args_class
=
getattr
(
module
,
class_name
)
return
args_class
(
*
args
,
**
kwargs
)
class
WavTokenizer
(
nn
.
Module
):
"""
The Vocos class represents a Fourier-based neural vocoder for audio synthesis.
This class is primarily designed for inference, with support for loading from pretrained
model checkpoints. It consists of three main components: a feature extractor,
a backbone, and a head.
"""
def
__init__
(
self
,
feature_extractor
:
FeatureExtractor
,
backbone
:
Backbone
,
head
:
FourierHead
,
):
super
().
__init__
()
self
.
feature_extractor
=
feature_extractor
self
.
backbone
=
backbone
self
.
head
=
head
@
classmethod
def
from_hparams
(
cls
,
config_path
:
str
)
->
"Vocos"
:
"""
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
with
open
(
config_path
,
"r"
)
as
f
:
config
=
yaml
.
safe_load
(
f
)
feature_extractor
=
instantiate_class
(
args
=
(),
init
=
config
[
"feature_extractor"
])
backbone
=
instantiate_class
(
args
=
(),
init
=
config
[
"backbone"
])
head
=
instantiate_class
(
args
=
(),
init
=
config
[
"head"
])
model
=
cls
(
feature_extractor
=
feature_extractor
,
backbone
=
backbone
,
head
=
head
)
return
model
@
classmethod
def
from_pretrained
(
self
,
repo_id
:
str
)
->
"Vocos"
:
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
config_path
=
hf_hub_download
(
repo_id
=
repo_id
,
filename
=
"config.yaml"
)
model_path
=
hf_hub_download
(
repo_id
=
repo_id
,
filename
=
"pytorch_model.bin"
)
model
=
self
.
from_hparams
(
config_path
)
state_dict
=
torch
.
load
(
model_path
,
map_location
=
"cpu"
)
if
isinstance
(
model
.
feature_extractor
,
EncodecFeatures
):
encodec_parameters
=
{
"feature_extractor.encodec."
+
key
:
value
for
key
,
value
in
model
.
feature_extractor
.
encodec
.
state_dict
().
items
()
}
state_dict
.
update
(
encodec_parameters
)
model
.
load_state_dict
(
state_dict
)
model
.
eval
()
return
model
@
classmethod
def
from_hparams_feat
(
cls
,
config_path
:
str
)
->
"Vocos"
:
"""
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
with
open
(
config_path
,
"r"
)
as
f
:
config
=
yaml
.
safe_load
(
f
)
feature_extractor
=
instantiate_class
(
args
=
(),
init
=
config
[
'model'
][
'init_args'
][
"feature_extractor"
])
backbone
=
instantiate_class
(
args
=
(),
init
=
config
[
'model'
][
'init_args'
][
"backbone"
])
head
=
instantiate_class
(
args
=
(),
init
=
config
[
'model'
][
'init_args'
][
"head"
])
model
=
cls
(
feature_extractor
=
feature_extractor
,
backbone
=
backbone
,
head
=
head
)
return
model
@
classmethod
def
from_pretrained_feat
(
self
,
config_path
,
model_path
):
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
model
=
self
.
from_hparams_feat
(
config_path
)
state_dict_raw
=
torch
.
load
(
model_path
,
map_location
=
"cpu"
)[
'state_dict'
]
state_dict
=
dict
()
for
k
,
v
in
state_dict_raw
.
items
():
if
k
.
startswith
(
'backbone.'
)
or
k
.
startswith
(
'head.'
)
or
k
.
startswith
(
'feature_extractor.'
):
state_dict
[
k
]
=
v
model
.
load_state_dict
(
state_dict
)
model
.
eval
()
return
model
@
classmethod
def
estimator
(
self
,
config_path
,
model_path
):
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
model
=
self
.
from_hparams_feat
(
config_path
)
state_dict_raw
=
torch
.
load
(
model_path
,
map_location
=
"cpu"
)[
'state_dict'
]
state_dict
=
dict
()
for
k
,
v
in
state_dict_raw
.
items
():
if
k
.
startswith
(
'backbone.'
)
or
k
.
startswith
(
'head.'
)
or
k
.
startswith
(
'feature_extractor.'
):
state_dict
[
k
]
=
v
model
.
load_state_dict
(
state_dict
)
model
.
eval
()
return
model
@
classmethod
def
from_pretrained0911
(
self
,
config_path
,
model_folder_path
):
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
model
=
self
.
from_hparams0802
(
config_path
)
models
=
os
.
listdir
(
model_folder_path
)
val_loss
=
[]
for
item
in
models
:
if
not
item
.
startswith
(
'vocos_'
):
continue
val_loss
.
append
(
item
[
-
11
:
-
5
])
val_loss
.
sort
()
val_loss
=
val_loss
[:
3
]
# 取前3性能较好的模型平均
state_dict
=
dict
()
state_dicts
=
[]
for
item
in
models
:
if
not
item
.
startswith
(
'vocos_'
):
continue
ll
=
item
[
-
11
:
-
5
]
if
ll
not
in
val_loss
:
continue
model_path
=
model_folder_path
+
'/'
+
item
state_dict_raw
=
torch
.
load
(
model_path
,
map_location
=
"cpu"
)[
'state_dict'
]
state_dict_single
=
dict
()
for
k
,
v
in
state_dict_raw
.
items
():
if
k
.
startswith
(
'backbone.'
)
or
k
.
startswith
(
'head.'
)
or
k
.
startswith
(
'feature_extractor.'
):
state_dict_single
[
k
]
=
v
state_dicts
.
append
(
state_dict_single
)
for
kk
in
state_dicts
[
0
].
keys
():
vv
=
state_dicts
[
0
][
kk
]
for
i
in
range
(
1
,
len
(
state_dicts
)):
ss
=
state_dicts
[
i
]
vv
+=
ss
[
kk
]
vm
=
vv
/
len
(
state_dicts
)
state_dict
[
kk
]
=
vm
model
.
load_state_dict
(
state_dict
)
model
.
eval
()
return
model
@
torch
.
inference_mode
()
def
forward
(
self
,
audio_input
:
torch
.
Tensor
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
"""
Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input,
which is then passed through the backbone and the head to reconstruct the audio output.
Args:
audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T),
where B is the batch size and L is the waveform length.
Returns:
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
"""
features
,
_
,
_
=
self
.
feature_extractor
(
audio_input
,
**
kwargs
)
# 0818
audio_output
=
self
.
decode
(
features
,
**
kwargs
)
return
audio_output
# 0818
@
torch
.
inference_mode
()
def
encode
(
self
,
audio_input
:
torch
.
Tensor
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
features
,
discrete_codes
,
_
=
self
.
feature_extractor
(
audio_input
,
**
kwargs
)
return
features
,
discrete_codes
# 0818
@
torch
.
inference_mode
()
def
encode_infer
(
self
,
audio_input
:
torch
.
Tensor
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
features
,
discrete_codes
,
_
=
self
.
feature_extractor
.
infer
(
audio_input
,
**
kwargs
)
return
features
,
discrete_codes
@
torch
.
inference_mode
()
def
infer
(
self
,
audio_input
:
torch
.
Tensor
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
_
,
discrete_codes
,
_
=
self
.
feature_extractor
.
_infer
(
audio_input
,
**
kwargs
)
discrete_codes
=
discrete_codes
.
clamp
(
min
=
0
,
max
=
16383
)
return
discrete_codes
@
torch
.
inference_mode
()
def
decode
(
self
,
features_input
:
torch
.
Tensor
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
"""
Method to decode audio waveform from already calculated features. The features input is passed through
the backbone and the head to reconstruct the audio output.
Args:
features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size,
C denotes the feature dimension, and L is the sequence length.
Returns:
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
"""
x
=
self
.
backbone
(
features_input
,
**
kwargs
)
audio_output
=
self
.
head
(
x
)
return
audio_output
@
torch
.
inference_mode
()
def
codes_to_features
(
self
,
codes
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's
codebook weights.
Args:
codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L),
where K is the number of codebooks, B is the batch size and L is the sequence length.
Returns:
Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension,
and L is the sequence length.
"""
assert
isinstance
(
self
.
feature_extractor
,
EncodecFeatures
),
"Feature extractor should be an instance of EncodecFeatures"
if
codes
.
dim
()
==
2
:
codes
=
codes
.
unsqueeze
(
1
)
n_bins
=
self
.
feature_extractor
.
encodec
.
quantizer
.
bins
offsets
=
torch
.
arange
(
0
,
n_bins
*
len
(
codes
),
n_bins
,
device
=
codes
.
device
)
embeddings_idxs
=
codes
+
offsets
.
view
(
-
1
,
1
,
1
)
tmp
=
torch
.
cat
([
vq
.
codebook
for
vq
in
self
.
feature_extractor
.
encodec
.
quantizer
.
vq
.
layers
],
dim
=
0
)
# features = torch.nn.functional.embedding(embeddings_idxs, self.feature_extractor.codebook_weights).sum(dim=0)
features
=
torch
.
nn
.
functional
.
embedding
(
embeddings_idxs
,
tmp
).
sum
(
dim
=
0
)
features
=
features
.
transpose
(
1
,
2
)
return
features
examples/music_generation/inspiremusic/wavtokenizer/decoder/pretrained_model.py
0 → 100644
View file @
0112b0f0
from
typing
import
Tuple
,
Any
,
Union
,
Dict
import
torch
import
yaml
from
huggingface_hub
import
hf_hub_download
from
torch
import
nn
from
inspiremusic.wavtokenizer.decoder.feature_extractors
import
FeatureExtractor
,
EncodecFeatures
from
inspiremusic.wavtokenizer.decoder.heads
import
FourierHead
from
inspiremusic.wavtokenizer.decoder.models
import
Backbone
from
inspiremusic.wavtokenizer.decoder.discriminators
import
MultiPeriodDiscriminator
,
MultiResolutionDiscriminator
def
instantiate_class
(
args
:
Union
[
Any
,
Tuple
[
Any
,
...]],
init
:
Dict
[
str
,
Any
])
->
Any
:
"""Instantiates a class with the given args and init.
Args:
args: Positional arguments required for instantiation.
init: Dict of the form {"class_path":...,"init_args":...}.
Returns:
The instantiated class object.
"""
kwargs
=
init
.
get
(
"init_args"
,
{})
if
not
isinstance
(
args
,
tuple
):
args
=
(
args
,)
class_module
,
class_name
=
init
[
"class_path"
].
rsplit
(
"."
,
1
)
module
=
__import__
(
class_module
,
fromlist
=
[
class_name
])
args_class
=
getattr
(
module
,
class_name
)
return
args_class
(
*
args
,
**
kwargs
)
class
WavTokenizer
(
nn
.
Module
):
"""
The Vocos class represents a Fourier-based neural vocoder for audio synthesis.
This class is primarily designed for inference, with support for loading from pretrained
model checkpoints. It consists of three main components: a feature extractor,
a backbone, and a head.
"""
def
__init__
(
self
,
feature_extractor
:
FeatureExtractor
,
backbone
:
Backbone
,
head
:
FourierHead
,
multiperioddisc
:
MultiPeriodDiscriminator
,
multiresddisc
:
MultiResolutionDiscriminator
,
):
super
().
__init__
()
self
.
feature_extractor
=
feature_extractor
self
.
backbone
=
backbone
self
.
head
=
head
self
.
multiperioddisc
=
multiperioddisc
self
.
multiresddisc
=
multiresddisc
@
classmethod
def
from_hparams0828
(
cls
,
config_path
:
str
)
->
"Vocos"
:
"""
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
with
open
(
config_path
,
"r"
)
as
f
:
config
=
yaml
.
safe_load
(
f
)
feature_extractor
=
instantiate_class
(
args
=
(),
init
=
config
[
'model'
][
'init_args'
][
"feature_extractor"
])
backbone
=
instantiate_class
(
args
=
(),
init
=
config
[
'model'
][
'init_args'
][
"backbone"
])
head
=
instantiate_class
(
args
=
(),
init
=
config
[
'model'
][
'init_args'
][
"head"
])
model
=
cls
(
feature_extractor
=
feature_extractor
,
backbone
=
backbone
,
head
=
head
,
multiperioddisc
=
MultiPeriodDiscriminator
(
num_embeddings
=
4
),
multiresddisc
=
MultiResolutionDiscriminator
(
num_embeddings
=
4
))
return
model
@
classmethod
def
from_pretrained0828
(
self
,
config_path
,
model_path
):
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
model
=
self
.
from_hparams0828
(
config_path
)
state_dict_raw
=
torch
.
load
(
model_path
,
map_location
=
"cpu"
)[
'state_dict'
]
state_dict
=
dict
()
for
k
,
v
in
state_dict_raw
.
items
():
if
k
.
startswith
(
'backbone.'
)
or
k
.
startswith
(
'head.'
)
or
k
.
startswith
(
'feature_extractor.'
)
\
or
k
.
startswith
(
'multiperioddisc.'
)
or
k
.
startswith
(
'multiresddisc.'
):
state_dict
[
k
]
=
v
# if isinstance(model.feature_extractor, EncodecFeatures):
# encodec_parameters = {
# "feature_extractor.encodec." + key: value
# for key, value in model.feature_extractor.encodec.state_dict().items()
# }
# state_dict.update(encodec_parameters)
model
.
load_state_dict
(
state_dict
)
return
model
@
classmethod
def
from_hparams0802
(
cls
,
config_path
:
str
)
->
"Vocos"
:
"""
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
with
open
(
config_path
,
"r"
)
as
f
:
config
=
yaml
.
safe_load
(
f
)
feature_extractor
=
instantiate_class
(
args
=
(),
init
=
config
[
'model'
][
'init_args'
][
"feature_extractor"
])
backbone
=
instantiate_class
(
args
=
(),
init
=
config
[
'model'
][
'init_args'
][
"backbone"
])
head
=
instantiate_class
(
args
=
(),
init
=
config
[
'model'
][
'init_args'
][
"head"
])
model
=
cls
(
feature_extractor
=
feature_extractor
,
backbone
=
backbone
,
head
=
head
)
return
model
@
classmethod
def
from_pretrained0802
(
self
,
config_path
,
model_path
):
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
model
=
self
.
from_hparams0802
(
config_path
)
state_dict_raw
=
torch
.
load
(
model_path
,
map_location
=
"cpu"
)[
'state_dict'
]
state_dict
=
dict
()
for
k
,
v
in
state_dict_raw
.
items
():
if
k
.
startswith
(
'backbone.'
)
or
k
.
startswith
(
'head.'
)
or
k
.
startswith
(
'feature_extractor.'
):
state_dict
[
k
]
=
v
# if isinstance(model.feature_extractor, EncodecFeatures):
# encodec_parameters = {
# "feature_extractor.encodec." + key: value
# for key, value in model.feature_extractor.encodec.state_dict().items()
# }
# state_dict.update(encodec_parameters)
model
.
load_state_dict
(
state_dict
)
model
.
eval
()
return
model
@
torch
.
inference_mode
()
def
forward
(
self
,
audio_input
:
torch
.
Tensor
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
"""
Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input,
which is then passed through the backbone and the head to reconstruct the audio output.
Args:
audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T),
where B is the batch size and L is the waveform length.
Returns:
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
"""
features
,
_
,
_
=
self
.
feature_extractor
(
audio_input
,
**
kwargs
)
# 0818
audio_output
=
self
.
decode
(
features
,
**
kwargs
)
return
audio_output
# 0818
@
torch
.
inference_mode
()
def
encode
(
self
,
audio_input
:
torch
.
Tensor
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
features
,
_
,
_
=
self
.
feature_extractor
(
audio_input
,
**
kwargs
)
return
features
@
torch
.
inference_mode
()
def
decode
(
self
,
features_input
:
torch
.
Tensor
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
"""
Method to decode audio waveform from already calculated features. The features input is passed through
the backbone and the head to reconstruct the audio output.
Args:
features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size,
C denotes the feature dimension, and L is the sequence length.
Returns:
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
"""
x
=
self
.
backbone
(
features_input
,
**
kwargs
)
audio_output
=
self
.
head
(
x
)
return
audio_output
@
torch
.
inference_mode
()
def
codes_to_features
(
self
,
codes
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's
codebook weights.
Args:
codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L),
where K is the number of codebooks, B is the batch size and L is the sequence length.
Returns:
Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension,
and L is the sequence length.
"""
assert
isinstance
(
self
.
feature_extractor
,
EncodecFeatures
),
"Feature extractor should be an instance of EncodecFeatures"
if
codes
.
dim
()
==
2
:
codes
=
codes
.
unsqueeze
(
1
)
n_bins
=
self
.
feature_extractor
.
encodec
.
quantizer
.
bins
offsets
=
torch
.
arange
(
0
,
n_bins
*
len
(
codes
),
n_bins
,
device
=
codes
.
device
)
embeddings_idxs
=
codes
+
offsets
.
view
(
-
1
,
1
,
1
)
features
=
torch
.
nn
.
functional
.
embedding
(
embeddings_idxs
,
self
.
feature_extractor
.
codebook_weights
).
sum
(
dim
=
0
)
features
=
features
.
transpose
(
1
,
2
)
return
features
examples/music_generation/inspiremusic/wavtokenizer/decoder/spectral_ops.py
0 → 100644
View file @
0112b0f0
import
numpy
as
np
import
scipy
import
torch
from
torch
import
nn
,
view_as_real
,
view_as_complex
import
pdb
class
ISTFT
(
nn
.
Module
):
"""
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
See issue: https://github.com/pytorch/pytorch/issues/62323
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
The NOLA constraint is met as we trim padded samples anyway.
Args:
n_fft (int): Size of Fourier transform.
hop_length (int): The distance between neighboring sliding window frames.
win_length (int): The size of window frame and STFT filter.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def
__init__
(
self
,
n_fft
:
int
,
hop_length
:
int
,
win_length
:
int
,
padding
:
str
=
"same"
):
super
().
__init__
()
if
padding
not
in
[
"center"
,
"same"
]:
raise
ValueError
(
"Padding must be 'center' or 'same'."
)
self
.
padding
=
padding
self
.
n_fft
=
n_fft
self
.
hop_length
=
hop_length
self
.
win_length
=
win_length
window
=
torch
.
hann_window
(
win_length
)
self
.
register_buffer
(
"window"
,
window
)
def
forward
(
self
,
spec
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
Args:
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
N is the number of frequency bins, and T is the number of time frames.
Returns:
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
"""
if
self
.
padding
==
"center"
:
# Fallback to pytorch native implementation
return
torch
.
istft
(
spec
,
self
.
n_fft
,
self
.
hop_length
,
self
.
win_length
,
self
.
window
,
center
=
True
)
elif
self
.
padding
==
"same"
:
pad
=
(
self
.
win_length
-
self
.
hop_length
)
//
2
else
:
raise
ValueError
(
"Padding must be 'center' or 'same'."
)
assert
spec
.
dim
()
==
3
,
"Expected a 3D tensor as input"
B
,
N
,
T
=
spec
.
shape
# Inverse FFT
ifft
=
torch
.
fft
.
irfft
(
spec
,
self
.
n_fft
,
dim
=
1
,
norm
=
"backward"
)
ifft
=
ifft
*
self
.
window
[
None
,
:,
None
]
# Overlap and Add
output_size
=
(
T
-
1
)
*
self
.
hop_length
+
self
.
win_length
y
=
torch
.
nn
.
functional
.
fold
(
ifft
,
output_size
=
(
1
,
output_size
),
kernel_size
=
(
1
,
self
.
win_length
),
stride
=
(
1
,
self
.
hop_length
),
)[:,
0
,
0
,
pad
:
-
pad
]
# Window envelope
window_sq
=
self
.
window
.
square
().
expand
(
1
,
T
,
-
1
).
transpose
(
1
,
2
)
window_envelope
=
torch
.
nn
.
functional
.
fold
(
window_sq
,
output_size
=
(
1
,
output_size
),
kernel_size
=
(
1
,
self
.
win_length
),
stride
=
(
1
,
self
.
hop_length
),
).
squeeze
()[
pad
:
-
pad
]
# Normalize
# assert (window_envelope > 1e-11).all()
if
not
torch
.
all
(
window_envelope
>
1e-11
):
window_envelope
=
torch
.
clamp
(
window_envelope
,
min
=
1e-11
)
y
=
y
/
window_envelope
return
y
def
onnx_forward
(
self
,
spec
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
Args:
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
N is the number of frequency bins, and T is the number of time frames.
Returns:
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
"""
if
self
.
padding
==
"center"
:
# Fallback to pytorch native implementation
return
torch
.
istft
(
spec
,
self
.
n_fft
,
self
.
hop_length
,
self
.
win_length
,
self
.
window
,
center
=
True
)
elif
self
.
padding
==
"same"
:
pad
=
(
self
.
win_length
-
self
.
hop_length
)
//
2
else
:
raise
ValueError
(
"Padding must be 'center' or 'same'."
)
assert
spec
.
dim
()
==
3
,
"Expected a 3D tensor as input"
B
,
N
,
T
=
spec
.
shape
pdb
.
set_trace
()
# Inverse FFT
ifft
=
torch
.
fft
.
irfft
(
spec
,
self
.
n_fft
,
dim
=
1
,
norm
=
"backward"
)
ifft
=
ifft
*
self
.
window
[
None
,
:,
None
]
# Overlap and Add
output_size
=
(
T
-
1
)
*
self
.
hop_length
+
self
.
win_length
y
=
torch
.
nn
.
functional
.
fold
(
ifft
,
output_size
=
(
1
,
output_size
),
kernel_size
=
(
1
,
self
.
win_length
),
stride
=
(
1
,
self
.
hop_length
),
)[:,
0
,
0
,
pad
:
-
pad
]
# Window envelope
window_sq
=
self
.
window
.
square
().
expand
(
1
,
T
,
-
1
).
transpose
(
1
,
2
)
window_envelope
=
torch
.
nn
.
functional
.
fold
(
window_sq
,
output_size
=
(
1
,
output_size
),
kernel_size
=
(
1
,
self
.
win_length
),
stride
=
(
1
,
self
.
hop_length
),
).
squeeze
()[
pad
:
-
pad
]
# Normalize
# assert (window_envelope > 1e-11).all()
if
not
torch
.
all
(
window_envelope
>
1e-11
):
window_envelope
=
torch
.
clamp
(
window_envelope
,
min
=
1e-11
)
y
=
y
/
window_envelope
return
y
class
MDCT
(
nn
.
Module
):
"""
Modified Discrete Cosine Transform (MDCT) module.
Args:
frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def
__init__
(
self
,
frame_len
:
int
,
padding
:
str
=
"same"
):
super
().
__init__
()
if
padding
not
in
[
"center"
,
"same"
]:
raise
ValueError
(
"Padding must be 'center' or 'same'."
)
self
.
padding
=
padding
self
.
frame_len
=
frame_len
N
=
frame_len
//
2
n0
=
(
N
+
1
)
/
2
window
=
torch
.
from_numpy
(
scipy
.
signal
.
cosine
(
frame_len
)).
float
()
self
.
register_buffer
(
"window"
,
window
)
pre_twiddle
=
torch
.
exp
(
-
1j
*
torch
.
pi
*
torch
.
arange
(
frame_len
)
/
frame_len
)
post_twiddle
=
torch
.
exp
(
-
1j
*
torch
.
pi
*
n0
*
(
torch
.
arange
(
N
)
+
0.5
)
/
N
)
# view_as_real: NCCL Backend does not support ComplexFloat data type
# https://github.com/pytorch/pytorch/issues/71613
self
.
register_buffer
(
"pre_twiddle"
,
view_as_real
(
pre_twiddle
))
self
.
register_buffer
(
"post_twiddle"
,
view_as_real
(
post_twiddle
))
def
forward
(
self
,
audio
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
Args:
audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
and T is the length of the audio.
Returns:
Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
and N is the number of frequency bins.
"""
if
self
.
padding
==
"center"
:
audio
=
torch
.
nn
.
functional
.
pad
(
audio
,
(
self
.
frame_len
//
2
,
self
.
frame_len
//
2
))
elif
self
.
padding
==
"same"
:
# hop_length is 1/2 frame_len
audio
=
torch
.
nn
.
functional
.
pad
(
audio
,
(
self
.
frame_len
//
4
,
self
.
frame_len
//
4
))
else
:
raise
ValueError
(
"Padding must be 'center' or 'same'."
)
x
=
audio
.
unfold
(
-
1
,
self
.
frame_len
,
self
.
frame_len
//
2
)
N
=
self
.
frame_len
//
2
x
=
x
*
self
.
window
.
expand
(
x
.
shape
)
X
=
torch
.
fft
.
fft
(
x
*
view_as_complex
(
self
.
pre_twiddle
).
expand
(
x
.
shape
),
dim
=-
1
)[...,
:
N
]
res
=
X
*
view_as_complex
(
self
.
post_twiddle
).
expand
(
X
.
shape
)
*
np
.
sqrt
(
1
/
N
)
return
torch
.
real
(
res
)
*
np
.
sqrt
(
2
)
class
IMDCT
(
nn
.
Module
):
"""
Inverse Modified Discrete Cosine Transform (IMDCT) module.
Args:
frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def
__init__
(
self
,
frame_len
:
int
,
padding
:
str
=
"same"
):
super
().
__init__
()
if
padding
not
in
[
"center"
,
"same"
]:
raise
ValueError
(
"Padding must be 'center' or 'same'."
)
self
.
padding
=
padding
self
.
frame_len
=
frame_len
N
=
frame_len
//
2
n0
=
(
N
+
1
)
/
2
window
=
torch
.
from_numpy
(
scipy
.
signal
.
cosine
(
frame_len
)).
float
()
self
.
register_buffer
(
"window"
,
window
)
pre_twiddle
=
torch
.
exp
(
1j
*
torch
.
pi
*
n0
*
torch
.
arange
(
N
*
2
)
/
N
)
post_twiddle
=
torch
.
exp
(
1j
*
torch
.
pi
*
(
torch
.
arange
(
N
*
2
)
+
n0
)
/
(
N
*
2
))
self
.
register_buffer
(
"pre_twiddle"
,
view_as_real
(
pre_twiddle
))
self
.
register_buffer
(
"post_twiddle"
,
view_as_real
(
post_twiddle
))
def
forward
(
self
,
X
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
Args:
X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
L is the number of frames, and N is the number of frequency bins.
Returns:
Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
"""
B
,
L
,
N
=
X
.
shape
Y
=
torch
.
zeros
((
B
,
L
,
N
*
2
),
dtype
=
X
.
dtype
,
device
=
X
.
device
)
Y
[...,
:
N
]
=
X
Y
[...,
N
:]
=
-
1
*
torch
.
conj
(
torch
.
flip
(
X
,
dims
=
(
-
1
,)))
y
=
torch
.
fft
.
ifft
(
Y
*
view_as_complex
(
self
.
pre_twiddle
).
expand
(
Y
.
shape
),
dim
=-
1
)
y
=
torch
.
real
(
y
*
view_as_complex
(
self
.
post_twiddle
).
expand
(
y
.
shape
))
*
np
.
sqrt
(
N
)
*
np
.
sqrt
(
2
)
result
=
y
*
self
.
window
.
expand
(
y
.
shape
)
output_size
=
(
1
,
(
L
+
1
)
*
N
)
audio
=
torch
.
nn
.
functional
.
fold
(
result
.
transpose
(
1
,
2
),
output_size
=
output_size
,
kernel_size
=
(
1
,
self
.
frame_len
),
stride
=
(
1
,
self
.
frame_len
//
2
),
)[:,
0
,
0
,
:]
if
self
.
padding
==
"center"
:
pad
=
self
.
frame_len
//
2
elif
self
.
padding
==
"same"
:
pad
=
self
.
frame_len
//
4
else
:
raise
ValueError
(
"Padding must be 'center' or 'same'."
)
audio
=
audio
[:,
pad
:
-
pad
]
return
audio
Prev
1
…
4
5
6
7
8
9
10
11
12
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment