Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
index-tts-vllm
Commits
ab9c00af
Commit
ab9c00af
authored
Jan 07, 2026
by
yangzhong
Browse files
init submission
parents
Pipeline
#3176
failed with stages
in 0 seconds
Changes
316
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4021 additions
and
0 deletions
+4021
-0
indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py
.../maskgct/models/codec/speechtokenizer/modules/__init__.py
+27
-0
indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py
...tils/maskgct/models/codec/speechtokenizer/modules/conv.py
+346
-0
indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py
...tils/maskgct/models/codec/speechtokenizer/modules/lstm.py
+46
-0
indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py
...tils/maskgct/models/codec/speechtokenizer/modules/norm.py
+37
-0
indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py
...ls/codec/speechtokenizer/modules/quantization/__init__.py
+14
-0
indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py
...t/models/codec/speechtokenizer/modules/quantization/ac.py
+317
-0
indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py
...els/codec/speechtokenizer/modules/quantization/core_vq.py
+388
-0
indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py
...els/codec/speechtokenizer/modules/quantization/distrib.py
+135
-0
indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py
...t/models/codec/speechtokenizer/modules/quantization/vq.py
+125
-0
indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py
...ls/maskgct/models/codec/speechtokenizer/modules/seanet.py
+414
-0
indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py
indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py
+592
-0
indextts/utils/maskgct/models/tts/maskgct/__pycache__/llama_nar.cpython-310.pyc
.../models/tts/maskgct/__pycache__/llama_nar.cpython-310.pyc
+0
-0
indextts/utils/maskgct/models/tts/maskgct/__pycache__/maskgct_s2a.cpython-310.pyc
...odels/tts/maskgct/__pycache__/maskgct_s2a.cpython-310.pyc
+0
-0
indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt
...ils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt
+0
-0
indextts/utils/maskgct/models/tts/maskgct/llama_nar.py
indextts/utils/maskgct/models/tts/maskgct/llama_nar.py
+650
-0
indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py
indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py
+503
-0
indextts/utils/maskgct_utils.py
indextts/utils/maskgct_utils.py
+262
-0
indextts/utils/text_utils.py
indextts/utils/text_utils.py
+42
-0
indextts/utils/typical_sampling.py
indextts/utils/typical_sampling.py
+30
-0
indextts/utils/utils.py
indextts/utils/utils.py
+93
-0
No files found.
indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This source file is copied from https://github.com/facebookresearch/encodec
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Torch modules."""
# flake8: noqa
from
.conv
import
(
pad1d
,
unpad1d
,
NormConv1d
,
NormConvTranspose1d
,
NormConv2d
,
NormConvTranspose2d
,
SConv1d
,
SConvTranspose1d
,
)
from
.lstm
import
SLSTM
from
.seanet
import
SEANetEncoder
,
SEANetDecoder
indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This source file is copied from https://github.com/facebookresearch/encodec
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Convolutional layers wrappers and utilities."""
import
math
import
typing
as
tp
import
warnings
import
torch
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
torch.nn.utils
import
spectral_norm
,
weight_norm
from
.norm
import
ConvLayerNorm
CONV_NORMALIZATIONS
=
frozenset
(
[
"none"
,
"weight_norm"
,
"spectral_norm"
,
"time_layer_norm"
,
"layer_norm"
,
"time_group_norm"
,
]
)
def
apply_parametrization_norm
(
module
:
nn
.
Module
,
norm
:
str
=
"none"
)
->
nn
.
Module
:
assert
norm
in
CONV_NORMALIZATIONS
if
norm
==
"weight_norm"
:
return
weight_norm
(
module
)
elif
norm
==
"spectral_norm"
:
return
spectral_norm
(
module
)
else
:
# We already check was in CONV_NORMALIZATION, so any other choice
# doesn't need reparametrization.
return
module
def
get_norm_module
(
module
:
nn
.
Module
,
causal
:
bool
=
False
,
norm
:
str
=
"none"
,
**
norm_kwargs
)
->
nn
.
Module
:
"""Return the proper normalization module. If causal is True, this will ensure the returned
module is causal, or return an error if the normalization doesn't support causal evaluation.
"""
assert
norm
in
CONV_NORMALIZATIONS
if
norm
==
"layer_norm"
:
assert
isinstance
(
module
,
nn
.
modules
.
conv
.
_ConvNd
)
return
ConvLayerNorm
(
module
.
out_channels
,
**
norm_kwargs
)
elif
norm
==
"time_group_norm"
:
if
causal
:
raise
ValueError
(
"GroupNorm doesn't support causal evaluation."
)
assert
isinstance
(
module
,
nn
.
modules
.
conv
.
_ConvNd
)
return
nn
.
GroupNorm
(
1
,
module
.
out_channels
,
**
norm_kwargs
)
else
:
return
nn
.
Identity
()
def
get_extra_padding_for_conv1d
(
x
:
torch
.
Tensor
,
kernel_size
:
int
,
stride
:
int
,
padding_total
:
int
=
0
)
->
int
:
"""See `pad_for_conv1d`."""
length
=
x
.
shape
[
-
1
]
n_frames
=
(
length
-
kernel_size
+
padding_total
)
/
stride
+
1
ideal_length
=
(
math
.
ceil
(
n_frames
)
-
1
)
*
stride
+
(
kernel_size
-
padding_total
)
return
ideal_length
-
length
def
pad_for_conv1d
(
x
:
torch
.
Tensor
,
kernel_size
:
int
,
stride
:
int
,
padding_total
:
int
=
0
):
"""Pad for a convolution to make sure that the last window is full.
Extra padding is added at the end. This is required to ensure that we can rebuild
an output of the same length, as otherwise, even with padding, some time steps
might get removed.
For instance, with total padding = 4, kernel size = 4, stride = 2:
0 0 1 2 3 4 5 0 0 # (0s are padding)
1 2 3 # (output frames of a convolution, last 0 is never used)
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
1 2 3 4 # once you removed padding, we are missing one time step !
"""
extra_padding
=
get_extra_padding_for_conv1d
(
x
,
kernel_size
,
stride
,
padding_total
)
return
F
.
pad
(
x
,
(
0
,
extra_padding
))
def
pad1d
(
x
:
torch
.
Tensor
,
paddings
:
tp
.
Tuple
[
int
,
int
],
mode
:
str
=
"zero"
,
value
:
float
=
0.0
,
):
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
If this is the case, we insert extra 0 padding to the right before the reflection happen.
"""
length
=
x
.
shape
[
-
1
]
padding_left
,
padding_right
=
paddings
assert
padding_left
>=
0
and
padding_right
>=
0
,
(
padding_left
,
padding_right
)
if
mode
==
"reflect"
:
max_pad
=
max
(
padding_left
,
padding_right
)
extra_pad
=
0
if
length
<=
max_pad
:
extra_pad
=
max_pad
-
length
+
1
x
=
F
.
pad
(
x
,
(
0
,
extra_pad
))
padded
=
F
.
pad
(
x
,
paddings
,
mode
,
value
)
end
=
padded
.
shape
[
-
1
]
-
extra_pad
return
padded
[...,
:
end
]
else
:
return
F
.
pad
(
x
,
paddings
,
mode
,
value
)
def
unpad1d
(
x
:
torch
.
Tensor
,
paddings
:
tp
.
Tuple
[
int
,
int
]):
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
padding_left
,
padding_right
=
paddings
assert
padding_left
>=
0
and
padding_right
>=
0
,
(
padding_left
,
padding_right
)
assert
(
padding_left
+
padding_right
)
<=
x
.
shape
[
-
1
]
end
=
x
.
shape
[
-
1
]
-
padding_right
return
x
[...,
padding_left
:
end
]
class
NormConv1d
(
nn
.
Module
):
"""Wrapper around Conv1d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def
__init__
(
self
,
*
args
,
causal
:
bool
=
False
,
norm
:
str
=
"none"
,
norm_kwargs
:
tp
.
Dict
[
str
,
tp
.
Any
]
=
{},
**
kwargs
,
):
super
().
__init__
()
self
.
conv
=
apply_parametrization_norm
(
nn
.
Conv1d
(
*
args
,
**
kwargs
),
norm
)
self
.
norm
=
get_norm_module
(
self
.
conv
,
causal
,
norm
,
**
norm_kwargs
)
self
.
norm_type
=
norm
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
norm
(
x
)
return
x
class
NormConv2d
(
nn
.
Module
):
"""Wrapper around Conv2d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def
__init__
(
self
,
*
args
,
norm
:
str
=
"none"
,
norm_kwargs
:
tp
.
Dict
[
str
,
tp
.
Any
]
=
{},
**
kwargs
,
):
super
().
__init__
()
self
.
conv
=
apply_parametrization_norm
(
nn
.
Conv2d
(
*
args
,
**
kwargs
),
norm
)
self
.
norm
=
get_norm_module
(
self
.
conv
,
causal
=
False
,
norm
=
norm
,
**
norm_kwargs
)
self
.
norm_type
=
norm
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
norm
(
x
)
return
x
class
NormConvTranspose1d
(
nn
.
Module
):
"""Wrapper around ConvTranspose1d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def
__init__
(
self
,
*
args
,
causal
:
bool
=
False
,
norm
:
str
=
"none"
,
norm_kwargs
:
tp
.
Dict
[
str
,
tp
.
Any
]
=
{},
**
kwargs
,
):
super
().
__init__
()
self
.
convtr
=
apply_parametrization_norm
(
nn
.
ConvTranspose1d
(
*
args
,
**
kwargs
),
norm
)
self
.
norm
=
get_norm_module
(
self
.
convtr
,
causal
,
norm
,
**
norm_kwargs
)
self
.
norm_type
=
norm
def
forward
(
self
,
x
):
x
=
self
.
convtr
(
x
)
x
=
self
.
norm
(
x
)
return
x
class
NormConvTranspose2d
(
nn
.
Module
):
"""Wrapper around ConvTranspose2d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def
__init__
(
self
,
*
args
,
norm
:
str
=
"none"
,
norm_kwargs
:
tp
.
Dict
[
str
,
tp
.
Any
]
=
{},
**
kwargs
,
):
super
().
__init__
()
self
.
convtr
=
apply_parametrization_norm
(
nn
.
ConvTranspose2d
(
*
args
,
**
kwargs
),
norm
)
self
.
norm
=
get_norm_module
(
self
.
convtr
,
causal
=
False
,
norm
=
norm
,
**
norm_kwargs
)
def
forward
(
self
,
x
):
x
=
self
.
convtr
(
x
)
x
=
self
.
norm
(
x
)
return
x
class
SConv1d
(
nn
.
Module
):
"""Conv1d with some builtin handling of asymmetric or causal padding
and normalization.
"""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
int
,
stride
:
int
=
1
,
dilation
:
int
=
1
,
groups
:
int
=
1
,
bias
:
bool
=
True
,
causal
:
bool
=
False
,
norm
:
str
=
"none"
,
norm_kwargs
:
tp
.
Dict
[
str
,
tp
.
Any
]
=
{},
pad_mode
:
str
=
"reflect"
,
):
super
().
__init__
()
# warn user on unusual setup between dilation and stride
if
stride
>
1
and
dilation
>
1
:
warnings
.
warn
(
"SConv1d has been initialized with stride > 1 and dilation > 1"
f
" (kernel_size=
{
kernel_size
}
stride=
{
stride
}
, dilation=
{
dilation
}
)."
)
self
.
conv
=
NormConv1d
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
dilation
=
dilation
,
groups
=
groups
,
bias
=
bias
,
causal
=
causal
,
norm
=
norm
,
norm_kwargs
=
norm_kwargs
,
)
self
.
causal
=
causal
self
.
pad_mode
=
pad_mode
def
forward
(
self
,
x
):
B
,
C
,
T
=
x
.
shape
kernel_size
=
self
.
conv
.
conv
.
kernel_size
[
0
]
stride
=
self
.
conv
.
conv
.
stride
[
0
]
dilation
=
self
.
conv
.
conv
.
dilation
[
0
]
padding_total
=
(
kernel_size
-
1
)
*
dilation
-
(
stride
-
1
)
extra_padding
=
get_extra_padding_for_conv1d
(
x
,
kernel_size
,
stride
,
padding_total
)
if
self
.
causal
:
# Left padding for causal
x
=
pad1d
(
x
,
(
padding_total
,
extra_padding
),
mode
=
self
.
pad_mode
)
else
:
# Asymmetric padding required for odd strides
padding_right
=
padding_total
//
2
padding_left
=
padding_total
-
padding_right
x
=
pad1d
(
x
,
(
padding_left
,
padding_right
+
extra_padding
),
mode
=
self
.
pad_mode
)
return
self
.
conv
(
x
)
class
SConvTranspose1d
(
nn
.
Module
):
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
and normalization.
"""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
int
,
stride
:
int
=
1
,
causal
:
bool
=
False
,
norm
:
str
=
"none"
,
trim_right_ratio
:
float
=
1.0
,
norm_kwargs
:
tp
.
Dict
[
str
,
tp
.
Any
]
=
{},
):
super
().
__init__
()
self
.
convtr
=
NormConvTranspose1d
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
causal
=
causal
,
norm
=
norm
,
norm_kwargs
=
norm_kwargs
,
)
self
.
causal
=
causal
self
.
trim_right_ratio
=
trim_right_ratio
assert
(
self
.
causal
or
self
.
trim_right_ratio
==
1.0
),
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
assert
self
.
trim_right_ratio
>=
0.0
and
self
.
trim_right_ratio
<=
1.0
def
forward
(
self
,
x
):
kernel_size
=
self
.
convtr
.
convtr
.
kernel_size
[
0
]
stride
=
self
.
convtr
.
convtr
.
stride
[
0
]
padding_total
=
kernel_size
-
stride
y
=
self
.
convtr
(
x
)
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
# removed at the very end, when keeping only the right length for the output,
# as removing it here would require also passing the length at the matching layer
# in the encoder.
if
self
.
causal
:
# Trim the padding on the right according to the specified ratio
# if trim_right_ratio = 1.0, trim everything from right
padding_right
=
math
.
ceil
(
padding_total
*
self
.
trim_right_ratio
)
padding_left
=
padding_total
-
padding_right
y
=
unpad1d
(
y
,
(
padding_left
,
padding_right
))
else
:
# Asymmetric padding required for odd strides
padding_right
=
padding_total
//
2
padding_left
=
padding_total
-
padding_right
y
=
unpad1d
(
y
,
(
padding_left
,
padding_right
))
return
y
indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This source file is copied from https://github.com/facebookresearch/encodec
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""LSTM layers module."""
from
torch
import
nn
class
SLSTM
(
nn
.
Module
):
"""
LSTM without worrying about the hidden state, nor the layout of the data.
Expects input as convolutional layout.
"""
def
__init__
(
self
,
dimension
:
int
,
num_layers
:
int
=
2
,
skip
:
bool
=
True
,
bidirectional
:
bool
=
False
,
):
super
().
__init__
()
self
.
bidirectional
=
bidirectional
self
.
skip
=
skip
self
.
lstm
=
nn
.
LSTM
(
dimension
,
dimension
,
num_layers
,
bidirectional
=
bidirectional
)
def
forward
(
self
,
x
):
x
=
x
.
permute
(
2
,
0
,
1
)
y
,
_
=
self
.
lstm
(
x
)
if
self
.
bidirectional
:
x
=
x
.
repeat
(
1
,
1
,
2
)
if
self
.
skip
:
y
=
y
+
x
y
=
y
.
permute
(
1
,
2
,
0
)
return
y
indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This source file is copied from https://github.com/facebookresearch/encodec
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Normalization modules."""
import
typing
as
tp
import
einops
import
torch
from
torch
import
nn
class
ConvLayerNorm
(
nn
.
LayerNorm
):
"""
Convolution-friendly LayerNorm that moves channels to last dimensions
before running the normalization and moves them back to original position right after.
"""
def
__init__
(
self
,
normalized_shape
:
tp
.
Union
[
int
,
tp
.
List
[
int
],
torch
.
Size
],
**
kwargs
):
super
().
__init__
(
normalized_shape
,
**
kwargs
)
def
forward
(
self
,
x
):
x
=
einops
.
rearrange
(
x
,
"b ... t -> b t ..."
)
x
=
super
().
forward
(
x
)
x
=
einops
.
rearrange
(
x
,
"b t ... -> b ... t"
)
return
indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This source file is copied from https://github.com/facebookresearch/encodec
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# flake8: noqa
from
.vq
import
QuantizedResult
,
ResidualVectorQuantizer
indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This source file is copied from https://github.com/facebookresearch/encodec
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Arithmetic coder."""
import
io
import
math
import
random
import
typing
as
tp
import
torch
from
..binary
import
BitPacker
,
BitUnpacker
def
build_stable_quantized_cdf
(
pdf
:
torch
.
Tensor
,
total_range_bits
:
int
,
roundoff
:
float
=
1e-8
,
min_range
:
int
=
2
,
check
:
bool
=
True
,
)
->
torch
.
Tensor
:
"""Turn the given PDF into a quantized CDF that splits
[0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
to the PDF.
Args:
pdf (torch.Tensor): probability distribution, shape should be `[N]`.
total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
during the coding process is `[0, 2 ** total_range_bits - 1]`.
roundoff (float): will round the pdf up to that level to remove difference coming
from e.g. evaluating the Language Model on different architectures.
min_range (int): minimum range width. Should always be at least 2 for numerical
stability. Use this to avoid pathological behavior is a value
that is expected to be rare actually happens in real life.
check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
"""
pdf
=
pdf
.
detach
()
if
roundoff
:
pdf
=
(
pdf
/
roundoff
).
floor
()
*
roundoff
# interpolate with uniform distribution to achieve desired minimum probability.
total_range
=
2
**
total_range_bits
cardinality
=
len
(
pdf
)
alpha
=
min_range
*
cardinality
/
total_range
assert
alpha
<=
1
,
"you must reduce min_range"
ranges
=
(((
1
-
alpha
)
*
total_range
)
*
pdf
).
floor
().
long
()
ranges
+=
min_range
quantized_cdf
=
torch
.
cumsum
(
ranges
,
dim
=-
1
)
if
min_range
<
2
:
raise
ValueError
(
"min_range must be at least 2."
)
if
check
:
assert
quantized_cdf
[
-
1
]
<=
2
**
total_range_bits
,
quantized_cdf
[
-
1
]
if
(
(
quantized_cdf
[
1
:]
-
quantized_cdf
[:
-
1
])
<
min_range
).
any
()
or
quantized_cdf
[
0
]
<
min_range
:
raise
ValueError
(
"You must increase your total_range_bits."
)
return
quantized_cdf
class
ArithmeticCoder
:
"""ArithmeticCoder,
Let us take a distribution `p` over `N` symbols, and assume we have a stream
of random variables `s_t` sampled from `p`. Let us assume that we have a budget
of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
sequence `(s_t)` by doing the following:
1) Initialize the current range to` [0 ** 2 B - 1]`.
2) For each time step t, split the current range into contiguous chunks,
one for each possible outcome, with size roughly proportional to `p`.
For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
would be `{[0, 2], [3, 3]}`.
3) Select the chunk corresponding to `s_t`, and replace the current range with this.
4) When done encoding all the values, just select any value remaining in the range.
You will notice that this procedure can fail: for instance if at any point in time
the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
possible outcome. Intuitively, the more likely a value is, the less the range width
will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
coding scheme, likely outcomes would take less bits, and more of them can be coded
with a fixed budget.
In practice, we do not know `B` ahead of time, but we have a way to inject new bits
when the current range decreases below a given limit (given by `total_range_bits`), without
having to redo all the computations. If we encode mostly likely values, we will seldom
need to inject new bits, but a single rare value can deplete our stock of entropy!
In this explanation, we assumed that the distribution `p` was constant. In fact, the present
code works for any sequence `(p_t)` possibly different for each timestep.
We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
the KL between the true distribution and `p_t`, the most efficient the coding will be.
Args:
fo (IO[bytes]): file-like object to which the bytes will be written to.
total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
Any time the current range width fall under this limit, new bits will
be injected to rescale the initial range.
"""
def
__init__
(
self
,
fo
:
tp
.
IO
[
bytes
],
total_range_bits
:
int
=
24
):
assert
total_range_bits
<=
30
self
.
total_range_bits
=
total_range_bits
self
.
packer
=
BitPacker
(
bits
=
1
,
fo
=
fo
)
# we push single bits at a time.
self
.
low
:
int
=
0
self
.
high
:
int
=
0
self
.
max_bit
:
int
=
-
1
self
.
_dbg
:
tp
.
List
[
tp
.
Any
]
=
[]
self
.
_dbg2
:
tp
.
List
[
tp
.
Any
]
=
[]
@
property
def
delta
(
self
)
->
int
:
"""Return the current range width."""
return
self
.
high
-
self
.
low
+
1
def
_flush_common_prefix
(
self
):
# If self.low and self.high start with the sames bits,
# those won't change anymore as we always just increase the range
# by powers of 2, and we can flush them out to the bit stream.
assert
self
.
high
>=
self
.
low
,
(
self
.
low
,
self
.
high
)
assert
self
.
high
<
2
**
(
self
.
max_bit
+
1
)
while
self
.
max_bit
>=
0
:
b1
=
self
.
low
>>
self
.
max_bit
b2
=
self
.
high
>>
self
.
max_bit
if
b1
==
b2
:
self
.
low
-=
b1
<<
self
.
max_bit
self
.
high
-=
b1
<<
self
.
max_bit
assert
self
.
high
>=
self
.
low
,
(
self
.
high
,
self
.
low
,
self
.
max_bit
)
assert
self
.
low
>=
0
self
.
max_bit
-=
1
self
.
packer
.
push
(
b1
)
else
:
break
def
push
(
self
,
symbol
:
int
,
quantized_cdf
:
torch
.
Tensor
):
"""Push the given symbol on the stream, flushing out bits
if possible.
Args:
symbol (int): symbol to encode with the AC.
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
to build this from your pdf estimate.
"""
while
self
.
delta
<
2
**
self
.
total_range_bits
:
self
.
low
*=
2
self
.
high
=
self
.
high
*
2
+
1
self
.
max_bit
+=
1
range_low
=
0
if
symbol
==
0
else
quantized_cdf
[
symbol
-
1
].
item
()
range_high
=
quantized_cdf
[
symbol
].
item
()
-
1
effective_low
=
int
(
math
.
ceil
(
range_low
*
(
self
.
delta
/
(
2
**
self
.
total_range_bits
)))
)
effective_high
=
int
(
math
.
floor
(
range_high
*
(
self
.
delta
/
(
2
**
self
.
total_range_bits
)))
)
assert
self
.
low
<=
self
.
high
self
.
high
=
self
.
low
+
effective_high
self
.
low
=
self
.
low
+
effective_low
assert
self
.
low
<=
self
.
high
,
(
effective_low
,
effective_high
,
range_low
,
range_high
,
)
self
.
_dbg
.
append
((
self
.
low
,
self
.
high
))
self
.
_dbg2
.
append
((
self
.
low
,
self
.
high
))
outs
=
self
.
_flush_common_prefix
()
assert
self
.
low
<=
self
.
high
assert
self
.
max_bit
>=
-
1
assert
self
.
max_bit
<=
61
,
self
.
max_bit
return
outs
def
flush
(
self
):
"""Flush the remaining information to the stream."""
while
self
.
max_bit
>=
0
:
b1
=
(
self
.
low
>>
self
.
max_bit
)
&
1
self
.
packer
.
push
(
b1
)
self
.
max_bit
-=
1
self
.
packer
.
flush
()
class
ArithmeticDecoder
:
"""ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
Note that this must be called with **exactly** the same parameters and sequence
of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
If the AC encoder current range is [L, H], with `L` and `H` having the some common
prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
`[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
for a specific sequence of symbols and a binary-search allows us to decode those symbols.
At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
and we will need to read new bits from the stream and repeat the process.
"""
def
__init__
(
self
,
fo
:
tp
.
IO
[
bytes
],
total_range_bits
:
int
=
24
):
self
.
total_range_bits
=
total_range_bits
self
.
low
:
int
=
0
self
.
high
:
int
=
0
self
.
current
:
int
=
0
self
.
max_bit
:
int
=
-
1
self
.
unpacker
=
BitUnpacker
(
bits
=
1
,
fo
=
fo
)
# we pull single bits at a time.
# Following is for debugging
self
.
_dbg
:
tp
.
List
[
tp
.
Any
]
=
[]
self
.
_dbg2
:
tp
.
List
[
tp
.
Any
]
=
[]
self
.
_last
:
tp
.
Any
=
None
@
property
def
delta
(
self
)
->
int
:
return
self
.
high
-
self
.
low
+
1
def
_flush_common_prefix
(
self
):
# Given the current range [L, H], if both have a common prefix,
# we know we can remove it from our representation to avoid handling large numbers.
while
self
.
max_bit
>=
0
:
b1
=
self
.
low
>>
self
.
max_bit
b2
=
self
.
high
>>
self
.
max_bit
if
b1
==
b2
:
self
.
low
-=
b1
<<
self
.
max_bit
self
.
high
-=
b1
<<
self
.
max_bit
self
.
current
-=
b1
<<
self
.
max_bit
assert
self
.
high
>=
self
.
low
assert
self
.
low
>=
0
self
.
max_bit
-=
1
else
:
break
def
pull
(
self
,
quantized_cdf
:
torch
.
Tensor
)
->
tp
.
Optional
[
int
]:
"""Pull a symbol, reading as many bits from the stream as required.
This returns `None` when the stream has been exhausted.
Args:
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
to build this from your pdf estimate. This must be **exatly**
the same cdf as the one used at encoding time.
"""
while
self
.
delta
<
2
**
self
.
total_range_bits
:
bit
=
self
.
unpacker
.
pull
()
if
bit
is
None
:
return
None
self
.
low
*=
2
self
.
high
=
self
.
high
*
2
+
1
self
.
current
=
self
.
current
*
2
+
bit
self
.
max_bit
+=
1
def
bin_search
(
low_idx
:
int
,
high_idx
:
int
):
# Binary search is not just for coding interviews :)
if
high_idx
<
low_idx
:
raise
RuntimeError
(
"Binary search failed"
)
mid
=
(
low_idx
+
high_idx
)
//
2
range_low
=
quantized_cdf
[
mid
-
1
].
item
()
if
mid
>
0
else
0
range_high
=
quantized_cdf
[
mid
].
item
()
-
1
effective_low
=
int
(
math
.
ceil
(
range_low
*
(
self
.
delta
/
(
2
**
self
.
total_range_bits
)))
)
effective_high
=
int
(
math
.
floor
(
range_high
*
(
self
.
delta
/
(
2
**
self
.
total_range_bits
)))
)
low
=
effective_low
+
self
.
low
high
=
effective_high
+
self
.
low
if
self
.
current
>=
low
:
if
self
.
current
<=
high
:
return
(
mid
,
low
,
high
,
self
.
current
)
else
:
return
bin_search
(
mid
+
1
,
high_idx
)
else
:
return
bin_search
(
low_idx
,
mid
-
1
)
self
.
_last
=
(
self
.
low
,
self
.
high
,
self
.
current
,
self
.
max_bit
)
sym
,
self
.
low
,
self
.
high
,
self
.
current
=
bin_search
(
0
,
len
(
quantized_cdf
)
-
1
)
self
.
_dbg
.
append
((
self
.
low
,
self
.
high
,
self
.
current
))
self
.
_flush_common_prefix
()
self
.
_dbg2
.
append
((
self
.
low
,
self
.
high
,
self
.
current
))
return
sym
def
test
():
torch
.
manual_seed
(
1234
)
random
.
seed
(
1234
)
for
_
in
range
(
4
):
pdfs
=
[]
cardinality
=
random
.
randrange
(
4000
)
steps
=
random
.
randrange
(
100
,
500
)
fo
=
io
.
BytesIO
()
encoder
=
ArithmeticCoder
(
fo
)
symbols
=
[]
for
step
in
range
(
steps
):
pdf
=
torch
.
softmax
(
torch
.
randn
(
cardinality
),
dim
=
0
)
pdfs
.
append
(
pdf
)
q_cdf
=
build_stable_quantized_cdf
(
pdf
,
encoder
.
total_range_bits
)
symbol
=
torch
.
multinomial
(
pdf
,
1
).
item
()
symbols
.
append
(
symbol
)
encoder
.
push
(
symbol
,
q_cdf
)
encoder
.
flush
()
fo
.
seek
(
0
)
decoder
=
ArithmeticDecoder
(
fo
)
for
idx
,
(
pdf
,
symbol
)
in
enumerate
(
zip
(
pdfs
,
symbols
)):
q_cdf
=
build_stable_quantized_cdf
(
pdf
,
encoder
.
total_range_bits
)
decoded_symbol
=
decoder
.
pull
(
q_cdf
)
assert
decoded_symbol
==
symbol
,
idx
assert
decoder
.
pull
(
torch
.
zeros
(
1
))
is
None
if
__name__
==
"__main__"
:
test
()
indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This source file is copied from https://github.com/facebookresearch/encodec
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# This implementation is inspired from
# https://github.com/lucidrains/vector-quantize-pytorch
# which is released under MIT License. Hereafter, the original license:
# MIT License
#
# Copyright (c) 2020 Phil Wang
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Core vector quantization implementation."""
import
typing
as
tp
from
einops
import
rearrange
,
repeat
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
.distrib
import
broadcast_tensors
,
rank
def
default
(
val
:
tp
.
Any
,
d
:
tp
.
Any
)
->
tp
.
Any
:
return
val
if
val
is
not
None
else
d
def
ema_inplace
(
moving_avg
,
new
,
decay
:
float
):
moving_avg
.
data
.
mul_
(
decay
).
add_
(
new
,
alpha
=
(
1
-
decay
))
def
laplace_smoothing
(
x
,
n_categories
:
int
,
epsilon
:
float
=
1e-5
):
return
(
x
+
epsilon
)
/
(
x
.
sum
()
+
n_categories
*
epsilon
)
def
uniform_init
(
*
shape
:
int
):
t
=
torch
.
empty
(
shape
)
nn
.
init
.
kaiming_uniform_
(
t
)
return
t
def
sample_vectors
(
samples
,
num
:
int
):
num_samples
,
device
=
samples
.
shape
[
0
],
samples
.
device
if
num_samples
>=
num
:
indices
=
torch
.
randperm
(
num_samples
,
device
=
device
)[:
num
]
else
:
indices
=
torch
.
randint
(
0
,
num_samples
,
(
num
,),
device
=
device
)
return
samples
[
indices
]
def
kmeans
(
samples
,
num_clusters
:
int
,
num_iters
:
int
=
10
):
dim
,
dtype
=
samples
.
shape
[
-
1
],
samples
.
dtype
means
=
sample_vectors
(
samples
,
num_clusters
)
for
_
in
range
(
num_iters
):
diffs
=
rearrange
(
samples
,
"n d -> n () d"
)
-
rearrange
(
means
,
"c d -> () c d"
)
dists
=
-
(
diffs
**
2
).
sum
(
dim
=-
1
)
buckets
=
dists
.
max
(
dim
=-
1
).
indices
bins
=
torch
.
bincount
(
buckets
,
minlength
=
num_clusters
)
zero_mask
=
bins
==
0
bins_min_clamped
=
bins
.
masked_fill
(
zero_mask
,
1
)
new_means
=
buckets
.
new_zeros
(
num_clusters
,
dim
,
dtype
=
dtype
)
new_means
.
scatter_add_
(
0
,
repeat
(
buckets
,
"n -> n d"
,
d
=
dim
),
samples
)
new_means
=
new_means
/
bins_min_clamped
[...,
None
]
means
=
torch
.
where
(
zero_mask
[...,
None
],
means
,
new_means
)
return
means
,
bins
class
EuclideanCodebook
(
nn
.
Module
):
"""Codebook with Euclidean distance.
Args:
dim (int): Dimension.
codebook_size (int): Codebook size.
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
If set to true, run the k-means algorithm on the first training batch and use
the learned centroids as initialization.
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def
__init__
(
self
,
dim
:
int
,
codebook_size
:
int
,
kmeans_init
:
int
=
False
,
kmeans_iters
:
int
=
10
,
decay
:
float
=
0.99
,
epsilon
:
float
=
1e-5
,
threshold_ema_dead_code
:
int
=
2
,
):
super
().
__init__
()
self
.
decay
=
decay
init_fn
:
tp
.
Union
[
tp
.
Callable
[...,
torch
.
Tensor
],
tp
.
Any
]
=
(
uniform_init
if
not
kmeans_init
else
torch
.
zeros
)
embed
=
init_fn
(
codebook_size
,
dim
)
self
.
codebook_size
=
codebook_size
self
.
kmeans_iters
=
kmeans_iters
self
.
epsilon
=
epsilon
self
.
threshold_ema_dead_code
=
threshold_ema_dead_code
self
.
register_buffer
(
"inited"
,
torch
.
Tensor
([
not
kmeans_init
]))
self
.
register_buffer
(
"cluster_size"
,
torch
.
zeros
(
codebook_size
))
self
.
register_buffer
(
"embed"
,
embed
)
self
.
register_buffer
(
"embed_avg"
,
embed
.
clone
())
@
torch
.
jit
.
ignore
def
init_embed_
(
self
,
data
):
if
self
.
inited
:
return
embed
,
cluster_size
=
kmeans
(
data
,
self
.
codebook_size
,
self
.
kmeans_iters
)
self
.
embed
.
data
.
copy_
(
embed
)
self
.
embed_avg
.
data
.
copy_
(
embed
.
clone
())
self
.
cluster_size
.
data
.
copy_
(
cluster_size
)
self
.
inited
.
data
.
copy_
(
torch
.
Tensor
([
True
]))
# Make sure all buffers across workers are in sync after initialization
# broadcast_tensors(self.buffers())
def
replace_
(
self
,
samples
,
mask
):
modified_codebook
=
torch
.
where
(
mask
[...,
None
],
sample_vectors
(
samples
,
self
.
codebook_size
),
self
.
embed
)
self
.
embed
.
data
.
copy_
(
modified_codebook
)
def
expire_codes_
(
self
,
batch_samples
):
if
self
.
threshold_ema_dead_code
==
0
:
return
expired_codes
=
self
.
cluster_size
<
self
.
threshold_ema_dead_code
if
not
torch
.
any
(
expired_codes
):
return
batch_samples
=
rearrange
(
batch_samples
,
"... d -> (...) d"
)
self
.
replace_
(
batch_samples
,
mask
=
expired_codes
)
# broadcast_tensors(self.buffers())
def
preprocess
(
self
,
x
):
x
=
rearrange
(
x
,
"... d -> (...) d"
)
return
x
def
quantize
(
self
,
x
):
embed
=
self
.
embed
.
t
()
dist
=
-
(
x
.
pow
(
2
).
sum
(
1
,
keepdim
=
True
)
-
2
*
x
@
embed
+
embed
.
pow
(
2
).
sum
(
0
,
keepdim
=
True
)
)
embed_ind
=
dist
.
max
(
dim
=-
1
).
indices
return
embed_ind
def
postprocess_emb
(
self
,
embed_ind
,
shape
):
return
embed_ind
.
view
(
*
shape
[:
-
1
])
def
dequantize
(
self
,
embed_ind
):
quantize
=
F
.
embedding
(
embed_ind
,
self
.
embed
)
return
quantize
def
encode
(
self
,
x
):
shape
=
x
.
shape
# pre-process
x
=
self
.
preprocess
(
x
)
# quantize
embed_ind
=
self
.
quantize
(
x
)
# post-process
embed_ind
=
self
.
postprocess_emb
(
embed_ind
,
shape
)
return
embed_ind
def
decode
(
self
,
embed_ind
):
quantize
=
self
.
dequantize
(
embed_ind
)
return
quantize
def
forward
(
self
,
x
):
shape
,
dtype
=
x
.
shape
,
x
.
dtype
x
=
self
.
preprocess
(
x
)
self
.
init_embed_
(
x
)
embed_ind
=
self
.
quantize
(
x
)
embed_onehot
=
F
.
one_hot
(
embed_ind
,
self
.
codebook_size
).
type
(
dtype
)
embed_ind
=
self
.
postprocess_emb
(
embed_ind
,
shape
)
quantize
=
self
.
dequantize
(
embed_ind
)
if
self
.
training
:
# We do the expiry of code at that point as buffers are in sync
# and all the workers will take the same decision.
self
.
expire_codes_
(
x
)
ema_inplace
(
self
.
cluster_size
,
embed_onehot
.
sum
(
0
),
self
.
decay
)
embed_sum
=
x
.
t
()
@
embed_onehot
ema_inplace
(
self
.
embed_avg
,
embed_sum
.
t
(),
self
.
decay
)
cluster_size
=
(
laplace_smoothing
(
self
.
cluster_size
,
self
.
codebook_size
,
self
.
epsilon
)
*
self
.
cluster_size
.
sum
()
)
embed_normalized
=
self
.
embed_avg
/
cluster_size
.
unsqueeze
(
1
)
self
.
embed
.
data
.
copy_
(
embed_normalized
)
return
quantize
,
embed_ind
class
VectorQuantization
(
nn
.
Module
):
"""Vector quantization implementation.
Currently supports only euclidean distance.
Args:
dim (int): Dimension
codebook_size (int): Codebook size
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
commitment_weight (float): Weight for commitment loss.
"""
def
__init__
(
self
,
dim
:
int
,
codebook_size
:
int
,
codebook_dim
:
tp
.
Optional
[
int
]
=
None
,
decay
:
float
=
0.99
,
epsilon
:
float
=
1e-5
,
kmeans_init
:
bool
=
True
,
kmeans_iters
:
int
=
50
,
threshold_ema_dead_code
:
int
=
2
,
commitment_weight
:
float
=
1.0
,
):
super
().
__init__
()
_codebook_dim
:
int
=
default
(
codebook_dim
,
dim
)
requires_projection
=
_codebook_dim
!=
dim
self
.
project_in
=
(
nn
.
Linear
(
dim
,
_codebook_dim
)
if
requires_projection
else
nn
.
Identity
()
)
self
.
project_out
=
(
nn
.
Linear
(
_codebook_dim
,
dim
)
if
requires_projection
else
nn
.
Identity
()
)
self
.
epsilon
=
epsilon
self
.
commitment_weight
=
commitment_weight
self
.
_codebook
=
EuclideanCodebook
(
dim
=
_codebook_dim
,
codebook_size
=
codebook_size
,
kmeans_init
=
kmeans_init
,
kmeans_iters
=
kmeans_iters
,
decay
=
decay
,
epsilon
=
epsilon
,
threshold_ema_dead_code
=
threshold_ema_dead_code
,
)
self
.
codebook_size
=
codebook_size
@
property
def
codebook
(
self
):
return
self
.
_codebook
.
embed
def
encode
(
self
,
x
):
x
=
rearrange
(
x
,
"b d n -> b n d"
)
x
=
self
.
project_in
(
x
)
embed_in
=
self
.
_codebook
.
encode
(
x
)
return
embed_in
def
decode
(
self
,
embed_ind
):
quantize
=
self
.
_codebook
.
decode
(
embed_ind
)
quantize
=
self
.
project_out
(
quantize
)
quantize
=
rearrange
(
quantize
,
"b n d -> b d n"
)
return
quantize
def
forward
(
self
,
x
):
device
=
x
.
device
x
=
rearrange
(
x
,
"b d n -> b n d"
)
x
=
self
.
project_in
(
x
)
quantize
,
embed_ind
=
self
.
_codebook
(
x
)
if
self
.
training
:
quantize
=
x
+
(
quantize
-
x
).
detach
()
loss
=
torch
.
tensor
([
0.0
],
device
=
device
,
requires_grad
=
self
.
training
)
if
self
.
training
:
if
self
.
commitment_weight
>
0
:
commit_loss
=
F
.
mse_loss
(
quantize
.
detach
(),
x
)
loss
=
loss
+
commit_loss
*
self
.
commitment_weight
quantize
=
self
.
project_out
(
quantize
)
quantize
=
rearrange
(
quantize
,
"b n d -> b d n"
)
return
quantize
,
embed_ind
,
loss
class
ResidualVectorQuantization
(
nn
.
Module
):
"""Residual vector quantization implementation.
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
"""
def
__init__
(
self
,
*
,
num_quantizers
,
**
kwargs
):
super
().
__init__
()
self
.
layers
=
nn
.
ModuleList
(
[
VectorQuantization
(
**
kwargs
)
for
_
in
range
(
num_quantizers
)]
)
def
forward
(
self
,
x
,
n_q
:
tp
.
Optional
[
int
]
=
None
,
layers
:
tp
.
Optional
[
list
]
=
None
):
quantized_out
=
0.0
residual
=
x
all_losses
=
[]
all_indices
=
[]
out_quantized
=
[]
n_q
=
n_q
or
len
(
self
.
layers
)
for
i
,
layer
in
enumerate
(
self
.
layers
[:
n_q
]):
quantized
,
indices
,
loss
=
layer
(
residual
)
residual
=
residual
-
quantized
quantized_out
=
quantized_out
+
quantized
all_indices
.
append
(
indices
)
all_losses
.
append
(
loss
)
if
layers
and
i
in
layers
:
out_quantized
.
append
(
quantized
)
out_losses
,
out_indices
=
map
(
torch
.
stack
,
(
all_losses
,
all_indices
))
return
quantized_out
,
out_indices
,
out_losses
,
out_quantized
def
encode
(
self
,
x
:
torch
.
Tensor
,
n_q
:
tp
.
Optional
[
int
]
=
None
,
st
:
tp
.
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
residual
=
x
all_indices
=
[]
n_q
=
n_q
or
len
(
self
.
layers
)
st
=
st
or
0
for
layer
in
self
.
layers
[
st
:
n_q
]:
indices
=
layer
.
encode
(
residual
)
quantized
=
layer
.
decode
(
indices
)
residual
=
residual
-
quantized
all_indices
.
append
(
indices
)
out_indices
=
torch
.
stack
(
all_indices
)
return
out_indices
def
decode
(
self
,
q_indices
:
torch
.
Tensor
,
st
:
int
=
0
)
->
torch
.
Tensor
:
quantized_out
=
torch
.
tensor
(
0.0
,
device
=
q_indices
.
device
)
for
i
,
indices
in
enumerate
(
q_indices
):
layer
=
self
.
layers
[
st
+
i
]
quantized
=
layer
.
decode
(
indices
)
quantized_out
=
quantized_out
+
quantized
return
quantized_out
indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This source file is copied from https://github.com/facebookresearch/encodec
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Torch distributed utilities."""
import
typing
as
tp
import
torch
def
rank
():
if
torch
.
distributed
.
is_initialized
():
return
torch
.
distributed
.
get_rank
()
else
:
return
0
def
world_size
():
if
torch
.
distributed
.
is_initialized
():
return
torch
.
distributed
.
get_world_size
()
else
:
return
1
def
is_distributed
():
return
world_size
()
>
1
def
all_reduce
(
tensor
:
torch
.
Tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
):
if
is_distributed
():
return
torch
.
distributed
.
all_reduce
(
tensor
,
op
)
def
_is_complex_or_float
(
tensor
):
return
torch
.
is_floating_point
(
tensor
)
or
torch
.
is_complex
(
tensor
)
def
_check_number_of_params
(
params
:
tp
.
List
[
torch
.
Tensor
]):
# utility function to check that the number of params in all workers is the same,
# and thus avoid a deadlock with distributed all reduce.
if
not
is_distributed
()
or
not
params
:
return
# print('params[0].device ', params[0].device)
tensor
=
torch
.
tensor
([
len
(
params
)],
device
=
params
[
0
].
device
,
dtype
=
torch
.
long
)
all_reduce
(
tensor
)
if
tensor
.
item
()
!=
len
(
params
)
*
world_size
():
# If not all the workers have the same number, for at least one of them,
# this inequality will be verified.
raise
RuntimeError
(
f
"Mismatch in number of params: ours is
{
len
(
params
)
}
, "
"at least one worker has a different one."
)
def
broadcast_tensors
(
tensors
:
tp
.
Iterable
[
torch
.
Tensor
],
src
:
int
=
0
):
"""Broadcast the tensors from the given parameters to all workers.
This can be used to ensure that all workers have the same model to start with.
"""
if
not
is_distributed
():
return
tensors
=
[
tensor
for
tensor
in
tensors
if
_is_complex_or_float
(
tensor
)]
_check_number_of_params
(
tensors
)
handles
=
[]
for
tensor
in
tensors
:
# src = int(rank()) # added code
handle
=
torch
.
distributed
.
broadcast
(
tensor
.
data
,
src
=
src
,
async_op
=
True
)
handles
.
append
(
handle
)
for
handle
in
handles
:
handle
.
wait
()
def
sync_buffer
(
buffers
,
average
=
True
):
"""
Sync grad for buffers. If average is False, broadcast instead of averaging.
"""
if
not
is_distributed
():
return
handles
=
[]
for
buffer
in
buffers
:
if
torch
.
is_floating_point
(
buffer
.
data
):
if
average
:
handle
=
torch
.
distributed
.
all_reduce
(
buffer
.
data
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
async_op
=
True
)
else
:
handle
=
torch
.
distributed
.
broadcast
(
buffer
.
data
,
src
=
0
,
async_op
=
True
)
handles
.
append
((
buffer
,
handle
))
for
buffer
,
handle
in
handles
:
handle
.
wait
()
if
average
:
buffer
.
data
/=
world_size
def
sync_grad
(
params
):
"""
Simpler alternative to DistributedDataParallel, that doesn't rely
on any black magic. For simple models it can also be as fast.
Just call this on your model parameters after the call to backward!
"""
if
not
is_distributed
():
return
handles
=
[]
for
p
in
params
:
if
p
.
grad
is
not
None
:
handle
=
torch
.
distributed
.
all_reduce
(
p
.
grad
.
data
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
async_op
=
True
)
handles
.
append
((
p
,
handle
))
for
p
,
handle
in
handles
:
handle
.
wait
()
p
.
grad
.
data
/=
world_size
()
def
average_metrics
(
metrics
:
tp
.
Dict
[
str
,
float
],
count
=
1.0
):
"""Average a dictionary of metrics across all workers, using the optional
`count` as unormalized weight.
"""
if
not
is_distributed
():
return
metrics
keys
,
values
=
zip
(
*
metrics
.
items
())
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
tensor
=
torch
.
tensor
(
list
(
values
)
+
[
1
],
device
=
device
,
dtype
=
torch
.
float32
)
tensor
*=
count
all_reduce
(
tensor
)
averaged
=
(
tensor
[:
-
1
]
/
tensor
[
-
1
]).
cpu
().
tolist
()
return
dict
(
zip
(
keys
,
averaged
))
indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This source file is copied from https://github.com/facebookresearch/encodec
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Residual vector quantizer implementation."""
from
dataclasses
import
dataclass
,
field
import
math
import
typing
as
tp
import
torch
from
torch
import
nn
from
.core_vq
import
ResidualVectorQuantization
@
dataclass
class
QuantizedResult
:
quantized
:
torch
.
Tensor
codes
:
torch
.
Tensor
bandwidth
:
torch
.
Tensor
# bandwidth in kb/s used, per batch item.
penalty
:
tp
.
Optional
[
torch
.
Tensor
]
=
None
metrics
:
dict
=
field
(
default_factory
=
dict
)
class
ResidualVectorQuantizer
(
nn
.
Module
):
"""Residual Vector Quantizer.
Args:
dimension (int): Dimension of the codebooks.
n_q (int): Number of residual vector quantizers used.
bins (int): Codebook size.
decay (float): Decay for exponential moving average over the codebooks.
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def
__init__
(
self
,
dimension
:
int
=
256
,
n_q
:
int
=
8
,
bins
:
int
=
1024
,
decay
:
float
=
0.99
,
kmeans_init
:
bool
=
True
,
kmeans_iters
:
int
=
50
,
threshold_ema_dead_code
:
int
=
2
,
):
super
().
__init__
()
self
.
n_q
=
n_q
self
.
dimension
=
dimension
self
.
bins
=
bins
self
.
decay
=
decay
self
.
kmeans_init
=
kmeans_init
self
.
kmeans_iters
=
kmeans_iters
self
.
threshold_ema_dead_code
=
threshold_ema_dead_code
self
.
vq
=
ResidualVectorQuantization
(
dim
=
self
.
dimension
,
codebook_size
=
self
.
bins
,
num_quantizers
=
self
.
n_q
,
decay
=
self
.
decay
,
kmeans_init
=
self
.
kmeans_init
,
kmeans_iters
=
self
.
kmeans_iters
,
threshold_ema_dead_code
=
self
.
threshold_ema_dead_code
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
n_q
:
tp
.
Optional
[
int
]
=
None
,
layers
:
tp
.
Optional
[
list
]
=
None
,
)
->
QuantizedResult
:
"""Residual vector quantization on the given input tensor.
Args:
x (torch.Tensor): Input tensor.
n_q (int): Number of quantizer used to quantize. Default: All quantizers.
layers (list): Layer that need to return quantized. Defalt: None.
Returns:
QuantizedResult:
The quantized (or approximately quantized) representation with
the associated numbert quantizers and layer quantized required to return.
"""
n_q
=
n_q
if
n_q
else
self
.
n_q
if
layers
and
max
(
layers
)
>=
n_q
:
raise
ValueError
(
f
"Last layer index in layers: A
{
max
(
layers
)
}
. Number of quantizers in RVQ: B
{
self
.
n_q
}
. A must less than B."
)
quantized
,
codes
,
commit_loss
,
quantized_list
=
self
.
vq
(
x
,
n_q
=
n_q
,
layers
=
layers
)
return
quantized
,
codes
,
torch
.
mean
(
commit_loss
),
quantized_list
def
encode
(
self
,
x
:
torch
.
Tensor
,
n_q
:
tp
.
Optional
[
int
]
=
None
,
st
:
tp
.
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
The RVQ encode method sets the appropriate number of quantizer to use
and returns indices for each quantizer.
Args:
x (torch.Tensor): Input tensor.
n_q (int): Number of quantizer used to quantize. Default: All quantizers.
st (int): Start to encode input from which layers. Default: 0.
"""
n_q
=
n_q
if
n_q
else
self
.
n_q
st
=
st
or
0
codes
=
self
.
vq
.
encode
(
x
,
n_q
=
n_q
,
st
=
st
)
return
codes
def
decode
(
self
,
codes
:
torch
.
Tensor
,
st
:
int
=
0
)
->
torch
.
Tensor
:
"""Decode the given codes to the quantized representation.
Args:
codes (torch.Tensor): Input indices for each quantizer.
st (int): Start to decode input codes from which layers. Default: 0.
"""
quantized
=
self
.
vq
.
decode
(
codes
,
st
=
st
)
return
quantized
indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This source file is copied from https://github.com/facebookresearch/encodec
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Encodec SEANet-based encoder and decoder implementation."""
import
typing
as
tp
import
numpy
as
np
import
torch.nn
as
nn
import
torch
from
.
import
SConv1d
,
SConvTranspose1d
,
SLSTM
@
torch
.
jit
.
script
def
snake
(
x
,
alpha
):
shape
=
x
.
shape
x
=
x
.
reshape
(
shape
[
0
],
shape
[
1
],
-
1
)
x
=
x
+
(
alpha
+
1e-9
).
reciprocal
()
*
torch
.
sin
(
alpha
*
x
).
pow
(
2
)
x
=
x
.
reshape
(
shape
)
return
x
class
Snake1d
(
nn
.
Module
):
def
__init__
(
self
,
channels
):
super
().
__init__
()
self
.
alpha
=
nn
.
Parameter
(
torch
.
ones
(
1
,
channels
,
1
))
def
forward
(
self
,
x
):
return
snake
(
x
,
self
.
alpha
)
class
SEANetResnetBlock
(
nn
.
Module
):
"""Residual block from SEANet model.
Args:
dim (int): Dimension of the input/output
kernel_sizes (list): List of kernel sizes for the convolutions.
dilations (list): List of dilations for the convolutions.
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function
norm (str): Normalization method.
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
causal (bool): Whether to use fully causal convolution.
pad_mode (str): Padding mode for the convolutions.
compress (int): Reduced dimensionality in residual branches (from Demucs v3)
true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection.
"""
def
__init__
(
self
,
dim
:
int
,
kernel_sizes
:
tp
.
List
[
int
]
=
[
3
,
1
],
dilations
:
tp
.
List
[
int
]
=
[
1
,
1
],
activation
:
str
=
"ELU"
,
activation_params
:
dict
=
{
"alpha"
:
1.0
},
norm
:
str
=
"weight_norm"
,
norm_params
:
tp
.
Dict
[
str
,
tp
.
Any
]
=
{},
causal
:
bool
=
False
,
pad_mode
:
str
=
"reflect"
,
compress
:
int
=
2
,
true_skip
:
bool
=
True
,
):
super
().
__init__
()
assert
len
(
kernel_sizes
)
==
len
(
dilations
),
"Number of kernel sizes should match number of dilations"
act
=
getattr
(
nn
,
activation
)
if
activation
!=
"Snake"
else
Snake1d
hidden
=
dim
//
compress
block
=
[]
for
i
,
(
kernel_size
,
dilation
)
in
enumerate
(
zip
(
kernel_sizes
,
dilations
)):
in_chs
=
dim
if
i
==
0
else
hidden
out_chs
=
dim
if
i
==
len
(
kernel_sizes
)
-
1
else
hidden
block
+=
[
act
(
**
activation_params
)
if
activation
!=
"Snake"
else
act
(
in_chs
),
SConv1d
(
in_chs
,
out_chs
,
kernel_size
=
kernel_size
,
dilation
=
dilation
,
norm
=
norm
,
norm_kwargs
=
norm_params
,
causal
=
causal
,
pad_mode
=
pad_mode
,
),
]
self
.
block
=
nn
.
Sequential
(
*
block
)
self
.
shortcut
:
nn
.
Module
if
true_skip
:
self
.
shortcut
=
nn
.
Identity
()
else
:
self
.
shortcut
=
SConv1d
(
dim
,
dim
,
kernel_size
=
1
,
norm
=
norm
,
norm_kwargs
=
norm_params
,
causal
=
causal
,
pad_mode
=
pad_mode
,
)
def
forward
(
self
,
x
):
return
self
.
shortcut
(
x
)
+
self
.
block
(
x
)
class
SEANetEncoder
(
nn
.
Module
):
"""SEANet encoder.
Args:
channels (int): Audio channels.
dimension (int): Intermediate representation dimension.
n_filters (int): Base width for the model.
n_residual_layers (int): nb of residual layers.
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
that must match the decoder order
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function
norm (str): Normalization method.
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
kernel_size (int): Kernel size for the initial convolution.
last_kernel_size (int): Kernel size for the initial convolution.
residual_kernel_size (int): Kernel size for the residual layers.
dilation_base (int): How much to increase the dilation with each layer.
causal (bool): Whether to use fully causal convolution.
pad_mode (str): Padding mode for the convolutions.
true_skip (bool): Whether to use true skip connection or a simple
(streamable) convolution as the skip connection in the residual network blocks.
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
lstm (int): Number of LSTM layers at the end of the encoder.
"""
def
__init__
(
self
,
channels
:
int
=
1
,
dimension
:
int
=
128
,
n_filters
:
int
=
32
,
n_residual_layers
:
int
=
1
,
ratios
:
tp
.
List
[
int
]
=
[
8
,
5
,
4
,
2
],
activation
:
str
=
"ELU"
,
activation_params
:
dict
=
{
"alpha"
:
1.0
},
norm
:
str
=
"weight_norm"
,
norm_params
:
tp
.
Dict
[
str
,
tp
.
Any
]
=
{},
kernel_size
:
int
=
7
,
last_kernel_size
:
int
=
7
,
residual_kernel_size
:
int
=
3
,
dilation_base
:
int
=
2
,
causal
:
bool
=
False
,
pad_mode
:
str
=
"reflect"
,
true_skip
:
bool
=
False
,
compress
:
int
=
2
,
lstm
:
int
=
2
,
bidirectional
:
bool
=
False
,
):
super
().
__init__
()
self
.
channels
=
channels
self
.
dimension
=
dimension
self
.
n_filters
=
n_filters
self
.
ratios
=
list
(
reversed
(
ratios
))
del
ratios
self
.
n_residual_layers
=
n_residual_layers
self
.
hop_length
=
np
.
prod
(
self
.
ratios
)
# 计算乘积
act
=
getattr
(
nn
,
activation
)
if
activation
!=
"Snake"
else
Snake1d
mult
=
1
model
:
tp
.
List
[
nn
.
Module
]
=
[
SConv1d
(
channels
,
mult
*
n_filters
,
kernel_size
,
norm
=
norm
,
norm_kwargs
=
norm_params
,
causal
=
causal
,
pad_mode
=
pad_mode
,
)
]
# Downsample to raw audio scale
for
i
,
ratio
in
enumerate
(
self
.
ratios
):
# Add residual layers
for
j
in
range
(
n_residual_layers
):
model
+=
[
SEANetResnetBlock
(
mult
*
n_filters
,
kernel_sizes
=
[
residual_kernel_size
,
1
],
dilations
=
[
dilation_base
**
j
,
1
],
norm
=
norm
,
norm_params
=
norm_params
,
activation
=
activation
,
activation_params
=
activation_params
,
causal
=
causal
,
pad_mode
=
pad_mode
,
compress
=
compress
,
true_skip
=
true_skip
,
)
]
# Add downsampling layers
model
+=
[
(
act
(
**
activation_params
)
if
activation
!=
"Snake"
else
act
(
mult
*
n_filters
)
),
SConv1d
(
mult
*
n_filters
,
mult
*
n_filters
*
2
,
kernel_size
=
ratio
*
2
,
stride
=
ratio
,
norm
=
norm
,
norm_kwargs
=
norm_params
,
causal
=
causal
,
pad_mode
=
pad_mode
,
),
]
mult
*=
2
if
lstm
:
model
+=
[
SLSTM
(
mult
*
n_filters
,
num_layers
=
lstm
,
bidirectional
=
bidirectional
)
]
mult
=
mult
*
2
if
bidirectional
else
mult
model
+=
[
(
act
(
**
activation_params
)
if
activation
!=
"Snake"
else
act
(
mult
*
n_filters
)
),
SConv1d
(
mult
*
n_filters
,
dimension
,
last_kernel_size
,
norm
=
norm
,
norm_kwargs
=
norm_params
,
causal
=
causal
,
pad_mode
=
pad_mode
,
),
]
self
.
model
=
nn
.
Sequential
(
*
model
)
def
forward
(
self
,
x
):
return
self
.
model
(
x
)
class
SEANetDecoder
(
nn
.
Module
):
"""SEANet decoder.
Args:
channels (int): Audio channels.
dimension (int): Intermediate representation dimension.
n_filters (int): Base width for the model.
n_residual_layers (int): nb of residual layers.
ratios (Sequence[int]): kernel size and stride ratios
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function
final_activation (str): Final activation function after all convolutions.
final_activation_params (dict): Parameters to provide to the activation function
norm (str): Normalization method.
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
kernel_size (int): Kernel size for the initial convolution.
last_kernel_size (int): Kernel size for the initial convolution.
residual_kernel_size (int): Kernel size for the residual layers.
dilation_base (int): How much to increase the dilation with each layer.
causal (bool): Whether to use fully causal convolution.
pad_mode (str): Padding mode for the convolutions.
true_skip (bool): Whether to use true skip connection or a simple
(streamable) convolution as the skip connection in the residual network blocks.
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
lstm (int): Number of LSTM layers at the end of the encoder.
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
If equal to 1.0, it means that all the trimming is done at the right.
"""
def
__init__
(
self
,
channels
:
int
=
1
,
dimension
:
int
=
128
,
n_filters
:
int
=
32
,
n_residual_layers
:
int
=
1
,
ratios
:
tp
.
List
[
int
]
=
[
8
,
5
,
4
,
2
],
activation
:
str
=
"ELU"
,
activation_params
:
dict
=
{
"alpha"
:
1.0
},
final_activation
:
tp
.
Optional
[
str
]
=
None
,
final_activation_params
:
tp
.
Optional
[
dict
]
=
None
,
norm
:
str
=
"weight_norm"
,
norm_params
:
tp
.
Dict
[
str
,
tp
.
Any
]
=
{},
kernel_size
:
int
=
7
,
last_kernel_size
:
int
=
7
,
residual_kernel_size
:
int
=
3
,
dilation_base
:
int
=
2
,
causal
:
bool
=
False
,
pad_mode
:
str
=
"reflect"
,
true_skip
:
bool
=
False
,
compress
:
int
=
2
,
lstm
:
int
=
2
,
trim_right_ratio
:
float
=
1.0
,
bidirectional
:
bool
=
False
,
):
super
().
__init__
()
self
.
dimension
=
dimension
self
.
channels
=
channels
self
.
n_filters
=
n_filters
self
.
ratios
=
ratios
del
ratios
self
.
n_residual_layers
=
n_residual_layers
self
.
hop_length
=
np
.
prod
(
self
.
ratios
)
act
=
getattr
(
nn
,
activation
)
if
activation
!=
"Snake"
else
Snake1d
mult
=
int
(
2
**
len
(
self
.
ratios
))
model
:
tp
.
List
[
nn
.
Module
]
=
[
SConv1d
(
dimension
,
mult
*
n_filters
,
kernel_size
,
norm
=
norm
,
norm_kwargs
=
norm_params
,
causal
=
causal
,
pad_mode
=
pad_mode
,
)
]
if
lstm
:
model
+=
[
SLSTM
(
mult
*
n_filters
,
num_layers
=
lstm
,
bidirectional
=
bidirectional
)
]
# Upsample to raw audio scale
for
i
,
ratio
in
enumerate
(
self
.
ratios
):
# Add upsampling layers
model
+=
[
(
act
(
**
activation_params
)
if
activation
!=
"Snake"
else
act
(
mult
*
n_filters
)
),
SConvTranspose1d
(
mult
*
n_filters
,
mult
*
n_filters
//
2
,
kernel_size
=
ratio
*
2
,
stride
=
ratio
,
norm
=
norm
,
norm_kwargs
=
norm_params
,
causal
=
causal
,
trim_right_ratio
=
trim_right_ratio
,
),
]
# Add residual layers
for
j
in
range
(
n_residual_layers
):
model
+=
[
SEANetResnetBlock
(
mult
*
n_filters
//
2
,
kernel_sizes
=
[
residual_kernel_size
,
1
],
dilations
=
[
dilation_base
**
j
,
1
],
activation
=
activation
,
activation_params
=
activation_params
,
norm
=
norm
,
norm_params
=
norm_params
,
causal
=
causal
,
pad_mode
=
pad_mode
,
compress
=
compress
,
true_skip
=
true_skip
,
)
]
mult
//=
2
# Add final layers
model
+=
[
act
(
**
activation_params
)
if
activation
!=
"Snake"
else
act
(
n_filters
),
SConv1d
(
n_filters
,
channels
,
last_kernel_size
,
norm
=
norm
,
norm_kwargs
=
norm_params
,
causal
=
causal
,
pad_mode
=
pad_mode
,
),
]
# Add optional final activation to decoder (eg. tanh)
if
final_activation
is
not
None
:
final_act
=
getattr
(
nn
,
final_activation
)
final_activation_params
=
final_activation_params
or
{}
model
+=
[
final_act
(
**
final_activation_params
)]
self
.
model
=
nn
.
Sequential
(
*
model
)
def
forward
(
self
,
z
):
y
=
self
.
model
(
z
)
return
y
def
test
():
import
torch
encoder
=
SEANetEncoder
()
decoder
=
SEANetDecoder
()
x
=
torch
.
randn
(
1
,
1
,
24000
)
z
=
encoder
(
x
)
print
(
"z "
,
z
.
shape
)
assert
1
==
2
assert
list
(
z
.
shape
)
==
[
1
,
128
,
75
],
z
.
shape
y
=
decoder
(
z
)
assert
y
.
shape
==
x
.
shape
,
(
x
.
shape
,
y
.
shape
)
if
__name__
==
"__main__"
:
test
()
indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) ByteDance, Inc. and its affiliates.
# Copyright (c) Chutong Meng
#
# This source code is licensed under the CC BY-NC license found in the
# LICENSE file in the root directory of this source tree.
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
VectorQuantize
(
nn
.
Module
):
"""Vector quantization w/ exponential moving averages (EMA)"""
def
__init__
(
self
,
dim
:
int
,
codebook_size
:
int
,
decay
=
0.8
,
commitment
=
1.0
,
eps
=
1e-5
,
n_embed
=
None
,
):
super
().
__init__
()
n_embed
=
self
.
default
(
n_embed
,
codebook_size
)
self
.
dim
=
dim
self
.
n_embed
=
n_embed
self
.
decay
=
decay
self
.
eps
=
eps
self
.
commitment
=
commitment
embed
=
torch
.
randn
(
dim
,
n_embed
)
self
.
register_buffer
(
"embed"
,
embed
)
self
.
register_buffer
(
"cluster_size"
,
torch
.
zeros
(
n_embed
))
self
.
register_buffer
(
"embed_avg"
,
embed
.
clone
())
@
property
def
codebook
(
self
):
return
self
.
embed
.
transpose
(
0
,
1
)
def
exists
(
self
,
val
):
return
val
is
not
None
def
default
(
self
,
val
,
d
):
return
val
if
self
.
exists
(
val
)
else
d
def
ema_inplace
(
self
,
moving_avg
,
new
,
decay
):
moving_avg
.
data
.
mul_
(
decay
).
add_
(
new
,
alpha
=
(
1
-
decay
))
def
laplace_smoothing
(
self
,
x
,
n_categories
,
eps
=
1e-5
):
return
(
x
+
eps
)
/
(
x
.
sum
()
+
n_categories
*
eps
)
def
forward
(
self
,
input
):
dtype
=
input
.
dtype
flatten
=
input
.
reshape
(
-
1
,
self
.
dim
)
dist
=
(
flatten
.
pow
(
2
).
sum
(
1
,
keepdim
=
True
)
-
2
*
flatten
@
self
.
embed
+
self
.
embed
.
pow
(
2
).
sum
(
0
,
keepdim
=
True
)
)
_
,
embed_ind
=
(
-
dist
).
max
(
1
)
embed_onehot
=
F
.
one_hot
(
embed_ind
,
self
.
n_embed
).
type
(
dtype
)
embed_ind
=
embed_ind
.
view
(
*
input
.
shape
[:
-
1
])
quantize
=
F
.
embedding
(
embed_ind
,
self
.
embed
.
transpose
(
0
,
1
))
if
self
.
training
:
self
.
ema_inplace
(
self
.
cluster_size
,
embed_onehot
.
sum
(
0
),
self
.
decay
)
embed_sum
=
flatten
.
transpose
(
0
,
1
)
@
embed_onehot
self
.
ema_inplace
(
self
.
embed_avg
,
embed_sum
,
self
.
decay
)
cluster_size
=
(
self
.
laplace_smoothing
(
self
.
cluster_size
,
self
.
n_embed
,
self
.
eps
)
*
self
.
cluster_size
.
sum
()
)
embed_normalized
=
self
.
embed_avg
/
cluster_size
.
unsqueeze
(
0
)
self
.
embed
.
data
.
copy_
(
embed_normalized
)
loss
=
F
.
mse_loss
(
quantize
.
detach
(),
input
)
*
self
.
commitment
quantize
=
input
+
(
quantize
-
input
).
detach
()
avg_probs
=
torch
.
mean
(
embed_onehot
,
dim
=
0
)
perplexity
=
torch
.
exp
(
-
torch
.
sum
(
avg_probs
*
torch
.
log
(
avg_probs
+
1e-10
)))
return
quantize
,
loss
,
perplexity
def
forward_index
(
self
,
input
):
dtype
=
input
.
dtype
flatten
=
input
.
reshape
(
-
1
,
self
.
dim
)
dist
=
(
flatten
.
pow
(
2
).
sum
(
1
,
keepdim
=
True
)
-
2
*
flatten
@
self
.
embed
+
self
.
embed
.
pow
(
2
).
sum
(
0
,
keepdim
=
True
)
)
_
,
embed_ind
=
(
-
dist
).
max
(
1
)
embed_onehot
=
F
.
one_hot
(
embed_ind
,
self
.
n_embed
).
type
(
dtype
)
embed_ind
=
embed_ind
.
view
(
*
input
.
shape
[:
-
1
])
quantize
=
F
.
embedding
(
embed_ind
,
self
.
embed
.
transpose
(
0
,
1
))
quantize
=
input
+
(
quantize
-
input
).
detach
()
return
quantize
,
embed_ind
class
ResidualVQ
(
nn
.
Module
):
"""Residual VQ following algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
def
__init__
(
self
,
*
,
num_quantizers
,
**
kwargs
):
super
().
__init__
()
self
.
layers
=
nn
.
ModuleList
(
[
VectorQuantize
(
**
kwargs
)
for
_
in
range
(
num_quantizers
)]
)
def
forward
(
self
,
x
):
quantized_out
=
0.0
residual
=
x
all_losses
=
[]
all_perplexities
=
[]
for
layer
in
self
.
layers
:
quantized
,
loss
,
perplexity
=
layer
(
residual
)
# Issue: https://github.com/lucidrains/vector-quantize-pytorch/issues/33
# We found considering only the 1st layer VQ's graident results in better performance
# residual = residual - quantized.detach() # considering all layers' graidents
residual
=
(
residual
-
quantized
)
# considering only the first layer's graident
quantized_out
=
quantized_out
+
quantized
all_losses
.
append
(
loss
)
all_perplexities
.
append
(
perplexity
)
all_losses
,
all_perplexities
=
map
(
torch
.
stack
,
(
all_losses
,
all_perplexities
))
return
quantized_out
,
all_losses
,
all_perplexities
def
forward_index
(
self
,
x
,
flatten_idx
=
False
):
"""
all_indices: [num_of_quantizers, B, T]
"""
quantized_out
=
0.0
residual
=
x
all_indices
=
[]
for
i
,
layer
in
enumerate
(
self
.
layers
):
quantized
,
indices
=
layer
.
forward_index
(
residual
)
# residual = residual - quantized.detach()
residual
=
residual
-
quantized
quantized_out
=
quantized_out
+
quantized
if
flatten_idx
:
indices
+=
self
.
codebook_size
*
i
all_indices
.
append
(
indices
)
all_indices
=
torch
.
stack
(
all_indices
)
return
quantized_out
,
all_indices
def
initial
(
self
):
self
.
codebook
=
[]
for
layer
in
self
.
layers
:
self
.
codebook
.
append
(
layer
.
codebook
)
self
.
codebook_size
=
self
.
codebook
[
0
].
size
(
0
)
self
.
codebook
=
torch
.
stack
(
self
.
codebook
)
self
.
codebook
=
self
.
codebook
.
reshape
(
-
1
,
self
.
codebook
.
size
(
-
1
))
def
lookup
(
self
,
indices
):
quantized_out
=
F
.
embedding
(
indices
,
self
.
codebook
)
# Num x T x C
return
torch
.
sum
(
quantized_out
,
dim
=
0
,
keepdim
=
True
)
class
Quantizer
(
nn
.
Module
):
def
__init__
(
self
,
code_dim
:
int
,
codebook_num
:
int
,
codebook_size
:
int
,
):
super
().
__init__
()
self
.
codebook
=
ResidualVQ
(
dim
=
code_dim
,
num_quantizers
=
codebook_num
,
codebook_size
=
codebook_size
)
def
initial
(
self
):
self
.
codebook
.
initial
()
def
forward
(
self
,
z
):
zq
,
vqloss
,
perplexity
=
self
.
codebook
(
z
.
transpose
(
2
,
1
))
zq
=
zq
.
transpose
(
2
,
1
)
return
zq
,
vqloss
,
perplexity
def
inference
(
self
,
z
):
zq
,
indices
=
self
.
codebook
.
forward_index
(
z
.
transpose
(
2
,
1
))
zq
=
zq
.
transpose
(
2
,
1
)
return
zq
,
indices
def
encode
(
self
,
z
):
zq
,
indices
=
self
.
codebook
.
forward_index
(
z
.
transpose
(
2
,
1
),
flatten_idx
=
True
)
return
zq
,
indices
def
decode
(
self
,
indices
):
z
=
self
.
codebook
.
lookup
(
indices
)
return
z
class
Conv1d1x1
(
nn
.
Conv1d
):
"""1x1 Conv1d."""
def
__init__
(
self
,
in_channels
,
out_channels
,
bias
=
True
):
super
(
Conv1d1x1
,
self
).
__init__
(
in_channels
,
out_channels
,
kernel_size
=
1
,
bias
=
bias
)
class
Conv1d
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
int
,
stride
:
int
=
1
,
padding
:
int
=
-
1
,
dilation
:
int
=
1
,
groups
:
int
=
1
,
bias
:
bool
=
True
,
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
kernel_size
if
padding
<
0
:
padding
=
(
kernel_size
-
1
)
//
2
*
dilation
self
.
dilation
=
dilation
self
.
conv
=
nn
.
Conv1d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
groups
=
groups
,
bias
=
bias
,
)
def
forward
(
self
,
x
):
"""
Args:
x (Tensor): Float tensor variable with the shape (B, C, T).
Returns:
Tensor: Float tensor variable with the shape (B, C, T).
"""
x
=
self
.
conv
(
x
)
return
x
class
ConvTranspose1d
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
int
,
stride
:
int
,
padding
=-
1
,
output_padding
=-
1
,
groups
=
1
,
bias
=
True
,
):
super
().
__init__
()
if
padding
<
0
:
padding
=
(
stride
+
1
)
//
2
if
output_padding
<
0
:
output_padding
=
1
if
stride
%
2
else
0
self
.
deconv
=
nn
.
ConvTranspose1d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
output_padding
=
output_padding
,
groups
=
groups
,
bias
=
bias
,
)
def
forward
(
self
,
x
):
"""
Args:
x (Tensor): Float tensor variable with the shape (B, C, T).
Returns:
Tensor: Float tensor variable with the shape (B, C', T').
"""
x
=
self
.
deconv
(
x
)
return
x
class
ResidualUnit
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
=
3
,
dilation
=
1
,
bias
=
False
,
nonlinear_activation
=
"ELU"
,
nonlinear_activation_params
=
{},
):
super
().
__init__
()
self
.
activation
=
getattr
(
nn
,
nonlinear_activation
)(
**
nonlinear_activation_params
)
self
.
conv1
=
Conv1d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
1
,
dilation
=
dilation
,
bias
=
bias
,
)
self
.
conv2
=
Conv1d1x1
(
out_channels
,
out_channels
,
bias
)
def
forward
(
self
,
x
):
y
=
self
.
conv1
(
self
.
activation
(
x
))
y
=
self
.
conv2
(
self
.
activation
(
y
))
return
x
+
y
class
Projector
(
nn
.
Module
):
def
__init__
(
self
,
input_channels
:
int
,
code_dim
:
int
,
kernel_size
=
3
,
stride
=
1
,
bias
=
False
):
super
().
__init__
()
self
.
project
=
Conv1d
(
input_channels
,
code_dim
,
kernel_size
=
kernel_size
,
stride
=
stride
,
bias
=
bias
)
def
forward
(
self
,
x
):
return
self
.
project
(
x
)
class
EncoderBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
stride
:
int
,
dilations
=
(
1
,
1
),
unit_kernel_size
=
3
,
bias
=
True
,
):
super
().
__init__
()
self
.
res_units
=
torch
.
nn
.
ModuleList
()
for
dilation
in
dilations
:
self
.
res_units
+=
[
ResidualUnit
(
in_channels
,
in_channels
,
kernel_size
=
unit_kernel_size
,
dilation
=
dilation
,
)
]
self
.
num_res
=
len
(
self
.
res_units
)
self
.
conv
=
Conv1d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
(
3
if
stride
==
1
else
(
2
*
stride
)
),
# special case: stride=1, do not use kernel=2
stride
=
stride
,
bias
=
bias
,
)
def
forward
(
self
,
x
):
for
idx
in
range
(
self
.
num_res
):
x
=
self
.
res_units
[
idx
](
x
)
x
=
self
.
conv
(
x
)
return
x
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
input_channels
:
int
,
encode_channels
:
int
,
channel_ratios
=
(
1
,
1
),
strides
=
(
1
,
1
),
kernel_size
=
3
,
bias
=
True
,
block_dilations
=
(
1
,
1
),
unit_kernel_size
=
3
,
):
super
().
__init__
()
assert
len
(
channel_ratios
)
==
len
(
strides
)
self
.
conv
=
Conv1d
(
in_channels
=
input_channels
,
out_channels
=
encode_channels
,
kernel_size
=
kernel_size
,
stride
=
1
,
bias
=
False
,
)
self
.
conv_blocks
=
torch
.
nn
.
ModuleList
()
in_channels
=
encode_channels
for
idx
,
stride
in
enumerate
(
strides
):
out_channels
=
int
(
encode_channels
*
channel_ratios
[
idx
])
# could be float
self
.
conv_blocks
+=
[
EncoderBlock
(
in_channels
,
out_channels
,
stride
,
dilations
=
block_dilations
,
unit_kernel_size
=
unit_kernel_size
,
bias
=
bias
,
)
]
in_channels
=
out_channels
self
.
num_blocks
=
len
(
self
.
conv_blocks
)
self
.
out_channels
=
out_channels
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
for
i
in
range
(
self
.
num_blocks
):
x
=
self
.
conv_blocks
[
i
](
x
)
return
x
class
DecoderBlock
(
nn
.
Module
):
"""Decoder block (no up-sampling)"""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
stride
:
int
,
dilations
=
(
1
,
1
),
unit_kernel_size
=
3
,
bias
=
True
,
):
super
().
__init__
()
if
stride
==
1
:
self
.
conv
=
Conv1d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
# fix kernel=3 when stride=1 for unchanged shape
stride
=
stride
,
bias
=
bias
,
)
else
:
self
.
conv
=
ConvTranspose1d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
(
2
*
stride
),
stride
=
stride
,
bias
=
bias
,
)
self
.
res_units
=
torch
.
nn
.
ModuleList
()
for
idx
,
dilation
in
enumerate
(
dilations
):
self
.
res_units
+=
[
ResidualUnit
(
out_channels
,
out_channels
,
kernel_size
=
unit_kernel_size
,
dilation
=
dilation
,
)
]
self
.
num_res
=
len
(
self
.
res_units
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
for
idx
in
range
(
self
.
num_res
):
x
=
self
.
res_units
[
idx
](
x
)
return
x
class
Decoder
(
nn
.
Module
):
def
__init__
(
self
,
code_dim
:
int
,
output_channels
:
int
,
decode_channels
:
int
,
channel_ratios
=
(
1
,
1
),
strides
=
(
1
,
1
),
kernel_size
=
3
,
bias
=
True
,
block_dilations
=
(
1
,
1
),
unit_kernel_size
=
3
,
):
super
().
__init__
()
assert
len
(
channel_ratios
)
==
len
(
strides
)
self
.
conv1
=
Conv1d
(
in_channels
=
code_dim
,
out_channels
=
int
(
decode_channels
*
channel_ratios
[
0
]),
kernel_size
=
kernel_size
,
stride
=
1
,
bias
=
False
,
)
self
.
conv_blocks
=
torch
.
nn
.
ModuleList
()
for
idx
,
stride
in
enumerate
(
strides
):
in_channels
=
int
(
decode_channels
*
channel_ratios
[
idx
])
if
idx
<
(
len
(
channel_ratios
)
-
1
):
out_channels
=
int
(
decode_channels
*
channel_ratios
[
idx
+
1
])
else
:
out_channels
=
decode_channels
self
.
conv_blocks
+=
[
DecoderBlock
(
in_channels
,
out_channels
,
stride
,
dilations
=
block_dilations
,
unit_kernel_size
=
unit_kernel_size
,
bias
=
bias
,
)
]
self
.
num_blocks
=
len
(
self
.
conv_blocks
)
self
.
conv2
=
Conv1d
(
out_channels
,
output_channels
,
kernel_size
,
1
,
bias
=
False
)
def
forward
(
self
,
z
):
x
=
self
.
conv1
(
z
)
for
i
in
range
(
self
.
num_blocks
):
x
=
self
.
conv_blocks
[
i
](
x
)
x
=
self
.
conv2
(
x
)
return
x
class
VevoRepCodec
(
nn
.
Module
):
def
__init__
(
self
,
input_channels
=
768
,
output_channels
=
768
,
encode_channels
=
768
,
decode_channels
=
768
,
code_dim
=
768
,
codebook_num
=
1
,
codebook_size
=
1024
,
bias
=
True
,
enc_ratios
=
(
1
,
1
),
dec_ratios
=
(
1
,
1
),
enc_strides
=
(
1
,
1
),
dec_strides
=
(
1
,
1
),
enc_kernel_size
=
3
,
dec_kernel_size
=
3
,
enc_block_dilations
=
(
1
,
1
),
enc_block_kernel_size
=
3
,
dec_block_dilations
=
(
1
,
1
),
dec_block_kernel_size
=
3
,
):
super
().
__init__
()
self
.
input_channels
=
input_channels
self
.
encoder
=
Encoder
(
input_channels
=
input_channels
,
encode_channels
=
encode_channels
,
channel_ratios
=
enc_ratios
,
strides
=
enc_strides
,
kernel_size
=
enc_kernel_size
,
bias
=
bias
,
block_dilations
=
enc_block_dilations
,
unit_kernel_size
=
enc_block_kernel_size
,
)
self
.
decoder
=
Decoder
(
code_dim
=
code_dim
,
output_channels
=
output_channels
,
decode_channels
=
decode_channels
,
channel_ratios
=
dec_ratios
,
strides
=
dec_strides
,
kernel_size
=
dec_kernel_size
,
bias
=
bias
,
block_dilations
=
dec_block_dilations
,
unit_kernel_size
=
dec_block_kernel_size
,
)
self
.
projector
=
Projector
(
input_channels
=
self
.
encoder
.
out_channels
,
code_dim
=
code_dim
,
kernel_size
=
3
,
stride
=
1
,
bias
=
False
,
)
self
.
quantizer
=
Quantizer
(
code_dim
=
code_dim
,
codebook_num
=
codebook_num
,
codebook_size
=
codebook_size
)
def
forward
(
self
,
x
):
x
=
self
.
encoder
(
x
)
z
=
self
.
projector
(
x
)
zq
,
vqloss
,
perplexity
=
self
.
quantizer
(
z
)
y
=
self
.
decoder
(
zq
)
return
y
,
zq
,
z
,
vqloss
,
perplexity
indextts/utils/maskgct/models/tts/maskgct/__pycache__/llama_nar.cpython-310.pyc
0 → 100644
View file @
ab9c00af
File added
indextts/utils/maskgct/models/tts/maskgct/__pycache__/maskgct_s2a.cpython-310.pyc
0 → 100644
View file @
ab9c00af
File added
indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt
0 → 100644
View file @
ab9c00af
File added
indextts/utils/maskgct/models/tts/maskgct/llama_nar.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2024 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
transformers
import
LlamaConfig
,
LlamaForCausalLM
,
LlamaModel
import
torch
import
torch.nn.functional
as
F
import
numpy
as
np
import
os
import
torch.nn
as
nn
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
math
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
from
transformers.models.llama.modeling_llama
import
BaseModelOutputWithPast
# sinusoidal positional encoding
class
SinusoidalPosEmb
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
dim
=
dim
def
forward
(
self
,
x
):
device
=
x
.
device
half_dim
=
self
.
dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
device
=
device
)
*
-
emb
)
emb
=
x
[:,
None
]
*
emb
[
None
,
:]
*
1.0
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
class
LlamaAdaptiveRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
=
1024
,
eps
=
1e-6
,
dim_cond
=
1024
):
super
().
__init__
()
self
.
to_weight
=
nn
.
Linear
(
dim_cond
,
hidden_size
)
nn
.
init
.
zeros_
(
self
.
to_weight
.
weight
)
nn
.
init
.
ones_
(
self
.
to_weight
.
bias
)
self
.
variance_epsilon
=
eps
self
.
_is_hf_initialized
=
True
# disable automatic init
def
forward
(
self
,
hidden_states
,
cond_embedding
):
input_dtype
=
hidden_states
.
dtype
variance
=
hidden_states
.
to
(
torch
.
float32
).
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
weight
=
self
.
to_weight
(
cond_embedding
)
if
len
(
weight
.
shape
)
==
2
:
weight
=
weight
.
unsqueeze
(
1
)
return
(
weight
*
hidden_states
).
to
(
input_dtype
)
class
LlamaNARDecoderLayer
(
LlamaDecoderLayer
):
def
__init__
(
self
,
config
:
LlamaConfig
,
layer_idx
:
int
):
"""Override to adaptive layer norm"""
super
().
__init__
(
config
,
layer_idx
)
# init attention, mlp, etc.
self
.
input_layernorm
=
LlamaAdaptiveRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
,
dim_cond
=
config
.
hidden_size
)
self
.
post_attention_layernorm
=
LlamaAdaptiveRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
,
dim_cond
=
config
.
hidden_size
)
# add `cond` in forward function
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
cond_embedding
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
Optional
[
bool
]
=
False
,
use_cache
:
Optional
[
bool
]
=
False
,
)
->
Tuple
[
torch
.
FloatTensor
,
Optional
[
Tuple
[
torch
.
FloatTensor
,
torch
.
FloatTensor
]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
,
cond_embedding
=
cond_embedding
)
# Self Attention
hidden_states
,
self_attn_weights
,
present_key_value
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
use_cache
=
use_cache
,
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
,
cond_embedding
=
cond_embedding
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
outputs
=
(
hidden_states
,)
if
output_attentions
:
outputs
+=
(
self_attn_weights
,)
if
use_cache
:
outputs
+=
(
present_key_value
,)
return
outputs
def
__init__
(
self
,
config
:
LlamaConfig
,
layer_idx
:
int
):
"""Override to adaptive layer norm"""
super
().
__init__
(
config
,
layer_idx
)
# init attention, mlp, etc.
self
.
layer_idx
=
layer_idx
self
.
input_layernorm
=
LlamaAdaptiveRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
,
dim_cond
=
config
.
hidden_size
)
self
.
post_attention_layernorm
=
LlamaAdaptiveRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
,
dim_cond
=
config
.
hidden_size
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
cond_embedding
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
Optional
[
bool
]
=
False
,
use_cache
:
Optional
[
bool
]
=
False
,
)
->
Tuple
[
torch
.
FloatTensor
,
Optional
[
Tuple
[
torch
.
FloatTensor
,
torch
.
FloatTensor
]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
,
cond_embedding
=
cond_embedding
)
# Self Attention
hidden_states
,
self_attn_weights
,
present_key_value
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
use_cache
=
use_cache
,
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
,
cond_embedding
=
cond_embedding
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
outputs
=
(
hidden_states
,)
if
output_attentions
:
outputs
+=
(
self_attn_weights
,)
if
use_cache
:
outputs
+=
(
present_key_value
,)
return
outputs
class
DiffLlama
(
LlamaModel
):
def
__init__
(
self
,
hidden_size
=
1024
,
num_heads
=
16
,
num_layers
=
16
,
config
=
LlamaConfig
(
0
,
256
,
1024
,
1
,
1
),
):
super
().
__init__
(
config
)
self
.
layers
=
nn
.
ModuleList
(
[
LlamaNARDecoderLayer
(
LlamaConfig
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_heads
,
max_position_embeddings
=
4096
,
intermediate_size
=
hidden_size
*
4
,
),
layer_idx
=
i
,
)
for
i
in
range
(
num_layers
)
]
)
self
.
norm
=
LlamaAdaptiveRMSNorm
(
hidden_size
,
dim_cond
=
hidden_size
)
self
.
diff_step_embedding
=
SinusoidalPosEmb
(
hidden_size
)
self
.
diff_step_mlp
=
nn
.
Sequential
(
nn
.
Linear
(
hidden_size
,
hidden_size
*
4
),
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
*
4
,
hidden_size
),
)
# self.position_embedding = PositionalEncoding(hidden_size, dropout=0.0)
self
.
cond_mlp
=
nn
.
Sequential
(
nn
.
Linear
(
hidden_size
,
hidden_size
*
4
),
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
*
4
,
hidden_size
),
)
for
layer
in
self
.
layers
:
layer
.
input_layernorm
=
LlamaAdaptiveRMSNorm
(
hidden_size
,
dim_cond
=
hidden_size
)
layer
.
post_attention_layernorm
=
LlamaAdaptiveRMSNorm
(
hidden_size
,
dim_cond
=
hidden_size
)
self
.
post_init
()
# self.reset_parameters()
def
_prepare_decoder_attention_mask
(
self
,
attention_mask
,
input_shape
,
inputs_embeds
,
past_key_values_length
):
# create noncausal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask
=
None
def
_expand_mask
(
mask
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
tgt_len
:
Optional
[
int
]
=
None
):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz
,
src_len
=
mask
.
size
()
tgt_len
=
tgt_len
if
tgt_len
is
not
None
else
src_len
expanded_mask
=
(
mask
[:,
None
,
None
,
:].
expand
(
bsz
,
1
,
tgt_len
,
src_len
).
to
(
dtype
)
)
inverted_mask
=
1.0
-
expanded_mask
return
inverted_mask
.
masked_fill
(
inverted_mask
.
to
(
torch
.
bool
),
torch
.
finfo
(
dtype
).
min
)
if
attention_mask
is
not
None
:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask
=
_expand_mask
(
attention_mask
,
inputs_embeds
.
dtype
,
tgt_len
=
input_shape
[
-
1
]
).
to
(
inputs_embeds
.
device
)
combined_attention_mask
=
(
expanded_attn_mask
if
combined_attention_mask
is
None
else
expanded_attn_mask
+
combined_attention_mask
)
return
combined_attention_mask
def
forward
(
self
,
x
,
diffusion_step
,
cond
,
x_mask
,
input_ids
:
torch
.
LongTensor
=
None
,
# [num_quant, B, T]
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPast
]:
# retrieve some shape info
batch_size
,
seq_length
,
_
=
x
.
shape
# condtion mlp
cond_embedding
=
self
.
cond_mlp
(
cond
)
# (B, T, C)
# diffusion step embedding
diffusion_step
=
self
.
diff_step_embedding
(
diffusion_step
).
to
(
x
.
device
)
diffusion_step
=
self
.
diff_step_mlp
(
diffusion_step
)
# (B, C)
x
=
x
+
cond_embedding
inputs_embeds
=
x
attention_mask
=
x_mask
output_attentions
=
(
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
)
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
use_cache
=
use_cache
if
use_cache
is
not
None
else
self
.
config
.
use_cache
return_dict
=
(
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
)
seq_length_with_past
=
seq_length
past_key_values_length
=
0
if
past_key_values
is
not
None
:
past_key_values_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
seq_length_with_past
=
seq_length_with_past
+
past_key_values_length
if
position_ids
is
None
:
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
position_ids
=
torch
.
arange
(
past_key_values_length
,
seq_length
+
past_key_values_length
,
dtype
=
torch
.
long
,
device
=
device
,
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
view
(
-
1
,
seq_length
)
else
:
position_ids
=
position_ids
.
view
(
-
1
,
seq_length
).
long
()
# embed positions
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones
(
(
batch_size
,
seq_length_with_past
),
dtype
=
torch
.
bool
,
device
=
inputs_embeds
.
device
,
)
attention_mask
=
self
.
_prepare_decoder_attention_mask
(
attention_mask
,
(
batch_size
,
seq_length
),
inputs_embeds
,
past_key_values_length
,
)
hidden_states
=
inputs_embeds
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
use_cache
=
False
# decoder layers
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attns
=
()
if
output_attentions
else
None
next_decoder_cache
=
()
if
use_cache
else
None
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
past_key_value
=
(
past_key_values
[
idx
]
if
past_key_values
is
not
None
else
None
)
if
self
.
gradient_checkpointing
and
self
.
training
:
raise
NotImplementedError
else
:
layer_outputs
=
decoder_layer
(
hidden_states
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
use_cache
=
use_cache
,
cond_embedding
=
diffusion_step
,
)
hidden_states
=
layer_outputs
[
0
]
if
use_cache
:
next_decoder_cache
+=
(
layer_outputs
[
2
if
output_attentions
else
1
],)
if
output_attentions
:
all_self_attns
+=
(
layer_outputs
[
1
],)
hidden_states
=
self
.
norm
(
hidden_states
,
cond_embedding
=
diffusion_step
)
# add hidden states from the last decoder layer
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
next_cache
=
next_decoder_cache
if
use_cache
else
None
return
hidden_states
class
DiffLlamaPrefix
(
LlamaModel
):
def
__init__
(
self
,
hidden_size
=
1024
,
num_heads
=
16
,
num_layers
=
16
,
config
=
LlamaConfig
(
0
,
256
,
1024
,
1
,
1
),
):
super
().
__init__
(
config
)
self
.
layers
=
nn
.
ModuleList
(
[
LlamaNARDecoderLayer
(
LlamaConfig
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_heads
,
max_position_embeddings
=
4096
,
intermediate_size
=
hidden_size
*
4
,
),
layer_idx
=
i
,
)
for
i
in
range
(
num_layers
)
]
)
self
.
norm
=
LlamaAdaptiveRMSNorm
(
hidden_size
,
dim_cond
=
hidden_size
)
self
.
diff_step_embedding
=
SinusoidalPosEmb
(
hidden_size
)
self
.
diff_step_mlp
=
nn
.
Sequential
(
nn
.
Linear
(
hidden_size
,
hidden_size
*
4
),
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
*
4
,
hidden_size
),
)
self
.
cond_mlp
=
nn
.
Sequential
(
nn
.
Linear
(
hidden_size
,
hidden_size
*
4
),
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
*
4
,
hidden_size
),
)
for
layer
in
self
.
layers
:
layer
.
input_layernorm
=
LlamaAdaptiveRMSNorm
(
hidden_size
,
dim_cond
=
hidden_size
)
layer
.
post_attention_layernorm
=
LlamaAdaptiveRMSNorm
(
hidden_size
,
dim_cond
=
hidden_size
)
self
.
embed_tokens
=
None
self
.
post_init
()
def
_prepare_decoder_attention_mask
(
self
,
attention_mask
,
input_shape
,
inputs_embeds
,
past_key_values_length
):
# create noncausal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask
=
None
def
_expand_mask
(
mask
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
tgt_len
:
Optional
[
int
]
=
None
):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz
,
src_len
=
mask
.
size
()
tgt_len
=
tgt_len
if
tgt_len
is
not
None
else
src_len
expanded_mask
=
(
mask
[:,
None
,
None
,
:].
expand
(
bsz
,
1
,
tgt_len
,
src_len
).
to
(
dtype
)
)
inverted_mask
=
1.0
-
expanded_mask
return
inverted_mask
.
masked_fill
(
inverted_mask
.
to
(
torch
.
bool
),
torch
.
finfo
(
dtype
).
min
)
if
attention_mask
is
not
None
:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask
=
_expand_mask
(
attention_mask
,
inputs_embeds
.
dtype
,
tgt_len
=
input_shape
[
-
1
]
).
to
(
inputs_embeds
.
device
)
combined_attention_mask
=
(
expanded_attn_mask
if
combined_attention_mask
is
None
else
expanded_attn_mask
+
combined_attention_mask
)
return
combined_attention_mask
def
forward
(
self
,
x
,
diffusion_step
,
x_mask
,
phone_embedding
:
Optional
[
torch
.
LongTensor
]
=
None
,
phone_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
input_ids
:
torch
.
LongTensor
=
None
,
# [num_quant, B, T]
attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPast
]:
# retrieve some shape info
phone_embedding
=
self
.
cond_mlp
(
phone_embedding
)
# (B, T, C)
phone_length
=
phone_embedding
.
shape
[
1
]
inputs_embeds
=
torch
.
cat
([
phone_embedding
,
x
],
dim
=
1
)
attention_mask
=
torch
.
cat
([
phone_mask
,
x_mask
],
dim
=
1
)
# diffusion step embedding
diffusion_step
=
self
.
diff_step_embedding
(
diffusion_step
).
to
(
x
.
device
)
diffusion_step
=
self
.
diff_step_mlp
(
diffusion_step
)
# (B, C)
batch_size
,
seq_length
,
_
=
inputs_embeds
.
shape
output_attentions
=
(
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
)
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
use_cache
=
use_cache
if
use_cache
is
not
None
else
self
.
config
.
use_cache
return_dict
=
(
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
)
seq_length_with_past
=
seq_length
past_key_values_length
=
0
if
past_key_values
is
not
None
:
past_key_values_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
seq_length_with_past
=
seq_length_with_past
+
past_key_values_length
if
position_ids
is
None
:
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
position_ids
=
torch
.
arange
(
past_key_values_length
,
seq_length
+
past_key_values_length
,
dtype
=
torch
.
long
,
device
=
device
,
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
view
(
-
1
,
seq_length
)
else
:
position_ids
=
position_ids
.
view
(
-
1
,
seq_length
).
long
()
# embed positions
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones
(
(
batch_size
,
seq_length_with_past
),
dtype
=
torch
.
bool
,
device
=
inputs_embeds
.
device
,
)
attention_mask
=
self
.
_prepare_decoder_attention_mask
(
attention_mask
,
(
batch_size
,
seq_length
),
inputs_embeds
,
past_key_values_length
,
)
hidden_states
=
inputs_embeds
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
use_cache
=
False
# decoder layers
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attns
=
()
if
output_attentions
else
None
next_decoder_cache
=
()
if
use_cache
else
None
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
past_key_value
=
(
past_key_values
[
idx
]
if
past_key_values
is
not
None
else
None
)
if
self
.
gradient_checkpointing
and
self
.
training
:
raise
NotImplementedError
else
:
layer_outputs
=
decoder_layer
(
hidden_states
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
use_cache
=
use_cache
,
cond_embedding
=
diffusion_step
,
)
hidden_states
=
layer_outputs
[
0
]
if
use_cache
:
next_decoder_cache
+=
(
layer_outputs
[
2
if
output_attentions
else
1
],)
if
output_attentions
:
all_self_attns
+=
(
layer_outputs
[
1
],)
hidden_states
=
self
.
norm
(
hidden_states
,
cond_embedding
=
diffusion_step
)
# add hidden states from the last decoder layer
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
next_cache
=
next_decoder_cache
if
use_cache
else
None
return
hidden_states
[
:,
phone_length
:,
]
indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2024 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
import
numpy
as
np
import
torch.nn
as
nn
import
math
from
einops
import
rearrange
from
indextts.utils.maskgct.models.tts.maskgct.llama_nar
import
DiffLlama
def
top_k
(
logits
,
thres
=
0.9
):
k
=
math
.
ceil
((
1
-
thres
)
*
logits
.
shape
[
-
1
])
val
,
ind
=
logits
.
topk
(
k
,
dim
=-
1
)
probs
=
torch
.
full_like
(
logits
,
float
(
"-inf"
))
probs
.
scatter_
(
2
,
ind
,
val
)
return
probs
def
log
(
t
,
eps
=
1e-10
):
return
torch
.
log
(
t
+
eps
)
def
gumbel_noise
(
t
):
noise
=
torch
.
zeros_like
(
t
).
uniform_
(
0
,
1
)
return
-
log
(
-
log
(
noise
))
def
gumbel_sample
(
t
,
temperature
=
1.0
,
dim
=-
1
):
return
((
t
/
max
(
temperature
,
1e-10
))
+
gumbel_noise
(
t
)).
argmax
(
dim
=
dim
)
def
top_k
(
logits
,
thres
=
0.9
):
k
=
math
.
ceil
((
1
-
thres
)
*
logits
.
shape
[
-
1
])
val
,
ind
=
logits
.
topk
(
k
,
dim
=-
1
)
probs
=
torch
.
full_like
(
logits
,
float
(
"-inf"
))
probs
.
scatter_
(
2
,
ind
,
val
)
return
probs
def
log
(
t
,
eps
=
1e-10
):
return
torch
.
log
(
t
+
eps
)
def
gumbel_noise
(
t
):
noise
=
torch
.
zeros_like
(
t
).
uniform_
(
0
,
1
)
return
-
log
(
-
log
(
noise
))
def
gumbel_sample
(
t
,
temperature
=
1.0
,
dim
=-
1
):
return
((
t
/
max
(
temperature
,
1e-10
))
+
gumbel_noise
(
t
)).
argmax
(
dim
=
dim
)
class
MaskGCT_S2A
(
nn
.
Module
):
def
__init__
(
self
,
num_quantizer
=
12
,
hidden_size
=
1024
,
num_layers
=
16
,
num_heads
=
16
,
codebook_size
=
1024
,
cfg_scale
=
0.15
,
mask_layer_schedule
=
"linear"
,
cond_codebook_size
=
1024
,
cond_dim
=
1024
,
predict_layer_1
=
True
,
cfg
=
None
,
):
super
().
__init__
()
num_quantizer
=
(
cfg
.
num_quantizer
if
cfg
is
not
None
and
hasattr
(
cfg
,
"num_quantizer"
)
else
num_quantizer
)
hidden_size
=
(
cfg
.
hidden_size
if
cfg
is
not
None
and
hasattr
(
cfg
,
"hidden_size"
)
else
hidden_size
)
num_layers
=
(
cfg
.
num_layers
if
cfg
is
not
None
and
hasattr
(
cfg
,
"num_layers"
)
else
num_layers
)
num_heads
=
(
cfg
.
num_heads
if
cfg
is
not
None
and
hasattr
(
cfg
,
"num_heads"
)
else
num_heads
)
codebook_size
=
(
cfg
.
codebook_size
if
cfg
is
not
None
and
hasattr
(
cfg
,
"codebook_size"
)
else
codebook_size
)
cfg_scale
=
(
cfg
.
cfg_scale
if
cfg
is
not
None
and
hasattr
(
cfg
,
"cfg_scale"
)
else
cfg_scale
)
mask_layer_schedule
=
(
cfg
.
mask_layer_schedule
if
cfg
is
not
None
and
hasattr
(
cfg
,
"mask_layer_schedule"
)
else
mask_layer_schedule
)
cond_codebook_size
=
(
cfg
.
cond_codebook_size
if
cfg
is
not
None
and
hasattr
(
cfg
,
"cond_codebook_size"
)
else
cond_codebook_size
)
cond_dim
=
(
cfg
.
cond_dim
if
cfg
is
not
None
and
hasattr
(
cfg
,
"cond_dim"
)
else
cond_dim
)
predict_layer_1
=
(
cfg
.
predict_layer_1
if
cfg
is
not
None
and
hasattr
(
cfg
,
"predict_layer_1"
)
else
predict_layer_1
)
self
.
num_quantizer
=
num_quantizer
self
.
hidden_size
=
hidden_size
self
.
codebook_size
=
codebook_size
self
.
num_layers
=
num_layers
self
.
num_heads
=
num_heads
self
.
cfg_scale
=
cfg_scale
self
.
mask_layer_schedule
=
mask_layer_schedule
self
.
cond_codebook_size
=
cond_codebook_size
self
.
cond_dim
=
cond_dim
self
.
predict_layer_1
=
predict_layer_1
self
.
layer_emb
=
nn
.
Embedding
(
self
.
num_quantizer
,
self
.
hidden_size
)
self
.
mask_emb
=
nn
.
Embedding
(
1
,
self
.
hidden_size
)
self
.
token_emb
=
torch
.
nn
.
ModuleList
(
[
nn
.
Embedding
(
self
.
codebook_size
,
self
.
hidden_size
)
for
_
in
range
(
self
.
num_quantizer
)
]
)
self
.
to_logits
=
torch
.
nn
.
ModuleList
(
[
nn
.
Linear
(
self
.
hidden_size
,
self
.
codebook_size
)
for
_
in
range
(
self
.
num_quantizer
)
]
)
self
.
cond_emb
=
nn
.
Embedding
(
cond_codebook_size
,
self
.
hidden_size
)
self
.
reset_parameters
()
self
.
diff_estimator
=
DiffLlama
(
hidden_size
=
hidden_size
,
num_heads
=
self
.
num_heads
,
num_layers
=
num_layers
,
)
def
mask_prob
(
self
,
t
):
return
torch
.
sin
(
t
*
np
.
pi
/
2
).
to
(
t
.
device
)
def
mask_layer
(
self
,
t
):
# print(self.predict_layer_1)
if
self
.
mask_layer_schedule
==
"uniform"
:
if
self
.
predict_layer_1
:
mask_layer
=
torch
.
randint
(
0
,
self
.
num_quantizer
,
(
1
,)).
to
(
t
.
device
)
else
:
mask_layer
=
torch
.
randint
(
1
,
self
.
num_quantizer
,
(
1
,)).
to
(
t
.
device
)
elif
self
.
mask_layer_schedule
==
"cosine"
:
if
self
.
predict_layer_1
:
weights
=
torch
.
tensor
(
[
np
.
cos
(
i
/
self
.
num_quantizer
*
np
.
pi
/
2
)
for
i
in
range
(
self
.
num_quantizer
)
]
)
else
:
weights
=
torch
.
tensor
(
[
0
]
+
[
np
.
cos
((
i
-
1
)
/
self
.
num_quantizer
*
np
.
pi
/
2
)
for
i
in
range
(
1
,
self
.
num_quantizer
)
]
)
mask_layer
=
torch
.
multinomial
(
weights
,
1
).
to
(
t
.
device
)
elif
self
.
mask_layer_schedule
==
"linear"
:
if
self
.
predict_layer_1
:
weights
=
torch
.
tensor
(
[
self
.
num_quantizer
-
i
for
i
in
range
(
self
.
num_quantizer
)]
)
else
:
weights
=
torch
.
tensor
(
[
0
]
+
[
self
.
num_quantizer
-
(
i
-
1
)
for
i
in
range
(
1
,
self
.
num_quantizer
)
]
)
weights
=
weights
/
weights
.
sum
()
mask_layer
=
torch
.
multinomial
(
weights
,
1
).
to
(
t
.
device
)
# print(mask_layer)
new_t
=
t
return
mask_layer
,
new_t
def
forward_diffusion
(
self
,
x0
,
t
):
# x0: (B, T, num_quantizer)
mask_layer
,
new_t
=
self
.
mask_layer
(
t
)
# (1,)
mask_prob
=
self
.
mask_prob
(
new_t
)
# (B,)
mask_token
=
self
.
mask_emb
(
torch
.
zeros_like
(
mask_layer
))
# (1, hidden_size)
xt
=
torch
.
zeros
(
x0
.
shape
[
0
],
x0
.
shape
[
1
],
self
.
hidden_size
).
to
(
x0
.
device
)
cfg_scale
=
self
.
cfg_scale
# get prompt len
if
torch
.
rand
(
1
)
>
cfg_scale
:
prompt_len
=
torch
.
randint
(
min
(
x0
.
shape
[
1
]
//
4
,
5
),
x0
.
shape
[
1
]
//
2
,
(
x0
.
shape
[
0
],)
).
to
(
x0
.
device
)
# (B,)
else
:
prompt_len
=
torch
.
zeros
(
x0
.
shape
[
0
]).
to
(
x0
)
# (B,)
# get is prompt
is_prompt
=
torch
.
zeros_like
(
x0
[:,
:,
0
])
# (B, T)
col_indices
=
(
torch
.
arange
(
is_prompt
.
shape
[
1
])
.
repeat
(
is_prompt
.
shape
[
0
],
1
)
.
to
(
prompt_len
)
)
# (B, T)
is_prompt
[
col_indices
<
prompt_len
.
unsqueeze
(
1
)]
=
1
# (B, T) 1 if prompt
for
idx
,
token_emb_idx
in
enumerate
(
self
.
token_emb
):
if
idx
<
mask_layer
:
xt
=
xt
+
token_emb_idx
(
x0
[:,
:,
idx
])
# (B, T, hidden_size)
elif
idx
==
mask_layer
:
mask
=
torch
.
bernoulli
(
torch
.
ones_like
(
x0
[:,
:,
idx
])
*
mask_prob
[...,
None
]
)
# mask if 1, not mask if 0
# prompt part don't need to be masked
mask
[
is_prompt
.
bool
()]
=
0
# Ensure at least one token is masked
mask_num
=
mask
[:,].
sum
(
dim
=
1
,
keepdim
=
False
)
all_zero_mask
=
(
mask_num
==
0
).
bool
()
row_indices_to_modify
=
torch
.
nonzero
(
all_zero_mask
)
# mask the first token if all tokens are not masked (may mask pad if random indices)
mask
[
row_indices_to_modify
,
prompt_len
[
row_indices_to_modify
]]
=
1
mask
=
mask
[...,
None
]
# (B, T, 1)
xt
=
(
xt
+
mask
*
mask_token
[:,
None
,
:]
+
(
1
-
mask
)
*
token_emb_idx
(
x0
[:,
:,
idx
])
)
# (B, T, hidden_size)
else
:
# prompt part don't need to be masked
xt
=
(
xt
+
token_emb_idx
(
x0
[:,
:,
idx
])
*
is_prompt
[...,
None
]
+
mask_token
*
(
1
-
is_prompt
[...,
None
])
)
return
xt
,
new_t
,
mask_layer
,
mask
,
prompt_len
,
mask_prob
def
loss_t
(
self
,
x0
,
x_mask
,
t
,
cond
=
None
):
xt
,
new_t
,
mask_layer
,
mask
,
prompt_len
,
mask_prob
=
self
.
forward_diffusion
(
x0
,
t
)
# xt: (B, T, hidden_size)
# new_t: (B,)
# mask_layer: (1,)
# mask: (B, T, 1) mask if 1, not mask if 0
# prompt_len: (B,)
# mask_prob: (B,)
mask_layer_cond
=
self
.
layer_emb
(
mask_layer
).
unsqueeze
(
1
)
# (1, 1, hidden_size)
cond
=
cond
+
mask_layer_cond
# (B, T, hidden_size)
embeds
=
self
.
diff_estimator
(
xt
,
new_t
,
cond
,
x_mask
)
# (B, T, hidden_size)
logits
=
self
.
to_logits
[
mask_layer
.
item
()](
embeds
)
# (B, T, codebook_size)
# final mask used for loss calculation
final_mask
=
mask
*
x_mask
[...,
None
]
# (B, T, 1)
return
logits
,
mask_layer
,
final_mask
,
x0
,
prompt_len
,
mask_prob
def
compute_loss
(
self
,
x0
,
x_mask
,
cond
=
None
):
# x0: (B, T, num_quantizer)
# x_mask: (B, T) mask is 0 for padding
t
=
torch
.
rand
(
x0
.
shape
[
0
],
device
=
x0
.
device
,
requires_grad
=
False
)
t
=
torch
.
clamp
(
t
,
1e-5
,
1.0
)
return
self
.
loss_t
(
x0
,
x_mask
,
t
,
cond
)
def
reset_parameters
(
self
):
def
_reset_parameters
(
m
):
if
isinstance
(
m
,
nn
.
MultiheadAttention
):
if
m
.
_qkv_same_embed_dim
:
nn
.
init
.
normal_
(
m
.
in_proj_weight
,
std
=
0.02
)
else
:
nn
.
init
.
normal_
(
m
.
q_proj_weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
m
.
k_proj_weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
m
.
v_proj_weight
,
std
=
0.02
)
if
m
.
in_proj_bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
in_proj_bias
,
0.0
)
nn
.
init
.
constant_
(
m
.
out_proj
.
bias
,
0.0
)
if
m
.
bias_k
is
not
None
:
nn
.
init
.
xavier_normal_
(
m
.
bias_k
)
if
m
.
bias_v
is
not
None
:
nn
.
init
.
xavier_normal_
(
m
.
bias_v
)
elif
(
isinstance
(
m
,
nn
.
Conv1d
)
or
isinstance
(
m
,
nn
.
ConvTranspose1d
)
or
isinstance
(
m
,
nn
.
Conv2d
)
or
isinstance
(
m
,
nn
.
ConvTranspose2d
)
):
m
.
weight
.
data
.
normal_
(
0.0
,
0.02
)
elif
isinstance
(
m
,
nn
.
Linear
):
m
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
0.02
)
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
elif
isinstance
(
m
,
nn
.
Embedding
):
m
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
0.02
)
if
m
.
padding_idx
is
not
None
:
m
.
weight
.
data
[
m
.
padding_idx
].
zero_
()
self
.
apply
(
_reset_parameters
)
@
torch
.
no_grad
()
def
reverse_diffusion
(
self
,
cond
,
prompt
,
x_mask
=
None
,
prompt_mask
=
None
,
temp
=
1.5
,
filter_thres
=
0.98
,
max_layer
=
None
,
gt_code
=
None
,
n_timesteps
=
[
10
,
4
,
4
,
4
,
4
,
4
,
4
,
4
],
cfg
=
1.0
,
rescale_cfg
=
1.0
,
):
assert
(
len
(
n_timesteps
)
==
self
.
num_quantizer
)
# each layer has a number of steps
prompt_code
=
prompt
# (B, prompt_len, num_quantizer)
prompt_len
=
prompt_code
.
shape
[
1
]
target_len
=
cond
.
shape
[
1
]
-
prompt_len
if
x_mask
==
None
:
x_mask
=
torch
.
ones
(
cond
.
shape
[
0
],
target_len
).
to
(
cond
.
device
)
# (B, T)
if
prompt_mask
==
None
:
prompt_mask
=
torch
.
ones
(
cond
.
shape
[
0
],
prompt_len
).
to
(
cond
.
device
)
# (B, prompt_len)
cum
=
torch
.
zeros
(
x_mask
.
shape
[
0
],
x_mask
.
shape
[
1
],
self
.
hidden_size
).
to
(
x_mask
.
device
)
# (B, T, hidden_size)
bsz
,
seq_len
,
_
=
cum
.
shape
choice_temp
=
1.0
start_temp
=
temp
# temperature for sampling
start_choice_temp
=
choice_temp
# temperature for choicing mask tokens
if
max_layer
is
None
:
max_layer
=
self
.
num_quantizer
xt
=
torch
.
LongTensor
(
bsz
,
seq_len
,
max_layer
).
to
(
x_mask
.
device
)
if
gt_code
is
not
None
:
gt_layer
=
gt_code
.
shape
[
-
1
]
xt
[:,
:,
:
gt_layer
]
=
gt_code
for
i
in
range
(
gt_layer
):
cum
+=
self
.
token_emb
[
i
](
xt
[:,
:,
i
])
else
:
gt_layer
=
0
for
mask_layer
in
range
(
gt_layer
,
max_layer
):
steps
=
n_timesteps
[
mask_layer
]
to_logits
=
self
.
to_logits
[
mask_layer
]
token_emb
=
self
.
token_emb
[
mask_layer
]
mask_layer
=
torch
.
tensor
(
mask_layer
).
to
(
x_mask
.
device
).
long
().
unsqueeze
(
0
)
mask_layer_cond
=
self
.
layer_emb
(
mask_layer
).
unsqueeze
(
1
)
# (1,) -> (1, 1, hidden_size)
temp_cond
=
cond
+
mask_layer_cond
# (B, T, hidden_size)
mask_token
=
self
.
mask_emb
(
torch
.
zeros_like
(
mask_layer
))
# (1, hidden_size)
mask
=
torch
.
full
((
bsz
,
seq_len
,
1
),
True
).
to
(
x_mask
.
device
)
# (B, T, 1)
seq
=
torch
.
full
((
bsz
,
seq_len
),
0
).
to
(
x_mask
.
device
)
h
=
1.0
/
steps
# prompt_code: (B, prompt_len, num_quantizer)
cur_prompt
=
0
for
idx
,
emb
in
enumerate
(
self
.
token_emb
):
cur_prompt
=
cur_prompt
+
emb
(
prompt_code
[:,
:,
idx
]
)
# (B, prompt_len, hidden_size)
t_list
=
[
1.0
-
i
*
h
for
i
in
range
(
steps
)]
t_list
.
append
(
0.0
)
for
i
in
range
(
steps
):
t
=
t_list
[
i
]
*
torch
.
ones
(
bsz
).
to
(
x_mask
.
device
)
token
=
token_emb
(
seq
)
# (B, T, hidden_size)
cur
=
cum
+
mask
*
mask_token
[:,
None
,
:]
+
(
~
mask
)
*
token
cur
=
cur
+
mask_token
[:,
None
,
:]
*
(
max_layer
-
1
-
mask_layer
)
xt_input
=
torch
.
cat
([
cur_prompt
,
cur
],
dim
=
1
)
# (B, T, hidden_size)
xt_mask
=
torch
.
cat
(
[
prompt_mask
,
x_mask
],
dim
=
1
)
# (B, T), mask is 0 for padding
embeds
=
self
.
diff_estimator
(
xt_input
,
t
,
temp_cond
,
xt_mask
)
embeds
=
embeds
[:,
prompt_len
:,
:]
# cfg
if
cfg
>
0
:
mask_embeds
=
self
.
diff_estimator
(
cur
,
t
,
temp_cond
[:,
prompt_len
:,
:],
x_mask
)
pos_emb_std
=
embeds
.
std
()
# std(g_cond)
embeds
=
embeds
+
cfg
*
(
embeds
-
mask_embeds
)
# g_cfg
rescale_embeds
=
embeds
*
pos_emb_std
/
embeds
.
std
()
# g_final
embeds
=
rescale_cfg
*
rescale_embeds
+
(
1
-
rescale_cfg
)
*
embeds
logits
=
to_logits
(
embeds
)
# (B, T, codebook_size)
annealing_scale
=
t_list
[
i
]
choice_temp
=
start_choice_temp
*
annealing_scale
temp
=
start_temp
*
annealing_scale
logits
=
top_k
(
logits
,
filter_thres
)
if
i
==
steps
-
1
:
# greedy
if
steps
==
1
:
temp
=
0.2
sampled_ids
=
gumbel_sample
(
logits
,
temperature
=
max
(
temp
,
1e-3
))
else
:
sampled_ids
=
logits
.
argmax
(
dim
=-
1
)
else
:
# sampling
sampled_ids
=
gumbel_sample
(
logits
,
temperature
=
max
(
temp
,
1e-3
))
seq
=
torch
.
where
(
mask
.
squeeze
(
-
1
),
sampled_ids
,
seq
)
scores
=
logits
.
softmax
(
dim
=-
1
)
scores
=
scores
.
gather
(
2
,
rearrange
(
sampled_ids
,
"b n -> b n 1"
))
scores
=
rearrange
(
scores
,
"b n 1 -> b n"
)
scores
=
choice_temp
*
gumbel_noise
(
scores
)
+
scores
scores
=
1
-
scores
next_t
=
t_list
[
i
+
1
]
*
torch
.
ones
(
bsz
).
to
(
x_mask
.
device
)
next_mask_num
=
(
self
.
mask_prob
(
next_t
)
*
seq_len
).
long
()[
0
].
item
()
if
next_mask_num
==
0
:
break
scores
=
scores
.
masked_fill
(
~
mask
.
squeeze
(
-
1
),
-
torch
.
finfo
(
scores
.
dtype
).
max
)
mask_indices
=
scores
.
topk
(
next_mask_num
,
dim
=-
1
).
indices
mask
=
torch
.
zeros_like
(
scores
,
dtype
=
torch
.
bool
).
scatter
(
1
,
mask_indices
,
True
)
seq
=
seq
.
masked_fill
(
mask
,
0
)
mask
=
mask
.
unsqueeze
(
-
1
)
cum
=
cum
+
token_emb
(
seq
)
xt
[...,
mask_layer
.
squeeze
(
0
).
item
()]
=
seq
return
xt
def
forward
(
self
,
x0
,
x_mask
,
cond_code
=
None
):
# x0: (B, T, num_quantizer)
# x_mask: (B, T) mask is 0 for padding
# cond_code: semantic token (B, T)
cond
=
self
.
cond_emb
(
cond_code
)
logits
,
mask_layer
,
final_mask
,
x0
,
prompt_len
,
mask_prob
=
self
.
compute_loss
(
x0
,
x_mask
,
cond
,
)
return
logits
,
mask_layer
,
final_mask
,
x0
,
prompt_len
,
mask_prob
indextts/utils/maskgct_utils.py
0 → 100644
View file @
ab9c00af
import
torch
import
librosa
# import json5
from
huggingface_hub
import
hf_hub_download
from
transformers
import
SeamlessM4TFeatureExtractor
,
Wav2Vec2BertModel
import
safetensors
import
numpy
as
np
from
indextts.utils.maskgct.models.codec.kmeans.repcodec_model
import
RepCodec
from
indextts.utils.maskgct.models.tts.maskgct.maskgct_s2a
import
MaskGCT_S2A
from
indextts.utils.maskgct.models.codec.amphion_codec.codec
import
CodecEncoder
,
CodecDecoder
import
time
def
_load_config
(
config_fn
,
lowercase
=
False
):
"""Load configurations into a dictionary
Args:
config_fn (str): path to configuration file
lowercase (bool, optional): whether changing keys to lower case. Defaults to False.
Returns:
dict: dictionary that stores configurations
"""
with
open
(
config_fn
,
"r"
)
as
f
:
data
=
f
.
read
()
config_
=
json5
.
loads
(
data
)
if
"base_config"
in
config_
:
# load configurations from new path
p_config_path
=
os
.
path
.
join
(
os
.
getenv
(
"WORK_DIR"
),
config_
[
"base_config"
])
p_config_
=
_load_config
(
p_config_path
)
config_
=
override_config
(
p_config_
,
config_
)
if
lowercase
:
# change keys in config_ to lower case
config_
=
get_lowercase_keys_config
(
config_
)
return
config_
def
load_config
(
config_fn
,
lowercase
=
False
):
"""Load configurations into a dictionary
Args:
config_fn (str): path to configuration file
lowercase (bool, optional): _description_. Defaults to False.
Returns:
JsonHParams: an object that stores configurations
"""
config_
=
_load_config
(
config_fn
,
lowercase
=
lowercase
)
# create an JsonHParams object with configuration dict
cfg
=
JsonHParams
(
**
config_
)
return
cfg
class
JsonHParams
:
def
__init__
(
self
,
**
kwargs
):
for
k
,
v
in
kwargs
.
items
():
if
type
(
v
)
==
dict
:
v
=
JsonHParams
(
**
v
)
self
[
k
]
=
v
def
keys
(
self
):
return
self
.
__dict__
.
keys
()
def
items
(
self
):
return
self
.
__dict__
.
items
()
def
values
(
self
):
return
self
.
__dict__
.
values
()
def
__len__
(
self
):
return
len
(
self
.
__dict__
)
def
__getitem__
(
self
,
key
):
return
getattr
(
self
,
key
)
def
__setitem__
(
self
,
key
,
value
):
return
setattr
(
self
,
key
,
value
)
def
__contains__
(
self
,
key
):
return
key
in
self
.
__dict__
def
__repr__
(
self
):
return
self
.
__dict__
.
__repr__
()
def
build_semantic_model
(
path_
=
'./models/tts/maskgct/ckpt/wav2vec2bert_stats.pt'
,
bert_path
=
None
):
semantic_model
=
Wav2Vec2BertModel
.
from_pretrained
(
# "facebook/w2v-bert-2.0"
bert_path
)
semantic_model
.
eval
()
stat_mean_var
=
torch
.
load
(
path_
)
semantic_mean
=
stat_mean_var
[
"mean"
]
semantic_std
=
torch
.
sqrt
(
stat_mean_var
[
"var"
])
return
semantic_model
,
semantic_mean
,
semantic_std
def
build_semantic_codec
(
cfg
):
semantic_codec
=
RepCodec
(
cfg
=
cfg
)
semantic_codec
.
eval
()
return
semantic_codec
def
build_s2a_model
(
cfg
,
device
):
soundstorm_model
=
MaskGCT_S2A
(
cfg
=
cfg
)
soundstorm_model
.
eval
()
soundstorm_model
.
to
(
device
)
return
soundstorm_model
def
build_acoustic_codec
(
cfg
,
device
):
codec_encoder
=
CodecEncoder
(
cfg
=
cfg
.
encoder
)
codec_decoder
=
CodecDecoder
(
cfg
=
cfg
.
decoder
)
codec_encoder
.
eval
()
codec_decoder
.
eval
()
codec_encoder
.
to
(
device
)
codec_decoder
.
to
(
device
)
return
codec_encoder
,
codec_decoder
class
Inference_Pipeline
():
def
__init__
(
self
,
semantic_model
,
semantic_codec
,
semantic_mean
,
semantic_std
,
codec_encoder
,
codec_decoder
,
s2a_model_1layer
,
s2a_model_full
,
):
self
.
semantic_model
=
semantic_model
self
.
semantic_codec
=
semantic_codec
self
.
semantic_mean
=
semantic_mean
self
.
semantic_std
=
semantic_std
self
.
codec_encoder
=
codec_encoder
self
.
codec_decoder
=
codec_decoder
self
.
s2a_model_1layer
=
s2a_model_1layer
self
.
s2a_model_full
=
s2a_model_full
@
torch
.
no_grad
()
def
get_emb
(
self
,
input_features
,
attention_mask
):
vq_emb
=
self
.
semantic_model
(
input_features
=
input_features
,
attention_mask
=
attention_mask
,
output_hidden_states
=
True
,
)
feat
=
vq_emb
.
hidden_states
[
17
]
# (B, T, C)
feat
=
(
feat
-
self
.
semantic_mean
.
to
(
feat
))
/
self
.
semantic_std
.
to
(
feat
)
return
feat
@
torch
.
no_grad
()
def
extract_acoustic_code
(
self
,
speech
):
vq_emb
=
self
.
codec_encoder
(
speech
.
unsqueeze
(
1
))
_
,
vq
,
_
,
_
,
_
=
self
.
codec_decoder
.
quantizer
(
vq_emb
)
acoustic_code
=
vq
.
permute
(
1
,
2
,
0
)
return
acoustic_code
@
torch
.
no_grad
()
def
get_scode
(
self
,
inputs
):
semantic_code
,
feat
=
self
.
semantic_codec
.
quantize
(
inputs
)
# vq = self.semantic_codec.quantizer.vq2emb(semantic_code.unsqueeze(1))
# vq = vq.transpose(1,2)
return
semantic_code
@
torch
.
no_grad
()
def
semantic2acoustic
(
self
,
combine_semantic_code
,
acoustic_code
,
n_timesteps
=
[
25
,
10
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
cfg
=
2.5
,
rescale_cfg
=
0.75
,
):
semantic_code
=
combine_semantic_code
cond
=
self
.
s2a_model_1layer
.
cond_emb
(
semantic_code
)
prompt
=
acoustic_code
[:,
:,
:]
predict_1layer
=
self
.
s2a_model_1layer
.
reverse_diffusion
(
cond
=
cond
,
prompt
=
prompt
,
temp
=
1.5
,
filter_thres
=
0.98
,
n_timesteps
=
n_timesteps
[:
1
],
cfg
=
cfg
,
rescale_cfg
=
rescale_cfg
,
)
cond
=
self
.
s2a_model_full
.
cond_emb
(
semantic_code
)
prompt
=
acoustic_code
[:,
:,
:]
predict_full
=
self
.
s2a_model_full
.
reverse_diffusion
(
cond
=
cond
,
prompt
=
prompt
,
temp
=
1.5
,
filter_thres
=
0.98
,
n_timesteps
=
n_timesteps
,
cfg
=
cfg
,
rescale_cfg
=
rescale_cfg
,
gt_code
=
predict_1layer
,
)
vq_emb
=
self
.
codec_decoder
.
vq2emb
(
predict_full
.
permute
(
2
,
0
,
1
),
n_quantizers
=
12
)
recovered_audio
=
self
.
codec_decoder
(
vq_emb
)
prompt_vq_emb
=
self
.
codec_decoder
.
vq2emb
(
prompt
.
permute
(
2
,
0
,
1
),
n_quantizers
=
12
)
recovered_prompt_audio
=
self
.
codec_decoder
(
prompt_vq_emb
)
recovered_prompt_audio
=
recovered_prompt_audio
[
0
][
0
].
cpu
().
numpy
()
recovered_audio
=
recovered_audio
[
0
][
0
].
cpu
().
numpy
()
combine_audio
=
np
.
concatenate
([
recovered_prompt_audio
,
recovered_audio
])
return
combine_audio
,
recovered_audio
def
s2a_inference
(
self
,
prompt_speech_path
,
combine_semantic_code
,
cfg
=
2.5
,
n_timesteps_s2a
=
[
25
,
10
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
cfg_s2a
=
2.5
,
rescale_cfg_s2a
=
0.75
,
):
speech
=
librosa
.
load
(
prompt_speech_path
,
sr
=
24000
)[
0
]
acoustic_code
=
self
.
extract_acoustic_code
(
torch
.
tensor
(
speech
).
unsqueeze
(
0
).
to
(
combine_semantic_code
.
device
)
)
_
,
recovered_audio
=
self
.
semantic2acoustic
(
combine_semantic_code
,
acoustic_code
,
n_timesteps
=
n_timesteps_s2a
,
cfg
=
cfg_s2a
,
rescale_cfg
=
rescale_cfg_s2a
,
)
return
recovered_audio
@
torch
.
no_grad
()
def
gt_inference
(
self
,
prompt_speech_path
,
combine_semantic_code
,
):
speech
=
librosa
.
load
(
prompt_speech_path
,
sr
=
24000
)[
0
]
'''
acoustic_code = self.extract_acoustic_code(
torch.tensor(speech).unsqueeze(0).to(combine_semantic_code.device)
)
prompt = acoustic_code[:, :, :]
prompt_vq_emb = self.codec_decoder.vq2emb(
prompt.permute(2, 0, 1), n_quantizers=12
)
'''
prompt_vq_emb
=
self
.
codec_encoder
(
torch
.
tensor
(
speech
).
unsqueeze
(
0
).
unsqueeze
(
1
).
to
(
combine_semantic_code
.
device
))
recovered_prompt_audio
=
self
.
codec_decoder
(
prompt_vq_emb
)
recovered_prompt_audio
=
recovered_prompt_audio
[
0
][
0
].
cpu
().
numpy
()
return
recovered_prompt_audio
indextts/utils/text_utils.py
0 → 100644
View file @
ab9c00af
import
re
from
textstat
import
textstat
def
contains_chinese
(
text
):
# 正则表达式,用于匹配中文字符 + 数字 -> 都认为是 zh
if
re
.
search
(
r
'[\u4e00-\u9fff0-9]'
,
text
):
return
True
return
False
def
get_text_syllable_num
(
text
):
chinese_char_pattern
=
re
.
compile
(
r
'[\u4e00-\u9fff]'
)
number_char_pattern
=
re
.
compile
(
r
'[0-9]'
)
syllable_num
=
0
tokens
=
re
.
findall
(
r
'[\u4e00-\u9fff]+|[a-zA-Z]+|[0-9]+'
,
text
)
# print(tokens)
if
contains_chinese
(
text
):
for
token
in
tokens
:
if
chinese_char_pattern
.
search
(
token
)
or
number_char_pattern
.
search
(
token
):
syllable_num
+=
len
(
token
)
else
:
syllable_num
+=
textstat
.
syllable_count
(
token
)
else
:
syllable_num
=
textstat
.
syllable_count
(
text
)
return
syllable_num
def
get_text_tts_dur
(
text
):
min_speed
=
3
# 2.18 #
max_speed
=
5.50
ratio
=
0.8517
if
contains_chinese
(
text
)
else
1.0
syllable_num
=
get_text_syllable_num
(
text
)
max_dur
=
syllable_num
*
ratio
/
max_speed
min_dur
=
syllable_num
*
ratio
/
min_speed
return
max_dur
,
min_dur
\ No newline at end of file
indextts/utils/typical_sampling.py
0 → 100644
View file @
ab9c00af
import
torch
from
transformers
import
TypicalLogitsWarper
as
BaseTypicalLogitsWarper
class
TypicalLogitsWarper
(
BaseTypicalLogitsWarper
):
def
__init__
(
self
,
mass
:
float
=
0.9
,
filter_value
:
float
=
-
float
(
"Inf"
),
min_tokens_to_keep
:
int
=
1
):
super
().
__init__
(
mass
=
mass
,
filter_value
=
filter_value
,
min_tokens_to_keep
=
min_tokens_to_keep
)
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
# calculate entropy
normalized
=
torch
.
nn
.
functional
.
log_softmax
(
scores
,
dim
=-
1
)
p
=
torch
.
exp
(
normalized
)
ent
=
-
(
normalized
*
p
).
nansum
(
-
1
,
keepdim
=
True
)
# shift and sort
shifted_scores
=
torch
.
abs
((
-
normalized
)
-
ent
)
sorted_scores
,
sorted_indices
=
torch
.
sort
(
shifted_scores
,
descending
=
False
)
sorted_logits
=
scores
.
gather
(
-
1
,
sorted_indices
)
cumulative_probs
=
sorted_logits
.
softmax
(
dim
=-
1
).
cumsum
(
dim
=-
1
)
# Remove tokens with cumulative mass above the threshold
last_ind
=
(
cumulative_probs
<
self
.
mass
).
sum
(
dim
=
1
)
last_ind
[
last_ind
<
0
]
=
0
sorted_indices_to_remove
=
sorted_scores
>
sorted_scores
.
gather
(
1
,
last_ind
.
view
(
-
1
,
1
))
if
self
.
min_tokens_to_keep
>
1
:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove
[...,
:
self
.
min_tokens_to_keep
]
=
0
indices_to_remove
=
sorted_indices_to_remove
.
scatter
(
1
,
sorted_indices
,
sorted_indices_to_remove
)
scores
=
scores
.
masked_fill
(
indices_to_remove
,
self
.
filter_value
)
return
scores
indextts/utils/utils.py
0 → 100644
View file @
ab9c00af
import
os
import
re
import
random
import
torch
import
torchaudio
MATPLOTLIB_FLAG
=
False
def
load_audio
(
audiopath
,
sampling_rate
):
audio
,
sr
=
torchaudio
.
load
(
audiopath
)
#print(f"wave shape: {audio.shape}, sample_rate: {sr}")
if
audio
.
size
(
0
)
>
1
:
# mix to mono
audio
=
audio
[
0
].
unsqueeze
(
0
)
if
sr
!=
sampling_rate
:
try
:
audio
=
torchaudio
.
functional
.
resample
(
audio
,
sr
,
sampling_rate
)
except
Exception
as
e
:
print
(
f
"Warning:
{
audiopath
}
, wave shape:
{
audio
.
shape
}
, sample_rate:
{
sr
}
"
)
return
None
# clip audio invalid values
audio
.
clip_
(
-
1
,
1
)
return
audio
def
tokenize_by_CJK_char
(
line
:
str
)
->
str
:
"""
Tokenize a line of text with CJK char.
Note: All return charaters will be upper case.
Example:
input = "你好世界是 hello world 的中文"
output = "你 好 世 界 是 HELLO WORLD 的 中 文"
Args:
line:
The input text.
Return:
A new string tokenize by CJK char.
"""
# The CJK ranges is from https://github.com/alvations/nltk/blob/79eed6ddea0d0a2c212c1060b477fc268fec4d4b/nltk/tokenize/util.py
pattern
=
re
.
compile
(
r
"([\u1100-\u11ff\u2e80-\ua4cf\ua840-\uD7AF\uF900-\uFAFF\uFE30-\uFE4F\uFF65-\uFFDC\U00020000-\U0002FFFF])"
)
chars
=
pattern
.
split
(
line
.
strip
().
upper
())
return
" "
.
join
([
w
.
strip
()
for
w
in
chars
if
w
.
strip
()])
def
make_pad_mask
(
lengths
:
torch
.
Tensor
,
max_len
:
int
=
0
)
->
torch
.
Tensor
:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
Args:
lengths (torch.Tensor): Batch of lengths (B,).
Returns:
torch.Tensor: Mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
batch_size
=
lengths
.
size
(
0
)
max_len
=
max_len
if
max_len
>
0
else
lengths
.
max
().
item
()
seq_range
=
torch
.
arange
(
0
,
max_len
,
dtype
=
torch
.
int64
,
device
=
lengths
.
device
)
seq_range_expand
=
seq_range
.
unsqueeze
(
0
).
expand
(
batch_size
,
max_len
)
seq_length_expand
=
lengths
.
unsqueeze
(
-
1
)
mask
=
seq_range_expand
>=
seq_length_expand
return
mask
def
safe_log
(
x
:
torch
.
Tensor
,
clip_val
:
float
=
1e-7
)
->
torch
.
Tensor
:
"""
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
Args:
x (Tensor): Input tensor.
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
Returns:
Tensor: Element-wise logarithm of the input tensor with clipping applied.
"""
return
torch
.
log
(
torch
.
clip
(
x
,
min
=
clip_val
))
Prev
1
…
11
12
13
14
15
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment