Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
index-tts-vllm
Commits
ab9c00af
Commit
ab9c00af
authored
Jan 07, 2026
by
yangzhong
Browse files
init submission
parents
Pipeline
#3176
failed with stages
in 0 seconds
Changes
316
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2113 additions
and
0 deletions
+2113
-0
indextts/s2mel/dac/__pycache__/__init__.cpython-310.pyc
indextts/s2mel/dac/__pycache__/__init__.cpython-310.pyc
+0
-0
indextts/s2mel/dac/model/__init__.py
indextts/s2mel/dac/model/__init__.py
+4
-0
indextts/s2mel/dac/model/__pycache__/__init__.cpython-310.pyc
...xtts/s2mel/dac/model/__pycache__/__init__.cpython-310.pyc
+0
-0
indextts/s2mel/dac/model/__pycache__/base.cpython-310.pyc
indextts/s2mel/dac/model/__pycache__/base.cpython-310.pyc
+0
-0
indextts/s2mel/dac/model/__pycache__/dac.cpython-310.pyc
indextts/s2mel/dac/model/__pycache__/dac.cpython-310.pyc
+0
-0
indextts/s2mel/dac/model/__pycache__/discriminator.cpython-310.pyc
...s2mel/dac/model/__pycache__/discriminator.cpython-310.pyc
+0
-0
indextts/s2mel/dac/model/__pycache__/encodec.cpython-310.pyc
indextts/s2mel/dac/model/__pycache__/encodec.cpython-310.pyc
+0
-0
indextts/s2mel/dac/model/base.py
indextts/s2mel/dac/model/base.py
+294
-0
indextts/s2mel/dac/model/dac.py
indextts/s2mel/dac/model/dac.py
+400
-0
indextts/s2mel/dac/model/discriminator.py
indextts/s2mel/dac/model/discriminator.py
+228
-0
indextts/s2mel/dac/model/encodec.py
indextts/s2mel/dac/model/encodec.py
+321
-0
indextts/s2mel/dac/nn/__init__.py
indextts/s2mel/dac/nn/__init__.py
+3
-0
indextts/s2mel/dac/nn/__pycache__/__init__.cpython-310.pyc
indextts/s2mel/dac/nn/__pycache__/__init__.cpython-310.pyc
+0
-0
indextts/s2mel/dac/nn/__pycache__/layers.cpython-310.pyc
indextts/s2mel/dac/nn/__pycache__/layers.cpython-310.pyc
+0
-0
indextts/s2mel/dac/nn/__pycache__/loss.cpython-310.pyc
indextts/s2mel/dac/nn/__pycache__/loss.cpython-310.pyc
+0
-0
indextts/s2mel/dac/nn/__pycache__/quantize.cpython-310.pyc
indextts/s2mel/dac/nn/__pycache__/quantize.cpython-310.pyc
+0
-0
indextts/s2mel/dac/nn/layers.py
indextts/s2mel/dac/nn/layers.py
+33
-0
indextts/s2mel/dac/nn/loss.py
indextts/s2mel/dac/nn/loss.py
+368
-0
indextts/s2mel/dac/nn/quantize.py
indextts/s2mel/dac/nn/quantize.py
+339
-0
indextts/s2mel/dac/utils/__init__.py
indextts/s2mel/dac/utils/__init__.py
+123
-0
No files found.
indextts/s2mel/dac/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
ab9c00af
File added
indextts/s2mel/dac/model/__init__.py
0 → 100644
View file @
ab9c00af
from
.base
import
CodecMixin
from
.base
import
DACFile
from
.dac
import
DAC
from
.discriminator
import
Discriminator
indextts/s2mel/dac/model/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
ab9c00af
File added
indextts/s2mel/dac/model/__pycache__/base.cpython-310.pyc
0 → 100644
View file @
ab9c00af
File added
indextts/s2mel/dac/model/__pycache__/dac.cpython-310.pyc
0 → 100644
View file @
ab9c00af
File added
indextts/s2mel/dac/model/__pycache__/discriminator.cpython-310.pyc
0 → 100644
View file @
ab9c00af
File added
indextts/s2mel/dac/model/__pycache__/encodec.cpython-310.pyc
0 → 100644
View file @
ab9c00af
File added
indextts/s2mel/dac/model/base.py
0 → 100644
View file @
ab9c00af
import
math
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
Union
import
numpy
as
np
import
torch
import
tqdm
from
audiotools
import
AudioSignal
from
torch
import
nn
SUPPORTED_VERSIONS
=
[
"1.0.0"
]
@
dataclass
class
DACFile
:
codes
:
torch
.
Tensor
# Metadata
chunk_length
:
int
original_length
:
int
input_db
:
float
channels
:
int
sample_rate
:
int
padding
:
bool
dac_version
:
str
def
save
(
self
,
path
):
artifacts
=
{
"codes"
:
self
.
codes
.
numpy
().
astype
(
np
.
uint16
),
"metadata"
:
{
"input_db"
:
self
.
input_db
.
numpy
().
astype
(
np
.
float32
),
"original_length"
:
self
.
original_length
,
"sample_rate"
:
self
.
sample_rate
,
"chunk_length"
:
self
.
chunk_length
,
"channels"
:
self
.
channels
,
"padding"
:
self
.
padding
,
"dac_version"
:
SUPPORTED_VERSIONS
[
-
1
],
},
}
path
=
Path
(
path
).
with_suffix
(
".dac"
)
with
open
(
path
,
"wb"
)
as
f
:
np
.
save
(
f
,
artifacts
)
return
path
@
classmethod
def
load
(
cls
,
path
):
artifacts
=
np
.
load
(
path
,
allow_pickle
=
True
)[()]
codes
=
torch
.
from_numpy
(
artifacts
[
"codes"
].
astype
(
int
))
if
artifacts
[
"metadata"
].
get
(
"dac_version"
,
None
)
not
in
SUPPORTED_VERSIONS
:
raise
RuntimeError
(
f
"Given file
{
path
}
can't be loaded with this version of descript-audio-codec."
)
return
cls
(
codes
=
codes
,
**
artifacts
[
"metadata"
])
class
CodecMixin
:
@
property
def
padding
(
self
):
if
not
hasattr
(
self
,
"_padding"
):
self
.
_padding
=
True
return
self
.
_padding
@
padding
.
setter
def
padding
(
self
,
value
):
assert
isinstance
(
value
,
bool
)
layers
=
[
l
for
l
in
self
.
modules
()
if
isinstance
(
l
,
(
nn
.
Conv1d
,
nn
.
ConvTranspose1d
))
]
for
layer
in
layers
:
if
value
:
if
hasattr
(
layer
,
"original_padding"
):
layer
.
padding
=
layer
.
original_padding
else
:
layer
.
original_padding
=
layer
.
padding
layer
.
padding
=
tuple
(
0
for
_
in
range
(
len
(
layer
.
padding
)))
self
.
_padding
=
value
def
get_delay
(
self
):
# Any number works here, delay is invariant to input length
l_out
=
self
.
get_output_length
(
0
)
L
=
l_out
layers
=
[]
for
layer
in
self
.
modules
():
if
isinstance
(
layer
,
(
nn
.
Conv1d
,
nn
.
ConvTranspose1d
)):
layers
.
append
(
layer
)
for
layer
in
reversed
(
layers
):
d
=
layer
.
dilation
[
0
]
k
=
layer
.
kernel_size
[
0
]
s
=
layer
.
stride
[
0
]
if
isinstance
(
layer
,
nn
.
ConvTranspose1d
):
L
=
((
L
-
d
*
(
k
-
1
)
-
1
)
/
s
)
+
1
elif
isinstance
(
layer
,
nn
.
Conv1d
):
L
=
(
L
-
1
)
*
s
+
d
*
(
k
-
1
)
+
1
L
=
math
.
ceil
(
L
)
l_in
=
L
return
(
l_in
-
l_out
)
//
2
def
get_output_length
(
self
,
input_length
):
L
=
input_length
# Calculate output length
for
layer
in
self
.
modules
():
if
isinstance
(
layer
,
(
nn
.
Conv1d
,
nn
.
ConvTranspose1d
)):
d
=
layer
.
dilation
[
0
]
k
=
layer
.
kernel_size
[
0
]
s
=
layer
.
stride
[
0
]
if
isinstance
(
layer
,
nn
.
Conv1d
):
L
=
((
L
-
d
*
(
k
-
1
)
-
1
)
/
s
)
+
1
elif
isinstance
(
layer
,
nn
.
ConvTranspose1d
):
L
=
(
L
-
1
)
*
s
+
d
*
(
k
-
1
)
+
1
L
=
math
.
floor
(
L
)
return
L
@
torch
.
no_grad
()
def
compress
(
self
,
audio_path_or_signal
:
Union
[
str
,
Path
,
AudioSignal
],
win_duration
:
float
=
1.0
,
verbose
:
bool
=
False
,
normalize_db
:
float
=
-
16
,
n_quantizers
:
int
=
None
,
)
->
DACFile
:
"""Processes an audio signal from a file or AudioSignal object into
discrete codes. This function processes the signal in short windows,
using constant GPU memory.
Parameters
----------
audio_path_or_signal : Union[str, Path, AudioSignal]
audio signal to reconstruct
win_duration : float, optional
window duration in seconds, by default 5.0
verbose : bool, optional
by default False
normalize_db : float, optional
normalize db, by default -16
Returns
-------
DACFile
Object containing compressed codes and metadata
required for decompression
"""
audio_signal
=
audio_path_or_signal
if
isinstance
(
audio_signal
,
(
str
,
Path
)):
audio_signal
=
AudioSignal
.
load_from_file_with_ffmpeg
(
str
(
audio_signal
))
self
.
eval
()
original_padding
=
self
.
padding
original_device
=
audio_signal
.
device
audio_signal
=
audio_signal
.
clone
()
original_sr
=
audio_signal
.
sample_rate
resample_fn
=
audio_signal
.
resample
loudness_fn
=
audio_signal
.
loudness
# If audio is > 10 minutes long, use the ffmpeg versions
if
audio_signal
.
signal_duration
>=
10
*
60
*
60
:
resample_fn
=
audio_signal
.
ffmpeg_resample
loudness_fn
=
audio_signal
.
ffmpeg_loudness
original_length
=
audio_signal
.
signal_length
resample_fn
(
self
.
sample_rate
)
input_db
=
loudness_fn
()
if
normalize_db
is
not
None
:
audio_signal
.
normalize
(
normalize_db
)
audio_signal
.
ensure_max_of_audio
()
nb
,
nac
,
nt
=
audio_signal
.
audio_data
.
shape
audio_signal
.
audio_data
=
audio_signal
.
audio_data
.
reshape
(
nb
*
nac
,
1
,
nt
)
win_duration
=
(
audio_signal
.
signal_duration
if
win_duration
is
None
else
win_duration
)
if
audio_signal
.
signal_duration
<=
win_duration
:
# Unchunked compression (used if signal length < win duration)
self
.
padding
=
True
n_samples
=
nt
hop
=
nt
else
:
# Chunked inference
self
.
padding
=
False
# Zero-pad signal on either side by the delay
audio_signal
.
zero_pad
(
self
.
delay
,
self
.
delay
)
n_samples
=
int
(
win_duration
*
self
.
sample_rate
)
# Round n_samples to nearest hop length multiple
n_samples
=
int
(
math
.
ceil
(
n_samples
/
self
.
hop_length
)
*
self
.
hop_length
)
hop
=
self
.
get_output_length
(
n_samples
)
codes
=
[]
range_fn
=
range
if
not
verbose
else
tqdm
.
trange
for
i
in
range_fn
(
0
,
nt
,
hop
):
x
=
audio_signal
[...,
i
:
i
+
n_samples
]
x
=
x
.
zero_pad
(
0
,
max
(
0
,
n_samples
-
x
.
shape
[
-
1
]))
audio_data
=
x
.
audio_data
.
to
(
self
.
device
)
audio_data
=
self
.
preprocess
(
audio_data
,
self
.
sample_rate
)
_
,
c
,
_
,
_
,
_
=
self
.
encode
(
audio_data
,
n_quantizers
)
codes
.
append
(
c
.
to
(
original_device
))
chunk_length
=
c
.
shape
[
-
1
]
codes
=
torch
.
cat
(
codes
,
dim
=-
1
)
dac_file
=
DACFile
(
codes
=
codes
,
chunk_length
=
chunk_length
,
original_length
=
original_length
,
input_db
=
input_db
,
channels
=
nac
,
sample_rate
=
original_sr
,
padding
=
self
.
padding
,
dac_version
=
SUPPORTED_VERSIONS
[
-
1
],
)
if
n_quantizers
is
not
None
:
codes
=
codes
[:,
:
n_quantizers
,
:]
self
.
padding
=
original_padding
return
dac_file
@
torch
.
no_grad
()
def
decompress
(
self
,
obj
:
Union
[
str
,
Path
,
DACFile
],
verbose
:
bool
=
False
,
)
->
AudioSignal
:
"""Reconstruct audio from a given .dac file
Parameters
----------
obj : Union[str, Path, DACFile]
.dac file location or corresponding DACFile object.
verbose : bool, optional
Prints progress if True, by default False
Returns
-------
AudioSignal
Object with the reconstructed audio
"""
self
.
eval
()
if
isinstance
(
obj
,
(
str
,
Path
)):
obj
=
DACFile
.
load
(
obj
)
original_padding
=
self
.
padding
self
.
padding
=
obj
.
padding
range_fn
=
range
if
not
verbose
else
tqdm
.
trange
codes
=
obj
.
codes
original_device
=
codes
.
device
chunk_length
=
obj
.
chunk_length
recons
=
[]
for
i
in
range_fn
(
0
,
codes
.
shape
[
-
1
],
chunk_length
):
c
=
codes
[...,
i
:
i
+
chunk_length
].
to
(
self
.
device
)
z
=
self
.
quantizer
.
from_codes
(
c
)[
0
]
r
=
self
.
decode
(
z
)
recons
.
append
(
r
.
to
(
original_device
))
recons
=
torch
.
cat
(
recons
,
dim
=-
1
)
recons
=
AudioSignal
(
recons
,
self
.
sample_rate
)
resample_fn
=
recons
.
resample
loudness_fn
=
recons
.
loudness
# If audio is > 10 minutes long, use the ffmpeg versions
if
recons
.
signal_duration
>=
10
*
60
*
60
:
resample_fn
=
recons
.
ffmpeg_resample
loudness_fn
=
recons
.
ffmpeg_loudness
recons
.
normalize
(
obj
.
input_db
)
resample_fn
(
obj
.
sample_rate
)
recons
=
recons
[...,
:
obj
.
original_length
]
loudness_fn
()
recons
.
audio_data
=
recons
.
audio_data
.
reshape
(
-
1
,
obj
.
channels
,
obj
.
original_length
)
self
.
padding
=
original_padding
return
recons
indextts/s2mel/dac/model/dac.py
0 → 100644
View file @
ab9c00af
import
math
from
typing
import
List
from
typing
import
Union
import
numpy
as
np
import
torch
from
audiotools
import
AudioSignal
from
audiotools.ml
import
BaseModel
from
torch
import
nn
from
.base
import
CodecMixin
from
indextts.s2mel.dac.nn.layers
import
Snake1d
from
indextts.s2mel.dac.nn.layers
import
WNConv1d
from
indextts.s2mel.dac.nn.layers
import
WNConvTranspose1d
from
indextts.s2mel.dac.nn.quantize
import
ResidualVectorQuantize
from
.encodec
import
SConv1d
,
SConvTranspose1d
,
SLSTM
def
init_weights
(
m
):
if
isinstance
(
m
,
nn
.
Conv1d
):
nn
.
init
.
trunc_normal_
(
m
.
weight
,
std
=
0.02
)
nn
.
init
.
constant_
(
m
.
bias
,
0
)
class
ResidualUnit
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
=
16
,
dilation
:
int
=
1
,
causal
:
bool
=
False
):
super
().
__init__
()
conv1d_type
=
SConv1d
# if causal else WNConv1d
pad
=
((
7
-
1
)
*
dilation
)
//
2
self
.
block
=
nn
.
Sequential
(
Snake1d
(
dim
),
conv1d_type
(
dim
,
dim
,
kernel_size
=
7
,
dilation
=
dilation
,
padding
=
pad
,
causal
=
causal
,
norm
=
'weight_norm'
),
Snake1d
(
dim
),
conv1d_type
(
dim
,
dim
,
kernel_size
=
1
,
causal
=
causal
,
norm
=
'weight_norm'
),
)
def
forward
(
self
,
x
):
y
=
self
.
block
(
x
)
pad
=
(
x
.
shape
[
-
1
]
-
y
.
shape
[
-
1
])
//
2
if
pad
>
0
:
x
=
x
[...,
pad
:
-
pad
]
return
x
+
y
class
EncoderBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
=
16
,
stride
:
int
=
1
,
causal
:
bool
=
False
):
super
().
__init__
()
conv1d_type
=
SConv1d
# if causal else WNConv1d
self
.
block
=
nn
.
Sequential
(
ResidualUnit
(
dim
//
2
,
dilation
=
1
,
causal
=
causal
),
ResidualUnit
(
dim
//
2
,
dilation
=
3
,
causal
=
causal
),
ResidualUnit
(
dim
//
2
,
dilation
=
9
,
causal
=
causal
),
Snake1d
(
dim
//
2
),
conv1d_type
(
dim
//
2
,
dim
,
kernel_size
=
2
*
stride
,
stride
=
stride
,
padding
=
math
.
ceil
(
stride
/
2
),
causal
=
causal
,
norm
=
'weight_norm'
,
),
)
def
forward
(
self
,
x
):
return
self
.
block
(
x
)
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
d_model
:
int
=
64
,
strides
:
list
=
[
2
,
4
,
8
,
8
],
d_latent
:
int
=
64
,
causal
:
bool
=
False
,
lstm
:
int
=
2
,
):
super
().
__init__
()
conv1d_type
=
SConv1d
# if causal else WNConv1d
# Create first convolution
self
.
block
=
[
conv1d_type
(
1
,
d_model
,
kernel_size
=
7
,
padding
=
3
,
causal
=
causal
,
norm
=
'weight_norm'
)]
# Create EncoderBlocks that double channels as they downsample by `stride`
for
stride
in
strides
:
d_model
*=
2
self
.
block
+=
[
EncoderBlock
(
d_model
,
stride
=
stride
,
causal
=
causal
)]
# Add LSTM if needed
self
.
use_lstm
=
lstm
if
lstm
:
self
.
block
+=
[
SLSTM
(
d_model
,
lstm
)]
# Create last convolution
self
.
block
+=
[
Snake1d
(
d_model
),
conv1d_type
(
d_model
,
d_latent
,
kernel_size
=
3
,
padding
=
1
,
causal
=
causal
,
norm
=
'weight_norm'
),
]
# Wrap black into nn.Sequential
self
.
block
=
nn
.
Sequential
(
*
self
.
block
)
self
.
enc_dim
=
d_model
def
forward
(
self
,
x
):
return
self
.
block
(
x
)
def
reset_cache
(
self
):
# recursively find all submodules named SConv1d in self.block and use their reset_cache method
def
reset_cache
(
m
):
if
isinstance
(
m
,
SConv1d
)
or
isinstance
(
m
,
SLSTM
):
m
.
reset_cache
()
return
for
child
in
m
.
children
():
reset_cache
(
child
)
reset_cache
(
self
.
block
)
class
DecoderBlock
(
nn
.
Module
):
def
__init__
(
self
,
input_dim
:
int
=
16
,
output_dim
:
int
=
8
,
stride
:
int
=
1
,
causal
:
bool
=
False
):
super
().
__init__
()
conv1d_type
=
SConvTranspose1d
#if causal else WNConvTranspose1d
self
.
block
=
nn
.
Sequential
(
Snake1d
(
input_dim
),
conv1d_type
(
input_dim
,
output_dim
,
kernel_size
=
2
*
stride
,
stride
=
stride
,
padding
=
math
.
ceil
(
stride
/
2
),
causal
=
causal
,
norm
=
'weight_norm'
),
ResidualUnit
(
output_dim
,
dilation
=
1
,
causal
=
causal
),
ResidualUnit
(
output_dim
,
dilation
=
3
,
causal
=
causal
),
ResidualUnit
(
output_dim
,
dilation
=
9
,
causal
=
causal
),
)
def
forward
(
self
,
x
):
return
self
.
block
(
x
)
class
Decoder
(
nn
.
Module
):
def
__init__
(
self
,
input_channel
,
channels
,
rates
,
d_out
:
int
=
1
,
causal
:
bool
=
False
,
lstm
:
int
=
2
,
):
super
().
__init__
()
conv1d_type
=
SConv1d
# if causal else WNConv1d
# Add first conv layer
layers
=
[
conv1d_type
(
input_channel
,
channels
,
kernel_size
=
7
,
padding
=
3
,
causal
=
causal
,
norm
=
'weight_norm'
)]
if
lstm
:
layers
+=
[
SLSTM
(
channels
,
num_layers
=
lstm
)]
# Add upsampling + MRF blocks
for
i
,
stride
in
enumerate
(
rates
):
input_dim
=
channels
//
2
**
i
output_dim
=
channels
//
2
**
(
i
+
1
)
layers
+=
[
DecoderBlock
(
input_dim
,
output_dim
,
stride
,
causal
=
causal
)]
# Add final conv layer
layers
+=
[
Snake1d
(
output_dim
),
conv1d_type
(
output_dim
,
d_out
,
kernel_size
=
7
,
padding
=
3
,
causal
=
causal
,
norm
=
'weight_norm'
),
nn
.
Tanh
(),
]
self
.
model
=
nn
.
Sequential
(
*
layers
)
def
forward
(
self
,
x
):
return
self
.
model
(
x
)
class
DAC
(
BaseModel
,
CodecMixin
):
def
__init__
(
self
,
encoder_dim
:
int
=
64
,
encoder_rates
:
List
[
int
]
=
[
2
,
4
,
8
,
8
],
latent_dim
:
int
=
None
,
decoder_dim
:
int
=
1536
,
decoder_rates
:
List
[
int
]
=
[
8
,
8
,
4
,
2
],
n_codebooks
:
int
=
9
,
codebook_size
:
int
=
1024
,
codebook_dim
:
Union
[
int
,
list
]
=
8
,
quantizer_dropout
:
bool
=
False
,
sample_rate
:
int
=
44100
,
lstm
:
int
=
2
,
causal
:
bool
=
False
,
):
super
().
__init__
()
self
.
encoder_dim
=
encoder_dim
self
.
encoder_rates
=
encoder_rates
self
.
decoder_dim
=
decoder_dim
self
.
decoder_rates
=
decoder_rates
self
.
sample_rate
=
sample_rate
if
latent_dim
is
None
:
latent_dim
=
encoder_dim
*
(
2
**
len
(
encoder_rates
))
self
.
latent_dim
=
latent_dim
self
.
hop_length
=
np
.
prod
(
encoder_rates
)
self
.
encoder
=
Encoder
(
encoder_dim
,
encoder_rates
,
latent_dim
,
causal
=
causal
,
lstm
=
lstm
)
self
.
n_codebooks
=
n_codebooks
self
.
codebook_size
=
codebook_size
self
.
codebook_dim
=
codebook_dim
self
.
quantizer
=
ResidualVectorQuantize
(
input_dim
=
latent_dim
,
n_codebooks
=
n_codebooks
,
codebook_size
=
codebook_size
,
codebook_dim
=
codebook_dim
,
quantizer_dropout
=
quantizer_dropout
,
)
self
.
decoder
=
Decoder
(
latent_dim
,
decoder_dim
,
decoder_rates
,
lstm
=
lstm
,
causal
=
causal
,
)
self
.
sample_rate
=
sample_rate
self
.
apply
(
init_weights
)
self
.
delay
=
self
.
get_delay
()
def
preprocess
(
self
,
audio_data
,
sample_rate
):
if
sample_rate
is
None
:
sample_rate
=
self
.
sample_rate
assert
sample_rate
==
self
.
sample_rate
length
=
audio_data
.
shape
[
-
1
]
right_pad
=
math
.
ceil
(
length
/
self
.
hop_length
)
*
self
.
hop_length
-
length
audio_data
=
nn
.
functional
.
pad
(
audio_data
,
(
0
,
right_pad
))
return
audio_data
def
encode
(
self
,
audio_data
:
torch
.
Tensor
,
n_quantizers
:
int
=
None
,
):
"""Encode given audio data and return quantized latent codes
Parameters
----------
audio_data : Tensor[B x 1 x T]
Audio data to encode
n_quantizers : int, optional
Number of quantizers to use, by default None
If None, all quantizers are used.
Returns
-------
dict
A dictionary with the following keys:
"z" : Tensor[B x D x T]
Quantized continuous representation of input
"codes" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"latents" : Tensor[B x N*D x T]
Projected latents (continuous representation of input before quantization)
"vq/commitment_loss" : Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
"vq/codebook_loss" : Tensor[1]
Codebook loss to update the codebook
"length" : int
Number of samples in input audio
"""
z
=
self
.
encoder
(
audio_data
)
z
,
codes
,
latents
,
commitment_loss
,
codebook_loss
=
self
.
quantizer
(
z
,
n_quantizers
)
return
z
,
codes
,
latents
,
commitment_loss
,
codebook_loss
def
decode
(
self
,
z
:
torch
.
Tensor
):
"""Decode given latent codes and return audio data
Parameters
----------
z : Tensor[B x D x T]
Quantized continuous representation of input
length : int, optional
Number of samples in output audio, by default None
Returns
-------
dict
A dictionary with the following keys:
"audio" : Tensor[B x 1 x length]
Decoded audio data.
"""
return
self
.
decoder
(
z
)
def
forward
(
self
,
audio_data
:
torch
.
Tensor
,
sample_rate
:
int
=
None
,
n_quantizers
:
int
=
None
,
):
"""Model forward pass
Parameters
----------
audio_data : Tensor[B x 1 x T]
Audio data to encode
sample_rate : int, optional
Sample rate of audio data in Hz, by default None
If None, defaults to `self.sample_rate`
n_quantizers : int, optional
Number of quantizers to use, by default None.
If None, all quantizers are used.
Returns
-------
dict
A dictionary with the following keys:
"z" : Tensor[B x D x T]
Quantized continuous representation of input
"codes" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"latents" : Tensor[B x N*D x T]
Projected latents (continuous representation of input before quantization)
"vq/commitment_loss" : Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
"vq/codebook_loss" : Tensor[1]
Codebook loss to update the codebook
"length" : int
Number of samples in input audio
"audio" : Tensor[B x 1 x length]
Decoded audio data.
"""
length
=
audio_data
.
shape
[
-
1
]
audio_data
=
self
.
preprocess
(
audio_data
,
sample_rate
)
z
,
codes
,
latents
,
commitment_loss
,
codebook_loss
=
self
.
encode
(
audio_data
,
n_quantizers
)
x
=
self
.
decode
(
z
)
return
{
"audio"
:
x
[...,
:
length
],
"z"
:
z
,
"codes"
:
codes
,
"latents"
:
latents
,
"vq/commitment_loss"
:
commitment_loss
,
"vq/codebook_loss"
:
codebook_loss
,
}
if
__name__
==
"__main__"
:
import
numpy
as
np
from
functools
import
partial
model
=
DAC
().
to
(
"cpu"
)
for
n
,
m
in
model
.
named_modules
():
o
=
m
.
extra_repr
()
p
=
sum
([
np
.
prod
(
p
.
size
())
for
p
in
m
.
parameters
()])
fn
=
lambda
o
,
p
:
o
+
f
"
{
p
/
1e6
:
<
.
3
f
}
M params."
setattr
(
m
,
"extra_repr"
,
partial
(
fn
,
o
=
o
,
p
=
p
))
print
(
model
)
print
(
"Total # of params: "
,
sum
([
np
.
prod
(
p
.
size
())
for
p
in
model
.
parameters
()]))
length
=
88200
*
2
x
=
torch
.
randn
(
1
,
1
,
length
).
to
(
model
.
device
)
x
.
requires_grad_
(
True
)
x
.
retain_grad
()
# Make a forward pass
out
=
model
(
x
)[
"audio"
]
print
(
"Input shape:"
,
x
.
shape
)
print
(
"Output shape:"
,
out
.
shape
)
# Create gradient variable
grad
=
torch
.
zeros_like
(
out
)
grad
[:,
:,
grad
.
shape
[
-
1
]
//
2
]
=
1
# Make a backward pass
out
.
backward
(
grad
)
# Check non-zero values
gradmap
=
x
.
grad
.
squeeze
(
0
)
gradmap
=
(
gradmap
!=
0
).
sum
(
0
)
# sum across features
rf
=
(
gradmap
!=
0
).
sum
()
print
(
f
"Receptive field:
{
rf
.
item
()
}
"
)
x
=
AudioSignal
(
torch
.
randn
(
1
,
1
,
44100
*
60
),
44100
)
model
.
decompress
(
model
.
compress
(
x
,
verbose
=
True
),
verbose
=
True
)
indextts/s2mel/dac/model/discriminator.py
0 → 100644
View file @
ab9c00af
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
audiotools
import
AudioSignal
from
audiotools
import
ml
from
audiotools
import
STFTParams
from
einops
import
rearrange
from
torch.nn.utils
import
weight_norm
def
WNConv1d
(
*
args
,
**
kwargs
):
act
=
kwargs
.
pop
(
"act"
,
True
)
conv
=
weight_norm
(
nn
.
Conv1d
(
*
args
,
**
kwargs
))
if
not
act
:
return
conv
return
nn
.
Sequential
(
conv
,
nn
.
LeakyReLU
(
0.1
))
def
WNConv2d
(
*
args
,
**
kwargs
):
act
=
kwargs
.
pop
(
"act"
,
True
)
conv
=
weight_norm
(
nn
.
Conv2d
(
*
args
,
**
kwargs
))
if
not
act
:
return
conv
return
nn
.
Sequential
(
conv
,
nn
.
LeakyReLU
(
0.1
))
class
MPD
(
nn
.
Module
):
def
__init__
(
self
,
period
):
super
().
__init__
()
self
.
period
=
period
self
.
convs
=
nn
.
ModuleList
(
[
WNConv2d
(
1
,
32
,
(
5
,
1
),
(
3
,
1
),
padding
=
(
2
,
0
)),
WNConv2d
(
32
,
128
,
(
5
,
1
),
(
3
,
1
),
padding
=
(
2
,
0
)),
WNConv2d
(
128
,
512
,
(
5
,
1
),
(
3
,
1
),
padding
=
(
2
,
0
)),
WNConv2d
(
512
,
1024
,
(
5
,
1
),
(
3
,
1
),
padding
=
(
2
,
0
)),
WNConv2d
(
1024
,
1024
,
(
5
,
1
),
1
,
padding
=
(
2
,
0
)),
]
)
self
.
conv_post
=
WNConv2d
(
1024
,
1
,
kernel_size
=
(
3
,
1
),
padding
=
(
1
,
0
),
act
=
False
)
def
pad_to_period
(
self
,
x
):
t
=
x
.
shape
[
-
1
]
x
=
F
.
pad
(
x
,
(
0
,
self
.
period
-
t
%
self
.
period
),
mode
=
"reflect"
)
return
x
def
forward
(
self
,
x
):
fmap
=
[]
x
=
self
.
pad_to_period
(
x
)
x
=
rearrange
(
x
,
"b c (l p) -> b c l p"
,
p
=
self
.
period
)
for
layer
in
self
.
convs
:
x
=
layer
(
x
)
fmap
.
append
(
x
)
x
=
self
.
conv_post
(
x
)
fmap
.
append
(
x
)
return
fmap
class
MSD
(
nn
.
Module
):
def
__init__
(
self
,
rate
:
int
=
1
,
sample_rate
:
int
=
44100
):
super
().
__init__
()
self
.
convs
=
nn
.
ModuleList
(
[
WNConv1d
(
1
,
16
,
15
,
1
,
padding
=
7
),
WNConv1d
(
16
,
64
,
41
,
4
,
groups
=
4
,
padding
=
20
),
WNConv1d
(
64
,
256
,
41
,
4
,
groups
=
16
,
padding
=
20
),
WNConv1d
(
256
,
1024
,
41
,
4
,
groups
=
64
,
padding
=
20
),
WNConv1d
(
1024
,
1024
,
41
,
4
,
groups
=
256
,
padding
=
20
),
WNConv1d
(
1024
,
1024
,
5
,
1
,
padding
=
2
),
]
)
self
.
conv_post
=
WNConv1d
(
1024
,
1
,
3
,
1
,
padding
=
1
,
act
=
False
)
self
.
sample_rate
=
sample_rate
self
.
rate
=
rate
def
forward
(
self
,
x
):
x
=
AudioSignal
(
x
,
self
.
sample_rate
)
x
.
resample
(
self
.
sample_rate
//
self
.
rate
)
x
=
x
.
audio_data
fmap
=
[]
for
l
in
self
.
convs
:
x
=
l
(
x
)
fmap
.
append
(
x
)
x
=
self
.
conv_post
(
x
)
fmap
.
append
(
x
)
return
fmap
BANDS
=
[(
0.0
,
0.1
),
(
0.1
,
0.25
),
(
0.25
,
0.5
),
(
0.5
,
0.75
),
(
0.75
,
1.0
)]
class
MRD
(
nn
.
Module
):
def
__init__
(
self
,
window_length
:
int
,
hop_factor
:
float
=
0.25
,
sample_rate
:
int
=
44100
,
bands
:
list
=
BANDS
,
):
"""Complex multi-band spectrogram discriminator.
Parameters
----------
window_length : int
Window length of STFT.
hop_factor : float, optional
Hop factor of the STFT, defaults to ``0.25 * window_length``.
sample_rate : int, optional
Sampling rate of audio in Hz, by default 44100
bands : list, optional
Bands to run discriminator over.
"""
super
().
__init__
()
self
.
window_length
=
window_length
self
.
hop_factor
=
hop_factor
self
.
sample_rate
=
sample_rate
self
.
stft_params
=
STFTParams
(
window_length
=
window_length
,
hop_length
=
int
(
window_length
*
hop_factor
),
match_stride
=
True
,
)
n_fft
=
window_length
//
2
+
1
bands
=
[(
int
(
b
[
0
]
*
n_fft
),
int
(
b
[
1
]
*
n_fft
))
for
b
in
bands
]
self
.
bands
=
bands
ch
=
32
convs
=
lambda
:
nn
.
ModuleList
(
[
WNConv2d
(
2
,
ch
,
(
3
,
9
),
(
1
,
1
),
padding
=
(
1
,
4
)),
WNConv2d
(
ch
,
ch
,
(
3
,
9
),
(
1
,
2
),
padding
=
(
1
,
4
)),
WNConv2d
(
ch
,
ch
,
(
3
,
9
),
(
1
,
2
),
padding
=
(
1
,
4
)),
WNConv2d
(
ch
,
ch
,
(
3
,
9
),
(
1
,
2
),
padding
=
(
1
,
4
)),
WNConv2d
(
ch
,
ch
,
(
3
,
3
),
(
1
,
1
),
padding
=
(
1
,
1
)),
]
)
self
.
band_convs
=
nn
.
ModuleList
([
convs
()
for
_
in
range
(
len
(
self
.
bands
))])
self
.
conv_post
=
WNConv2d
(
ch
,
1
,
(
3
,
3
),
(
1
,
1
),
padding
=
(
1
,
1
),
act
=
False
)
def
spectrogram
(
self
,
x
):
x
=
AudioSignal
(
x
,
self
.
sample_rate
,
stft_params
=
self
.
stft_params
)
x
=
torch
.
view_as_real
(
x
.
stft
())
x
=
rearrange
(
x
,
"b 1 f t c -> (b 1) c t f"
)
# Split into bands
x_bands
=
[
x
[...,
b
[
0
]
:
b
[
1
]]
for
b
in
self
.
bands
]
return
x_bands
def
forward
(
self
,
x
):
x_bands
=
self
.
spectrogram
(
x
)
fmap
=
[]
x
=
[]
for
band
,
stack
in
zip
(
x_bands
,
self
.
band_convs
):
for
layer
in
stack
:
band
=
layer
(
band
)
fmap
.
append
(
band
)
x
.
append
(
band
)
x
=
torch
.
cat
(
x
,
dim
=-
1
)
x
=
self
.
conv_post
(
x
)
fmap
.
append
(
x
)
return
fmap
class
Discriminator
(
nn
.
Module
):
def
__init__
(
self
,
rates
:
list
=
[],
periods
:
list
=
[
2
,
3
,
5
,
7
,
11
],
fft_sizes
:
list
=
[
2048
,
1024
,
512
],
sample_rate
:
int
=
44100
,
bands
:
list
=
BANDS
,
):
"""Discriminator that combines multiple discriminators.
Parameters
----------
rates : list, optional
sampling rates (in Hz) to run MSD at, by default []
If empty, MSD is not used.
periods : list, optional
periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
fft_sizes : list, optional
Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
sample_rate : int, optional
Sampling rate of audio in Hz, by default 44100
bands : list, optional
Bands to run MRD at, by default `BANDS`
"""
super
().
__init__
()
discs
=
[]
discs
+=
[
MPD
(
p
)
for
p
in
periods
]
discs
+=
[
MSD
(
r
,
sample_rate
=
sample_rate
)
for
r
in
rates
]
discs
+=
[
MRD
(
f
,
sample_rate
=
sample_rate
,
bands
=
bands
)
for
f
in
fft_sizes
]
self
.
discriminators
=
nn
.
ModuleList
(
discs
)
def
preprocess
(
self
,
y
):
# Remove DC offset
y
=
y
-
y
.
mean
(
dim
=-
1
,
keepdims
=
True
)
# Peak normalize the volume of input audio
y
=
0.8
*
y
/
(
y
.
abs
().
max
(
dim
=-
1
,
keepdim
=
True
)[
0
]
+
1e-9
)
return
y
def
forward
(
self
,
x
):
x
=
self
.
preprocess
(
x
)
fmaps
=
[
d
(
x
)
for
d
in
self
.
discriminators
]
return
fmaps
if
__name__
==
"__main__"
:
disc
=
Discriminator
()
x
=
torch
.
zeros
(
1
,
1
,
44100
)
results
=
disc
(
x
)
for
i
,
result
in
enumerate
(
results
):
print
(
f
"disc
{
i
}
"
)
for
i
,
r
in
enumerate
(
result
):
print
(
r
.
shape
,
r
.
mean
(),
r
.
min
(),
r
.
max
())
print
()
indextts/s2mel/dac/model/encodec.py
0 → 100644
View file @
ab9c00af
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Convolutional layers wrappers and utilities."""
import
math
import
typing
as
tp
import
warnings
import
torch
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
torch.nn.utils
import
spectral_norm
,
weight_norm
import
typing
as
tp
import
einops
class
ConvLayerNorm
(
nn
.
LayerNorm
):
"""
Convolution-friendly LayerNorm that moves channels to last dimensions
before running the normalization and moves them back to original position right after.
"""
def
__init__
(
self
,
normalized_shape
:
tp
.
Union
[
int
,
tp
.
List
[
int
],
torch
.
Size
],
**
kwargs
):
super
().
__init__
(
normalized_shape
,
**
kwargs
)
def
forward
(
self
,
x
):
x
=
einops
.
rearrange
(
x
,
'b ... t -> b t ...'
)
x
=
super
().
forward
(
x
)
x
=
einops
.
rearrange
(
x
,
'b t ... -> b ... t'
)
return
CONV_NORMALIZATIONS
=
frozenset
([
'none'
,
'weight_norm'
,
'spectral_norm'
,
'time_layer_norm'
,
'layer_norm'
,
'time_group_norm'
])
def
apply_parametrization_norm
(
module
:
nn
.
Module
,
norm
:
str
=
'none'
)
->
nn
.
Module
:
assert
norm
in
CONV_NORMALIZATIONS
if
norm
==
'weight_norm'
:
return
weight_norm
(
module
)
elif
norm
==
'spectral_norm'
:
return
spectral_norm
(
module
)
else
:
# We already check was in CONV_NORMALIZATION, so any other choice
# doesn't need reparametrization.
return
module
def
get_norm_module
(
module
:
nn
.
Module
,
causal
:
bool
=
False
,
norm
:
str
=
'none'
,
**
norm_kwargs
)
->
nn
.
Module
:
"""Return the proper normalization module. If causal is True, this will ensure the returned
module is causal, or return an error if the normalization doesn't support causal evaluation.
"""
assert
norm
in
CONV_NORMALIZATIONS
if
norm
==
'layer_norm'
:
assert
isinstance
(
module
,
nn
.
modules
.
conv
.
_ConvNd
)
return
ConvLayerNorm
(
module
.
out_channels
,
**
norm_kwargs
)
elif
norm
==
'time_group_norm'
:
if
causal
:
raise
ValueError
(
"GroupNorm doesn't support causal evaluation."
)
assert
isinstance
(
module
,
nn
.
modules
.
conv
.
_ConvNd
)
return
nn
.
GroupNorm
(
1
,
module
.
out_channels
,
**
norm_kwargs
)
else
:
return
nn
.
Identity
()
def
get_extra_padding_for_conv1d
(
x
:
torch
.
Tensor
,
kernel_size
:
int
,
stride
:
int
,
padding_total
:
int
=
0
)
->
int
:
"""See `pad_for_conv1d`.
"""
length
=
x
.
shape
[
-
1
]
n_frames
=
(
length
-
kernel_size
+
padding_total
)
/
stride
+
1
ideal_length
=
(
math
.
ceil
(
n_frames
)
-
1
)
*
stride
+
(
kernel_size
-
padding_total
)
return
ideal_length
-
length
def
pad_for_conv1d
(
x
:
torch
.
Tensor
,
kernel_size
:
int
,
stride
:
int
,
padding_total
:
int
=
0
):
"""Pad for a convolution to make sure that the last window is full.
Extra padding is added at the end. This is required to ensure that we can rebuild
an output of the same length, as otherwise, even with padding, some time steps
might get removed.
For instance, with total padding = 4, kernel size = 4, stride = 2:
0 0 1 2 3 4 5 0 0 # (0s are padding)
1 2 3 # (output frames of a convolution, last 0 is never used)
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
1 2 3 4 # once you removed padding, we are missing one time step !
"""
extra_padding
=
get_extra_padding_for_conv1d
(
x
,
kernel_size
,
stride
,
padding_total
)
return
F
.
pad
(
x
,
(
0
,
extra_padding
))
def
pad1d
(
x
:
torch
.
Tensor
,
paddings
:
tp
.
Tuple
[
int
,
int
],
mode
:
str
=
'zero'
,
value
:
float
=
0.
):
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
If this is the case, we insert extra 0 padding to the right before the reflection happen.
"""
length
=
x
.
shape
[
-
1
]
padding_left
,
padding_right
=
paddings
assert
padding_left
>=
0
and
padding_right
>=
0
,
(
padding_left
,
padding_right
)
if
mode
==
'reflect'
:
max_pad
=
max
(
padding_left
,
padding_right
)
extra_pad
=
0
if
length
<=
max_pad
:
extra_pad
=
max_pad
-
length
+
1
x
=
F
.
pad
(
x
,
(
0
,
extra_pad
))
padded
=
F
.
pad
(
x
,
paddings
,
mode
,
value
)
end
=
padded
.
shape
[
-
1
]
-
extra_pad
return
padded
[...,
:
end
]
else
:
return
F
.
pad
(
x
,
paddings
,
mode
,
value
)
def
unpad1d
(
x
:
torch
.
Tensor
,
paddings
:
tp
.
Tuple
[
int
,
int
]):
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
padding_left
,
padding_right
=
paddings
assert
padding_left
>=
0
and
padding_right
>=
0
,
(
padding_left
,
padding_right
)
assert
(
padding_left
+
padding_right
)
<=
x
.
shape
[
-
1
]
end
=
x
.
shape
[
-
1
]
-
padding_right
return
x
[...,
padding_left
:
end
]
class
NormConv1d
(
nn
.
Module
):
"""Wrapper around Conv1d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def
__init__
(
self
,
*
args
,
causal
:
bool
=
False
,
norm
:
str
=
'none'
,
norm_kwargs
:
tp
.
Dict
[
str
,
tp
.
Any
]
=
{},
**
kwargs
):
super
().
__init__
()
self
.
conv
=
apply_parametrization_norm
(
nn
.
Conv1d
(
*
args
,
**
kwargs
),
norm
)
self
.
norm
=
get_norm_module
(
self
.
conv
,
causal
,
norm
,
**
norm_kwargs
)
self
.
norm_type
=
norm
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
norm
(
x
)
return
x
class
NormConv2d
(
nn
.
Module
):
"""Wrapper around Conv2d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def
__init__
(
self
,
*
args
,
norm
:
str
=
'none'
,
norm_kwargs
:
tp
.
Dict
[
str
,
tp
.
Any
]
=
{},
**
kwargs
):
super
().
__init__
()
self
.
conv
=
apply_parametrization_norm
(
nn
.
Conv2d
(
*
args
,
**
kwargs
),
norm
)
self
.
norm
=
get_norm_module
(
self
.
conv
,
causal
=
False
,
norm
=
norm
,
**
norm_kwargs
)
self
.
norm_type
=
norm
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
norm
(
x
)
return
x
class
NormConvTranspose1d
(
nn
.
Module
):
"""Wrapper around ConvTranspose1d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def
__init__
(
self
,
*
args
,
causal
:
bool
=
False
,
norm
:
str
=
'none'
,
norm_kwargs
:
tp
.
Dict
[
str
,
tp
.
Any
]
=
{},
**
kwargs
):
super
().
__init__
()
self
.
convtr
=
apply_parametrization_norm
(
nn
.
ConvTranspose1d
(
*
args
,
**
kwargs
),
norm
)
self
.
norm
=
get_norm_module
(
self
.
convtr
,
causal
,
norm
,
**
norm_kwargs
)
self
.
norm_type
=
norm
def
forward
(
self
,
x
):
x
=
self
.
convtr
(
x
)
x
=
self
.
norm
(
x
)
return
x
class
NormConvTranspose2d
(
nn
.
Module
):
"""Wrapper around ConvTranspose2d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def
__init__
(
self
,
*
args
,
norm
:
str
=
'none'
,
norm_kwargs
:
tp
.
Dict
[
str
,
tp
.
Any
]
=
{},
**
kwargs
):
super
().
__init__
()
self
.
convtr
=
apply_parametrization_norm
(
nn
.
ConvTranspose2d
(
*
args
,
**
kwargs
),
norm
)
self
.
norm
=
get_norm_module
(
self
.
convtr
,
causal
=
False
,
norm
=
norm
,
**
norm_kwargs
)
def
forward
(
self
,
x
):
x
=
self
.
convtr
(
x
)
x
=
self
.
norm
(
x
)
return
x
class
SConv1d
(
nn
.
Module
):
"""Conv1d with some builtin handling of asymmetric or causal padding
and normalization.
"""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
int
,
stride
:
int
=
1
,
dilation
:
int
=
1
,
groups
:
int
=
1
,
bias
:
bool
=
True
,
causal
:
bool
=
False
,
norm
:
str
=
'none'
,
norm_kwargs
:
tp
.
Dict
[
str
,
tp
.
Any
]
=
{},
pad_mode
:
str
=
'reflect'
,
**
kwargs
):
super
().
__init__
()
# warn user on unusual setup between dilation and stride
if
stride
>
1
and
dilation
>
1
:
warnings
.
warn
(
'SConv1d has been initialized with stride > 1 and dilation > 1'
f
' (kernel_size=
{
kernel_size
}
stride=
{
stride
}
, dilation=
{
dilation
}
).'
)
self
.
conv
=
NormConv1d
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
dilation
=
dilation
,
groups
=
groups
,
bias
=
bias
,
causal
=
causal
,
norm
=
norm
,
norm_kwargs
=
norm_kwargs
)
self
.
causal
=
causal
self
.
pad_mode
=
pad_mode
self
.
cache_enabled
=
False
def
reset_cache
(
self
):
"""Reset the cache when starting a new stream."""
self
.
cache
=
None
self
.
cache_enabled
=
True
def
forward
(
self
,
x
):
B
,
C
,
T
=
x
.
shape
kernel_size
=
self
.
conv
.
conv
.
kernel_size
[
0
]
stride
=
self
.
conv
.
conv
.
stride
[
0
]
dilation
=
self
.
conv
.
conv
.
dilation
[
0
]
kernel_size
=
(
kernel_size
-
1
)
*
dilation
+
1
# effective kernel size with dilations
padding_total
=
kernel_size
-
stride
extra_padding
=
get_extra_padding_for_conv1d
(
x
,
kernel_size
,
stride
,
padding_total
)
if
self
.
causal
:
# Left padding for causal
if
self
.
cache_enabled
and
self
.
cache
is
not
None
:
# Concatenate the cache (previous inputs) with the new input for streaming
x
=
torch
.
cat
([
self
.
cache
,
x
],
dim
=
2
)
else
:
x
=
pad1d
(
x
,
(
padding_total
,
extra_padding
),
mode
=
self
.
pad_mode
)
else
:
# Asymmetric padding required for odd strides
padding_right
=
padding_total
//
2
padding_left
=
padding_total
-
padding_right
x
=
pad1d
(
x
,
(
padding_left
,
padding_right
+
extra_padding
),
mode
=
self
.
pad_mode
)
# Store the most recent input frames for future cache use
if
self
.
cache_enabled
:
if
self
.
cache
is
None
:
# Initialize cache with zeros (at the start of streaming)
self
.
cache
=
torch
.
zeros
(
B
,
C
,
kernel_size
-
1
,
device
=
x
.
device
)
# Update the cache by storing the latest input frames
if
kernel_size
>
1
:
self
.
cache
=
x
[:,
:,
-
kernel_size
+
1
:].
detach
()
# Only store the necessary frames
return
self
.
conv
(
x
)
class
SConvTranspose1d
(
nn
.
Module
):
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
and normalization.
"""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
int
,
stride
:
int
=
1
,
causal
:
bool
=
False
,
norm
:
str
=
'none'
,
trim_right_ratio
:
float
=
1.
,
norm_kwargs
:
tp
.
Dict
[
str
,
tp
.
Any
]
=
{},
**
kwargs
):
super
().
__init__
()
self
.
convtr
=
NormConvTranspose1d
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
causal
=
causal
,
norm
=
norm
,
norm_kwargs
=
norm_kwargs
)
self
.
causal
=
causal
self
.
trim_right_ratio
=
trim_right_ratio
assert
self
.
causal
or
self
.
trim_right_ratio
==
1.
,
\
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
assert
self
.
trim_right_ratio
>=
0.
and
self
.
trim_right_ratio
<=
1.
def
forward
(
self
,
x
):
kernel_size
=
self
.
convtr
.
convtr
.
kernel_size
[
0
]
stride
=
self
.
convtr
.
convtr
.
stride
[
0
]
padding_total
=
kernel_size
-
stride
y
=
self
.
convtr
(
x
)
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
# removed at the very end, when keeping only the right length for the output,
# as removing it here would require also passing the length at the matching layer
# in the encoder.
if
self
.
causal
:
# Trim the padding on the right according to the specified ratio
# if trim_right_ratio = 1.0, trim everything from right
padding_right
=
math
.
ceil
(
padding_total
*
self
.
trim_right_ratio
)
padding_left
=
padding_total
-
padding_right
y
=
unpad1d
(
y
,
(
padding_left
,
padding_right
))
else
:
# Asymmetric padding required for odd strides
padding_right
=
padding_total
//
2
padding_left
=
padding_total
-
padding_right
y
=
unpad1d
(
y
,
(
padding_left
,
padding_right
))
return
y
class
SLSTM
(
nn
.
Module
):
"""
LSTM without worrying about the hidden state, nor the layout of the data.
Expects input as convolutional layout.
"""
def
__init__
(
self
,
dimension
:
int
,
num_layers
:
int
=
2
,
skip
:
bool
=
True
):
super
().
__init__
()
self
.
skip
=
skip
self
.
lstm
=
nn
.
LSTM
(
dimension
,
dimension
,
num_layers
)
self
.
hidden
=
None
self
.
cache_enabled
=
False
def
forward
(
self
,
x
):
x
=
x
.
permute
(
2
,
0
,
1
)
if
self
.
training
or
not
self
.
cache_enabled
:
y
,
_
=
self
.
lstm
(
x
)
else
:
y
,
self
.
hidden
=
self
.
lstm
(
x
,
self
.
hidden
)
if
self
.
skip
:
y
=
y
+
x
y
=
y
.
permute
(
1
,
2
,
0
)
return
y
def
reset_cache
(
self
):
self
.
hidden
=
None
self
.
cache_enabled
=
True
\ No newline at end of file
indextts/s2mel/dac/nn/__init__.py
0 → 100644
View file @
ab9c00af
from
.
import
layers
from
.
import
loss
from
.
import
quantize
indextts/s2mel/dac/nn/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
ab9c00af
File added
indextts/s2mel/dac/nn/__pycache__/layers.cpython-310.pyc
0 → 100644
View file @
ab9c00af
File added
indextts/s2mel/dac/nn/__pycache__/loss.cpython-310.pyc
0 → 100644
View file @
ab9c00af
File added
indextts/s2mel/dac/nn/__pycache__/quantize.cpython-310.pyc
0 → 100644
View file @
ab9c00af
File added
indextts/s2mel/dac/nn/layers.py
0 → 100644
View file @
ab9c00af
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
torch.nn.utils
import
weight_norm
def
WNConv1d
(
*
args
,
**
kwargs
):
return
weight_norm
(
nn
.
Conv1d
(
*
args
,
**
kwargs
))
def
WNConvTranspose1d
(
*
args
,
**
kwargs
):
return
weight_norm
(
nn
.
ConvTranspose1d
(
*
args
,
**
kwargs
))
# Scripting this brings model speed up 1.4x
@
torch
.
jit
.
script
def
snake
(
x
,
alpha
):
shape
=
x
.
shape
x
=
x
.
reshape
(
shape
[
0
],
shape
[
1
],
-
1
)
x
=
x
+
(
alpha
+
1e-9
).
reciprocal
()
*
torch
.
sin
(
alpha
*
x
).
pow
(
2
)
x
=
x
.
reshape
(
shape
)
return
x
class
Snake1d
(
nn
.
Module
):
def
__init__
(
self
,
channels
):
super
().
__init__
()
self
.
alpha
=
nn
.
Parameter
(
torch
.
ones
(
1
,
channels
,
1
))
def
forward
(
self
,
x
):
return
snake
(
x
,
self
.
alpha
)
indextts/s2mel/dac/nn/loss.py
0 → 100644
View file @
ab9c00af
import
typing
from
typing
import
List
import
torch
import
torch.nn.functional
as
F
from
audiotools
import
AudioSignal
from
audiotools
import
STFTParams
from
torch
import
nn
class
L1Loss
(
nn
.
L1Loss
):
"""L1 Loss between AudioSignals. Defaults
to comparing ``audio_data``, but any
attribute of an AudioSignal can be used.
Parameters
----------
attribute : str, optional
Attribute of signal to compare, defaults to ``audio_data``.
weight : float, optional
Weight of this loss, defaults to 1.0.
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
"""
def
__init__
(
self
,
attribute
:
str
=
"audio_data"
,
weight
:
float
=
1.0
,
**
kwargs
):
self
.
attribute
=
attribute
self
.
weight
=
weight
super
().
__init__
(
**
kwargs
)
def
forward
(
self
,
x
:
AudioSignal
,
y
:
AudioSignal
):
"""
Parameters
----------
x : AudioSignal
Estimate AudioSignal
y : AudioSignal
Reference AudioSignal
Returns
-------
torch.Tensor
L1 loss between AudioSignal attributes.
"""
if
isinstance
(
x
,
AudioSignal
):
x
=
getattr
(
x
,
self
.
attribute
)
y
=
getattr
(
y
,
self
.
attribute
)
return
super
().
forward
(
x
,
y
)
class
SISDRLoss
(
nn
.
Module
):
"""
Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
of estimated and reference audio signals or aligned features.
Parameters
----------
scaling : int, optional
Whether to use scale-invariant (True) or
signal-to-noise ratio (False), by default True
reduction : str, optional
How to reduce across the batch (either 'mean',
'sum', or none).], by default ' mean'
zero_mean : int, optional
Zero mean the references and estimates before
computing the loss, by default True
clip_min : int, optional
The minimum possible loss value. Helps network
to not focus on making already good examples better, by default None
weight : float, optional
Weight of this loss, defaults to 1.0.
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
"""
def
__init__
(
self
,
scaling
:
int
=
True
,
reduction
:
str
=
"mean"
,
zero_mean
:
int
=
True
,
clip_min
:
int
=
None
,
weight
:
float
=
1.0
,
):
self
.
scaling
=
scaling
self
.
reduction
=
reduction
self
.
zero_mean
=
zero_mean
self
.
clip_min
=
clip_min
self
.
weight
=
weight
super
().
__init__
()
def
forward
(
self
,
x
:
AudioSignal
,
y
:
AudioSignal
):
eps
=
1e-8
# nb, nc, nt
if
isinstance
(
x
,
AudioSignal
):
references
=
x
.
audio_data
estimates
=
y
.
audio_data
else
:
references
=
x
estimates
=
y
nb
=
references
.
shape
[
0
]
references
=
references
.
reshape
(
nb
,
1
,
-
1
).
permute
(
0
,
2
,
1
)
estimates
=
estimates
.
reshape
(
nb
,
1
,
-
1
).
permute
(
0
,
2
,
1
)
# samples now on axis 1
if
self
.
zero_mean
:
mean_reference
=
references
.
mean
(
dim
=
1
,
keepdim
=
True
)
mean_estimate
=
estimates
.
mean
(
dim
=
1
,
keepdim
=
True
)
else
:
mean_reference
=
0
mean_estimate
=
0
_references
=
references
-
mean_reference
_estimates
=
estimates
-
mean_estimate
references_projection
=
(
_references
**
2
).
sum
(
dim
=-
2
)
+
eps
references_on_estimates
=
(
_estimates
*
_references
).
sum
(
dim
=-
2
)
+
eps
scale
=
(
(
references_on_estimates
/
references_projection
).
unsqueeze
(
1
)
if
self
.
scaling
else
1
)
e_true
=
scale
*
_references
e_res
=
_estimates
-
e_true
signal
=
(
e_true
**
2
).
sum
(
dim
=
1
)
noise
=
(
e_res
**
2
).
sum
(
dim
=
1
)
sdr
=
-
10
*
torch
.
log10
(
signal
/
noise
+
eps
)
if
self
.
clip_min
is
not
None
:
sdr
=
torch
.
clamp
(
sdr
,
min
=
self
.
clip_min
)
if
self
.
reduction
==
"mean"
:
sdr
=
sdr
.
mean
()
elif
self
.
reduction
==
"sum"
:
sdr
=
sdr
.
sum
()
return
sdr
class
MultiScaleSTFTLoss
(
nn
.
Module
):
"""Computes the multi-scale STFT loss from [1].
Parameters
----------
window_lengths : List[int], optional
Length of each window of each STFT, by default [2048, 512]
loss_fn : typing.Callable, optional
How to compare each loss, by default nn.L1Loss()
clamp_eps : float, optional
Clamp on the log magnitude, below, by default 1e-5
mag_weight : float, optional
Weight of raw magnitude portion of loss, by default 1.0
log_weight : float, optional
Weight of log magnitude portion of loss, by default 1.0
pow : float, optional
Power to raise magnitude to before taking log, by default 2.0
weight : float, optional
Weight of this loss, by default 1.0
match_stride : bool, optional
Whether to match the stride of convolutional layers, by default False
References
----------
1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
"DDSP: Differentiable Digital Signal Processing."
International Conference on Learning Representations. 2019.
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
"""
def
__init__
(
self
,
window_lengths
:
List
[
int
]
=
[
2048
,
512
],
loss_fn
:
typing
.
Callable
=
nn
.
L1Loss
(),
clamp_eps
:
float
=
1e-5
,
mag_weight
:
float
=
1.0
,
log_weight
:
float
=
1.0
,
pow
:
float
=
2.0
,
weight
:
float
=
1.0
,
match_stride
:
bool
=
False
,
window_type
:
str
=
None
,
):
super
().
__init__
()
self
.
stft_params
=
[
STFTParams
(
window_length
=
w
,
hop_length
=
w
//
4
,
match_stride
=
match_stride
,
window_type
=
window_type
,
)
for
w
in
window_lengths
]
self
.
loss_fn
=
loss_fn
self
.
log_weight
=
log_weight
self
.
mag_weight
=
mag_weight
self
.
clamp_eps
=
clamp_eps
self
.
weight
=
weight
self
.
pow
=
pow
def
forward
(
self
,
x
:
AudioSignal
,
y
:
AudioSignal
):
"""Computes multi-scale STFT between an estimate and a reference
signal.
Parameters
----------
x : AudioSignal
Estimate signal
y : AudioSignal
Reference signal
Returns
-------
torch.Tensor
Multi-scale STFT loss.
"""
loss
=
0.0
for
s
in
self
.
stft_params
:
x
.
stft
(
s
.
window_length
,
s
.
hop_length
,
s
.
window_type
)
y
.
stft
(
s
.
window_length
,
s
.
hop_length
,
s
.
window_type
)
loss
+=
self
.
log_weight
*
self
.
loss_fn
(
x
.
magnitude
.
clamp
(
self
.
clamp_eps
).
pow
(
self
.
pow
).
log10
(),
y
.
magnitude
.
clamp
(
self
.
clamp_eps
).
pow
(
self
.
pow
).
log10
(),
)
loss
+=
self
.
mag_weight
*
self
.
loss_fn
(
x
.
magnitude
,
y
.
magnitude
)
return
loss
class
MelSpectrogramLoss
(
nn
.
Module
):
"""Compute distance between mel spectrograms. Can be used
in a multi-scale way.
Parameters
----------
n_mels : List[int]
Number of mels per STFT, by default [150, 80],
window_lengths : List[int], optional
Length of each window of each STFT, by default [2048, 512]
loss_fn : typing.Callable, optional
How to compare each loss, by default nn.L1Loss()
clamp_eps : float, optional
Clamp on the log magnitude, below, by default 1e-5
mag_weight : float, optional
Weight of raw magnitude portion of loss, by default 1.0
log_weight : float, optional
Weight of log magnitude portion of loss, by default 1.0
pow : float, optional
Power to raise magnitude to before taking log, by default 2.0
weight : float, optional
Weight of this loss, by default 1.0
match_stride : bool, optional
Whether to match the stride of convolutional layers, by default False
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
"""
def
__init__
(
self
,
n_mels
:
List
[
int
]
=
[
150
,
80
],
window_lengths
:
List
[
int
]
=
[
2048
,
512
],
loss_fn
:
typing
.
Callable
=
nn
.
L1Loss
(),
clamp_eps
:
float
=
1e-5
,
mag_weight
:
float
=
1.0
,
log_weight
:
float
=
1.0
,
pow
:
float
=
2.0
,
weight
:
float
=
1.0
,
match_stride
:
bool
=
False
,
mel_fmin
:
List
[
float
]
=
[
0.0
,
0.0
],
mel_fmax
:
List
[
float
]
=
[
None
,
None
],
window_type
:
str
=
None
,
):
super
().
__init__
()
self
.
stft_params
=
[
STFTParams
(
window_length
=
w
,
hop_length
=
w
//
4
,
match_stride
=
match_stride
,
window_type
=
window_type
,
)
for
w
in
window_lengths
]
self
.
n_mels
=
n_mels
self
.
loss_fn
=
loss_fn
self
.
clamp_eps
=
clamp_eps
self
.
log_weight
=
log_weight
self
.
mag_weight
=
mag_weight
self
.
weight
=
weight
self
.
mel_fmin
=
mel_fmin
self
.
mel_fmax
=
mel_fmax
self
.
pow
=
pow
def
forward
(
self
,
x
:
AudioSignal
,
y
:
AudioSignal
):
"""Computes mel loss between an estimate and a reference
signal.
Parameters
----------
x : AudioSignal
Estimate signal
y : AudioSignal
Reference signal
Returns
-------
torch.Tensor
Mel loss.
"""
loss
=
0.0
for
n_mels
,
fmin
,
fmax
,
s
in
zip
(
self
.
n_mels
,
self
.
mel_fmin
,
self
.
mel_fmax
,
self
.
stft_params
):
kwargs
=
{
"window_length"
:
s
.
window_length
,
"hop_length"
:
s
.
hop_length
,
"window_type"
:
s
.
window_type
,
}
x_mels
=
x
.
mel_spectrogram
(
n_mels
,
mel_fmin
=
fmin
,
mel_fmax
=
fmax
,
**
kwargs
)
y_mels
=
y
.
mel_spectrogram
(
n_mels
,
mel_fmin
=
fmin
,
mel_fmax
=
fmax
,
**
kwargs
)
loss
+=
self
.
log_weight
*
self
.
loss_fn
(
x_mels
.
clamp
(
self
.
clamp_eps
).
pow
(
self
.
pow
).
log10
(),
y_mels
.
clamp
(
self
.
clamp_eps
).
pow
(
self
.
pow
).
log10
(),
)
loss
+=
self
.
mag_weight
*
self
.
loss_fn
(
x_mels
,
y_mels
)
return
loss
class
GANLoss
(
nn
.
Module
):
"""
Computes a discriminator loss, given a discriminator on
generated waveforms/spectrograms compared to ground truth
waveforms/spectrograms. Computes the loss for both the
discriminator and the generator in separate functions.
"""
def
__init__
(
self
,
discriminator
):
super
().
__init__
()
self
.
discriminator
=
discriminator
def
forward
(
self
,
fake
,
real
):
d_fake
=
self
.
discriminator
(
fake
.
audio_data
)
d_real
=
self
.
discriminator
(
real
.
audio_data
)
return
d_fake
,
d_real
def
discriminator_loss
(
self
,
fake
,
real
):
d_fake
,
d_real
=
self
.
forward
(
fake
.
clone
().
detach
(),
real
)
loss_d
=
0
for
x_fake
,
x_real
in
zip
(
d_fake
,
d_real
):
loss_d
+=
torch
.
mean
(
x_fake
[
-
1
]
**
2
)
loss_d
+=
torch
.
mean
((
1
-
x_real
[
-
1
])
**
2
)
return
loss_d
def
generator_loss
(
self
,
fake
,
real
):
d_fake
,
d_real
=
self
.
forward
(
fake
,
real
)
loss_g
=
0
for
x_fake
in
d_fake
:
loss_g
+=
torch
.
mean
((
1
-
x_fake
[
-
1
])
**
2
)
loss_feature
=
0
for
i
in
range
(
len
(
d_fake
)):
for
j
in
range
(
len
(
d_fake
[
i
])
-
1
):
loss_feature
+=
F
.
l1_loss
(
d_fake
[
i
][
j
],
d_real
[
i
][
j
].
detach
())
return
loss_g
,
loss_feature
indextts/s2mel/dac/nn/quantize.py
0 → 100644
View file @
ab9c00af
from
typing
import
Union
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
torch.nn.utils
import
weight_norm
from
indextts.s2mel.dac.nn.layers
import
WNConv1d
class
VectorQuantizeLegacy
(
nn
.
Module
):
"""
Implementation of VQ similar to Karpathy's repo:
https://github.com/karpathy/deep-vector-quantization
removed in-out projection
"""
def
__init__
(
self
,
input_dim
:
int
,
codebook_size
:
int
):
super
().
__init__
()
self
.
codebook_size
=
codebook_size
self
.
codebook
=
nn
.
Embedding
(
codebook_size
,
input_dim
)
def
forward
(
self
,
z
,
z_mask
=
None
):
"""Quantized the input tensor using a fixed codebook and returns
the corresponding codebook vectors
Parameters
----------
z : Tensor[B x D x T]
Returns
-------
Tensor[B x D x T]
Quantized continuous representation of input
Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
Tensor[1]
Codebook loss to update the codebook
Tensor[B x T]
Codebook indices (quantized discrete representation of input)
Tensor[B x D x T]
Projected latents (continuous representation of input before quantization)
"""
z_e
=
z
z_q
,
indices
=
self
.
decode_latents
(
z
)
if
z_mask
is
not
None
:
commitment_loss
=
(
F
.
mse_loss
(
z_e
,
z_q
.
detach
(),
reduction
=
"none"
).
mean
(
1
)
*
z_mask
).
sum
()
/
z_mask
.
sum
()
codebook_loss
=
(
F
.
mse_loss
(
z_q
,
z_e
.
detach
(),
reduction
=
"none"
).
mean
(
1
)
*
z_mask
).
sum
()
/
z_mask
.
sum
()
else
:
commitment_loss
=
F
.
mse_loss
(
z_e
,
z_q
.
detach
())
codebook_loss
=
F
.
mse_loss
(
z_q
,
z_e
.
detach
())
z_q
=
(
z_e
+
(
z_q
-
z_e
).
detach
()
)
# noop in forward pass, straight-through gradient estimator in backward pass
return
z_q
,
indices
,
z_e
,
commitment_loss
,
codebook_loss
def
embed_code
(
self
,
embed_id
):
return
F
.
embedding
(
embed_id
,
self
.
codebook
.
weight
)
def
decode_code
(
self
,
embed_id
):
return
self
.
embed_code
(
embed_id
).
transpose
(
1
,
2
)
def
decode_latents
(
self
,
latents
):
encodings
=
rearrange
(
latents
,
"b d t -> (b t) d"
)
codebook
=
self
.
codebook
.
weight
# codebook: (N x D)
# L2 normalize encodings and codebook (ViT-VQGAN)
encodings
=
F
.
normalize
(
encodings
)
codebook
=
F
.
normalize
(
codebook
)
# Compute euclidean distance with codebook
dist
=
(
encodings
.
pow
(
2
).
sum
(
1
,
keepdim
=
True
)
-
2
*
encodings
@
codebook
.
t
()
+
codebook
.
pow
(
2
).
sum
(
1
,
keepdim
=
True
).
t
()
)
indices
=
rearrange
((
-
dist
).
max
(
1
)[
1
],
"(b t) -> b t"
,
b
=
latents
.
size
(
0
))
z_q
=
self
.
decode_code
(
indices
)
return
z_q
,
indices
class
VectorQuantize
(
nn
.
Module
):
"""
Implementation of VQ similar to Karpathy's repo:
https://github.com/karpathy/deep-vector-quantization
Additionally uses following tricks from Improved VQGAN
(https://arxiv.org/pdf/2110.04627.pdf):
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
for improved codebook usage
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
improves training stability
"""
def
__init__
(
self
,
input_dim
:
int
,
codebook_size
:
int
,
codebook_dim
:
int
):
super
().
__init__
()
self
.
codebook_size
=
codebook_size
self
.
codebook_dim
=
codebook_dim
self
.
in_proj
=
WNConv1d
(
input_dim
,
codebook_dim
,
kernel_size
=
1
)
self
.
out_proj
=
WNConv1d
(
codebook_dim
,
input_dim
,
kernel_size
=
1
)
self
.
codebook
=
nn
.
Embedding
(
codebook_size
,
codebook_dim
)
def
forward
(
self
,
z
,
z_mask
=
None
):
"""Quantized the input tensor using a fixed codebook and returns
the corresponding codebook vectors
Parameters
----------
z : Tensor[B x D x T]
Returns
-------
Tensor[B x D x T]
Quantized continuous representation of input
Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
Tensor[1]
Codebook loss to update the codebook
Tensor[B x T]
Codebook indices (quantized discrete representation of input)
Tensor[B x D x T]
Projected latents (continuous representation of input before quantization)
"""
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
z_e
=
self
.
in_proj
(
z
)
# z_e : (B x D x T)
z_q
,
indices
=
self
.
decode_latents
(
z_e
)
if
z_mask
is
not
None
:
commitment_loss
=
(
F
.
mse_loss
(
z_e
,
z_q
.
detach
(),
reduction
=
"none"
).
mean
(
1
)
*
z_mask
).
sum
()
/
z_mask
.
sum
()
codebook_loss
=
(
F
.
mse_loss
(
z_q
,
z_e
.
detach
(),
reduction
=
"none"
).
mean
(
1
)
*
z_mask
).
sum
()
/
z_mask
.
sum
()
else
:
commitment_loss
=
F
.
mse_loss
(
z_e
,
z_q
.
detach
())
codebook_loss
=
F
.
mse_loss
(
z_q
,
z_e
.
detach
())
z_q
=
(
z_e
+
(
z_q
-
z_e
).
detach
()
)
# noop in forward pass, straight-through gradient estimator in backward pass
z_q
=
self
.
out_proj
(
z_q
)
return
z_q
,
commitment_loss
,
codebook_loss
,
indices
,
z_e
def
embed_code
(
self
,
embed_id
):
return
F
.
embedding
(
embed_id
,
self
.
codebook
.
weight
)
def
decode_code
(
self
,
embed_id
):
return
self
.
embed_code
(
embed_id
).
transpose
(
1
,
2
)
def
decode_latents
(
self
,
latents
):
encodings
=
rearrange
(
latents
,
"b d t -> (b t) d"
)
codebook
=
self
.
codebook
.
weight
# codebook: (N x D)
# L2 normalize encodings and codebook (ViT-VQGAN)
encodings
=
F
.
normalize
(
encodings
)
codebook
=
F
.
normalize
(
codebook
)
# Compute euclidean distance with codebook
dist
=
(
encodings
.
pow
(
2
).
sum
(
1
,
keepdim
=
True
)
-
2
*
encodings
@
codebook
.
t
()
+
codebook
.
pow
(
2
).
sum
(
1
,
keepdim
=
True
).
t
()
)
indices
=
rearrange
((
-
dist
).
max
(
1
)[
1
],
"(b t) -> b t"
,
b
=
latents
.
size
(
0
))
z_q
=
self
.
decode_code
(
indices
)
return
z_q
,
indices
class
ResidualVectorQuantize
(
nn
.
Module
):
"""
Introduced in SoundStream: An end2end neural audio codec
https://arxiv.org/abs/2107.03312
"""
def
__init__
(
self
,
input_dim
:
int
=
512
,
n_codebooks
:
int
=
9
,
codebook_size
:
int
=
1024
,
codebook_dim
:
Union
[
int
,
list
]
=
8
,
quantizer_dropout
:
float
=
0.0
,
):
super
().
__init__
()
if
isinstance
(
codebook_dim
,
int
):
codebook_dim
=
[
codebook_dim
for
_
in
range
(
n_codebooks
)]
self
.
n_codebooks
=
n_codebooks
self
.
codebook_dim
=
codebook_dim
self
.
codebook_size
=
codebook_size
self
.
quantizers
=
nn
.
ModuleList
(
[
VectorQuantize
(
input_dim
,
codebook_size
,
codebook_dim
[
i
])
for
i
in
range
(
n_codebooks
)
]
)
self
.
quantizer_dropout
=
quantizer_dropout
def
forward
(
self
,
z
,
n_quantizers
:
int
=
None
):
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
the corresponding codebook vectors
Parameters
----------
z : Tensor[B x D x T]
n_quantizers : int, optional
No. of quantizers to use
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
Note: if `self.quantizer_dropout` is True, this argument is ignored
when in training mode, and a random number of quantizers is used.
Returns
-------
dict
A dictionary with the following keys:
"z" : Tensor[B x D x T]
Quantized continuous representation of input
"codes" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"latents" : Tensor[B x N*D x T]
Projected latents (continuous representation of input before quantization)
"vq/commitment_loss" : Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
"vq/codebook_loss" : Tensor[1]
Codebook loss to update the codebook
"""
z_q
=
0
residual
=
z
commitment_loss
=
0
codebook_loss
=
0
codebook_indices
=
[]
latents
=
[]
if
n_quantizers
is
None
:
n_quantizers
=
self
.
n_codebooks
if
self
.
training
:
n_quantizers
=
torch
.
ones
((
z
.
shape
[
0
],))
*
self
.
n_codebooks
+
1
dropout
=
torch
.
randint
(
1
,
self
.
n_codebooks
+
1
,
(
z
.
shape
[
0
],))
n_dropout
=
int
(
z
.
shape
[
0
]
*
self
.
quantizer_dropout
)
n_quantizers
[:
n_dropout
]
=
dropout
[:
n_dropout
]
n_quantizers
=
n_quantizers
.
to
(
z
.
device
)
for
i
,
quantizer
in
enumerate
(
self
.
quantizers
):
if
self
.
training
is
False
and
i
>=
n_quantizers
:
break
z_q_i
,
commitment_loss_i
,
codebook_loss_i
,
indices_i
,
z_e_i
=
quantizer
(
residual
)
# Create mask to apply quantizer dropout
mask
=
(
torch
.
full
((
z
.
shape
[
0
],),
fill_value
=
i
,
device
=
z
.
device
)
<
n_quantizers
)
z_q
=
z_q
+
z_q_i
*
mask
[:,
None
,
None
]
residual
=
residual
-
z_q_i
# Sum losses
commitment_loss
+=
(
commitment_loss_i
*
mask
).
mean
()
codebook_loss
+=
(
codebook_loss_i
*
mask
).
mean
()
codebook_indices
.
append
(
indices_i
)
latents
.
append
(
z_e_i
)
codes
=
torch
.
stack
(
codebook_indices
,
dim
=
1
)
latents
=
torch
.
cat
(
latents
,
dim
=
1
)
return
z_q
,
codes
,
latents
,
commitment_loss
,
codebook_loss
def
from_codes
(
self
,
codes
:
torch
.
Tensor
):
"""Given the quantized codes, reconstruct the continuous representation
Parameters
----------
codes : Tensor[B x N x T]
Quantized discrete representation of input
Returns
-------
Tensor[B x D x T]
Quantized continuous representation of input
"""
z_q
=
0.0
z_p
=
[]
n_codebooks
=
codes
.
shape
[
1
]
for
i
in
range
(
n_codebooks
):
z_p_i
=
self
.
quantizers
[
i
].
decode_code
(
codes
[:,
i
,
:])
z_p
.
append
(
z_p_i
)
z_q_i
=
self
.
quantizers
[
i
].
out_proj
(
z_p_i
)
z_q
=
z_q
+
z_q_i
return
z_q
,
torch
.
cat
(
z_p
,
dim
=
1
),
codes
def
from_latents
(
self
,
latents
:
torch
.
Tensor
):
"""Given the unquantized latents, reconstruct the
continuous representation after quantization.
Parameters
----------
latents : Tensor[B x N x T]
Continuous representation of input after projection
Returns
-------
Tensor[B x D x T]
Quantized representation of full-projected space
Tensor[B x D x T]
Quantized representation of latent space
"""
z_q
=
0
z_p
=
[]
codes
=
[]
dims
=
np
.
cumsum
([
0
]
+
[
q
.
codebook_dim
for
q
in
self
.
quantizers
])
n_codebooks
=
np
.
where
(
dims
<=
latents
.
shape
[
1
])[
0
].
max
(
axis
=
0
,
keepdims
=
True
)[
0
]
for
i
in
range
(
n_codebooks
):
j
,
k
=
dims
[
i
],
dims
[
i
+
1
]
z_p_i
,
codes_i
=
self
.
quantizers
[
i
].
decode_latents
(
latents
[:,
j
:
k
,
:])
z_p
.
append
(
z_p_i
)
codes
.
append
(
codes_i
)
z_q_i
=
self
.
quantizers
[
i
].
out_proj
(
z_p_i
)
z_q
=
z_q
+
z_q_i
return
z_q
,
torch
.
cat
(
z_p
,
dim
=
1
),
torch
.
stack
(
codes
,
dim
=
1
)
if
__name__
==
"__main__"
:
rvq
=
ResidualVectorQuantize
(
quantizer_dropout
=
True
)
x
=
torch
.
randn
(
16
,
512
,
80
)
y
=
rvq
(
x
)
print
(
y
[
"latents"
].
shape
)
indextts/s2mel/dac/utils/__init__.py
0 → 100644
View file @
ab9c00af
from
pathlib
import
Path
import
argbind
from
audiotools
import
ml
import
indextts.s2mel.dac
as
dac
DAC
=
dac
.
model
.
DAC
Accelerator
=
ml
.
Accelerator
__MODEL_LATEST_TAGS__
=
{
(
"44khz"
,
"8kbps"
):
"0.0.1"
,
(
"24khz"
,
"8kbps"
):
"0.0.4"
,
(
"16khz"
,
"8kbps"
):
"0.0.5"
,
(
"44khz"
,
"16kbps"
):
"1.0.0"
,
}
__MODEL_URLS__
=
{
(
"44khz"
,
"0.0.1"
,
"8kbps"
,
):
"https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth"
,
(
"24khz"
,
"0.0.4"
,
"8kbps"
,
):
"https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth"
,
(
"16khz"
,
"0.0.5"
,
"8kbps"
,
):
"https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth"
,
(
"44khz"
,
"1.0.0"
,
"16kbps"
,
):
"https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth"
,
}
@
argbind
.
bind
(
group
=
"download"
,
positional
=
True
,
without_prefix
=
True
)
def
download
(
model_type
:
str
=
"44khz"
,
model_bitrate
:
str
=
"8kbps"
,
tag
:
str
=
"latest"
):
"""
Function that downloads the weights file from URL if a local cache is not found.
Parameters
----------
model_type : str
The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz".
model_bitrate: str
Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
Only 44khz model supports 16kbps.
tag : str
The tag of the model to download. Defaults to "latest".
Returns
-------
Path
Directory path required to load model via audiotools.
"""
model_type
=
model_type
.
lower
()
tag
=
tag
.
lower
()
assert
model_type
in
[
"44khz"
,
"24khz"
,
"16khz"
,
],
"model_type must be one of '44khz', '24khz', or '16khz'"
assert
model_bitrate
in
[
"8kbps"
,
"16kbps"
,
],
"model_bitrate must be one of '8kbps', or '16kbps'"
if
tag
==
"latest"
:
tag
=
__MODEL_LATEST_TAGS__
[(
model_type
,
model_bitrate
)]
download_link
=
__MODEL_URLS__
.
get
((
model_type
,
tag
,
model_bitrate
),
None
)
if
download_link
is
None
:
raise
ValueError
(
f
"Could not find model with tag
{
tag
}
and model type
{
model_type
}
"
)
local_path
=
(
Path
.
home
()
/
".cache"
/
"descript"
/
"dac"
/
f
"weights_
{
model_type
}
_
{
model_bitrate
}
_
{
tag
}
.pth"
)
if
not
local_path
.
exists
():
local_path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
# Download the model
import
requests
response
=
requests
.
get
(
download_link
)
if
response
.
status_code
!=
200
:
raise
ValueError
(
f
"Could not download model. Received response code
{
response
.
status_code
}
"
)
local_path
.
write_bytes
(
response
.
content
)
return
local_path
def
load_model
(
model_type
:
str
=
"44khz"
,
model_bitrate
:
str
=
"8kbps"
,
tag
:
str
=
"latest"
,
load_path
:
str
=
None
,
):
if
not
load_path
:
load_path
=
download
(
model_type
=
model_type
,
model_bitrate
=
model_bitrate
,
tag
=
tag
)
generator
=
DAC
.
load
(
load_path
)
return
generator
Prev
1
2
3
4
5
6
7
8
9
10
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment