Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
InspireMusic_pytorch
Commits
0112b0f0
Commit
0112b0f0
authored
Feb 14, 2025
by
chenzk
Browse files
v1.0
parents
Pipeline
#2394
canceled with stages
Changes
474
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1321 additions
and
0 deletions
+1321
-0
examples/music_generation/inspiremusic/wavtokenizer/encoder/__init__.py
..._generation/inspiremusic/wavtokenizer/encoder/__init__.py
+12
-0
examples/music_generation/inspiremusic/wavtokenizer/encoder/__pycache__/__init__.cpython-310.pyc
...wavtokenizer/encoder/__pycache__/__init__.cpython-310.pyc
+0
-0
examples/music_generation/inspiremusic/wavtokenizer/encoder/__pycache__/distrib.cpython-310.pyc
.../wavtokenizer/encoder/__pycache__/distrib.cpython-310.pyc
+0
-0
examples/music_generation/inspiremusic/wavtokenizer/encoder/__pycache__/model.cpython-310.pyc
...ic/wavtokenizer/encoder/__pycache__/model.cpython-310.pyc
+0
-0
examples/music_generation/inspiremusic/wavtokenizer/encoder/__pycache__/utils.cpython-310.pyc
...ic/wavtokenizer/encoder/__pycache__/utils.cpython-310.pyc
+0
-0
examples/music_generation/inspiremusic/wavtokenizer/encoder/distrib.py
...c_generation/inspiremusic/wavtokenizer/encoder/distrib.py
+124
-0
examples/music_generation/inspiremusic/wavtokenizer/encoder/model.py
...sic_generation/inspiremusic/wavtokenizer/encoder/model.py
+324
-0
examples/music_generation/inspiremusic/wavtokenizer/encoder/modules/__init__.py
...ion/inspiremusic/wavtokenizer/encoder/modules/__init__.py
+22
-0
examples/music_generation/inspiremusic/wavtokenizer/encoder/modules/__pycache__/__init__.cpython-310.pyc
...izer/encoder/modules/__pycache__/__init__.cpython-310.pyc
+0
-0
examples/music_generation/inspiremusic/wavtokenizer/encoder/modules/__pycache__/conv.cpython-310.pyc
...okenizer/encoder/modules/__pycache__/conv.cpython-310.pyc
+0
-0
examples/music_generation/inspiremusic/wavtokenizer/encoder/modules/__pycache__/lstm.cpython-310.pyc
...okenizer/encoder/modules/__pycache__/lstm.cpython-310.pyc
+0
-0
examples/music_generation/inspiremusic/wavtokenizer/encoder/modules/__pycache__/norm.cpython-310.pyc
...okenizer/encoder/modules/__pycache__/norm.cpython-310.pyc
+0
-0
examples/music_generation/inspiremusic/wavtokenizer/encoder/modules/__pycache__/seanet.cpython-310.pyc
...enizer/encoder/modules/__pycache__/seanet.cpython-310.pyc
+0
-0
examples/music_generation/inspiremusic/wavtokenizer/encoder/modules/__pycache__/transformer.cpython-310.pyc
...r/encoder/modules/__pycache__/transformer.cpython-310.pyc
+0
-0
examples/music_generation/inspiremusic/wavtokenizer/encoder/modules/conv.py
...eration/inspiremusic/wavtokenizer/encoder/modules/conv.py
+253
-0
examples/music_generation/inspiremusic/wavtokenizer/encoder/modules/lstm.py
...eration/inspiremusic/wavtokenizer/encoder/modules/lstm.py
+39
-0
examples/music_generation/inspiremusic/wavtokenizer/encoder/modules/norm.py
...eration/inspiremusic/wavtokenizer/encoder/modules/norm.py
+28
-0
examples/music_generation/inspiremusic/wavtokenizer/encoder/modules/seanet.py
...ation/inspiremusic/wavtokenizer/encoder/modules/seanet.py
+253
-0
examples/music_generation/inspiremusic/wavtokenizer/encoder/modules/transformer.py
.../inspiremusic/wavtokenizer/encoder/modules/transformer.py
+119
-0
examples/music_generation/inspiremusic/wavtokenizer/encoder/msstftd.py
...c_generation/inspiremusic/wavtokenizer/encoder/msstftd.py
+147
-0
No files found.
examples/music_generation/inspiremusic/wavtokenizer/encoder/__init__.py
0 → 100644
View file @
0112b0f0
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# flake8: noqa
"""EnCodec neural audio codec."""
__version__
=
"0.1.2a3"
from
.model
import
EncodecModel
examples/music_generation/inspiremusic/wavtokenizer/encoder/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
examples/music_generation/inspiremusic/wavtokenizer/encoder/__pycache__/distrib.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
examples/music_generation/inspiremusic/wavtokenizer/encoder/__pycache__/model.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
examples/music_generation/inspiremusic/wavtokenizer/encoder/__pycache__/utils.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
examples/music_generation/inspiremusic/wavtokenizer/encoder/distrib.py
0 → 100644
View file @
0112b0f0
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Torch distributed utilities."""
import
typing
as
tp
import
torch
def
rank
():
if
torch
.
distributed
.
is_initialized
():
return
torch
.
distributed
.
get_rank
()
else
:
return
0
def
world_size
():
if
torch
.
distributed
.
is_initialized
():
return
torch
.
distributed
.
get_world_size
()
else
:
return
1
def
is_distributed
():
return
world_size
()
>
1
def
all_reduce
(
tensor
:
torch
.
Tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
):
if
is_distributed
():
return
torch
.
distributed
.
all_reduce
(
tensor
,
op
)
def
_is_complex_or_float
(
tensor
):
return
torch
.
is_floating_point
(
tensor
)
or
torch
.
is_complex
(
tensor
)
def
_check_number_of_params
(
params
:
tp
.
List
[
torch
.
Tensor
]):
# utility function to check that the number of params in all workers is the same,
# and thus avoid a deadlock with distributed all reduce.
if
not
is_distributed
()
or
not
params
:
return
tensor
=
torch
.
tensor
([
len
(
params
)],
device
=
params
[
0
].
device
,
dtype
=
torch
.
long
)
all_reduce
(
tensor
)
if
tensor
.
item
()
!=
len
(
params
)
*
world_size
():
# If not all the workers have the same number, for at least one of them,
# this inequality will be verified.
raise
RuntimeError
(
f
"Mismatch in number of params: ours is
{
len
(
params
)
}
, "
"at least one worker has a different one."
)
def
broadcast_tensors
(
tensors
:
tp
.
Iterable
[
torch
.
Tensor
],
src
:
int
=
0
):
"""Broadcast the tensors from the given parameters to all workers.
This can be used to ensure that all workers have the same model to start with.
"""
if
not
is_distributed
():
return
tensors
=
[
tensor
for
tensor
in
tensors
if
_is_complex_or_float
(
tensor
)]
_check_number_of_params
(
tensors
)
handles
=
[]
for
tensor
in
tensors
:
handle
=
torch
.
distributed
.
broadcast
(
tensor
.
data
,
src
=
src
,
async_op
=
True
)
handles
.
append
(
handle
)
for
handle
in
handles
:
handle
.
wait
()
def
sync_buffer
(
buffers
,
average
=
True
):
"""
Sync grad for buffers. If average is False, broadcast instead of averaging.
"""
if
not
is_distributed
():
return
handles
=
[]
for
buffer
in
buffers
:
if
torch
.
is_floating_point
(
buffer
.
data
):
if
average
:
handle
=
torch
.
distributed
.
all_reduce
(
buffer
.
data
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
async_op
=
True
)
else
:
handle
=
torch
.
distributed
.
broadcast
(
buffer
.
data
,
src
=
0
,
async_op
=
True
)
handles
.
append
((
buffer
,
handle
))
for
buffer
,
handle
in
handles
:
handle
.
wait
()
if
average
:
buffer
.
data
/=
world_size
def
sync_grad
(
params
):
"""
Simpler alternative to DistributedDataParallel, that doesn't rely
on any black magic. For simple models it can also be as fast.
Just call this on your model parameters after the call to backward!
"""
if
not
is_distributed
():
return
handles
=
[]
for
p
in
params
:
if
p
.
grad
is
not
None
:
handle
=
torch
.
distributed
.
all_reduce
(
p
.
grad
.
data
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
async_op
=
True
)
handles
.
append
((
p
,
handle
))
for
p
,
handle
in
handles
:
handle
.
wait
()
p
.
grad
.
data
/=
world_size
()
def
average_metrics
(
metrics
:
tp
.
Dict
[
str
,
float
],
count
=
1.
):
"""Average a dictionary of metrics across all workers, using the optional
`count` as unnormalized weight.
"""
if
not
is_distributed
():
return
metrics
keys
,
values
=
zip
(
*
metrics
.
items
())
device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
tensor
=
torch
.
tensor
(
list
(
values
)
+
[
1
],
device
=
device
,
dtype
=
torch
.
float32
)
tensor
*=
count
all_reduce
(
tensor
)
averaged
=
(
tensor
[:
-
1
]
/
tensor
[
-
1
]).
cpu
().
tolist
()
return
dict
(
zip
(
keys
,
averaged
))
examples/music_generation/inspiremusic/wavtokenizer/encoder/model.py
0 → 100644
View file @
0112b0f0
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""EnCodec model implementation."""
import
math
from
pathlib
import
Path
import
typing
as
tp
import
numpy
as
np
import
torch
from
torch
import
nn
from
.
import
quantization
as
qt
from
.
import
modules
as
m
from
.utils
import
_check_checksum
,
_linear_overlap_add
,
_get_checkpoint_url
ROOT_URL
=
'https://dl.fbaipublicfiles.com/encodec/v0/'
EncodedFrame
=
tp
.
Tuple
[
torch
.
Tensor
,
tp
.
Optional
[
torch
.
Tensor
]]
class
LMModel
(
nn
.
Module
):
"""Language Model to estimate probabilities of each codebook entry.
We predict all codebooks in parallel for a given time step.
Args:
n_q (int): number of codebooks.
card (int): codebook cardinality.
dim (int): transformer dimension.
**kwargs: passed to `encoder.modules.transformer.StreamingTransformerEncoder`.
"""
def
__init__
(
self
,
n_q
:
int
=
32
,
card
:
int
=
1024
,
dim
:
int
=
200
,
**
kwargs
):
super
().
__init__
()
self
.
card
=
card
self
.
n_q
=
n_q
self
.
dim
=
dim
self
.
transformer
=
m
.
StreamingTransformerEncoder
(
dim
=
dim
,
**
kwargs
)
self
.
emb
=
nn
.
ModuleList
([
nn
.
Embedding
(
card
+
1
,
dim
)
for
_
in
range
(
n_q
)])
self
.
linears
=
nn
.
ModuleList
([
nn
.
Linear
(
dim
,
card
)
for
_
in
range
(
n_q
)])
def
forward
(
self
,
indices
:
torch
.
Tensor
,
states
:
tp
.
Optional
[
tp
.
List
[
torch
.
Tensor
]]
=
None
,
offset
:
int
=
0
):
"""
Args:
indices (torch.Tensor): indices from the previous time step. Indices
should be 1 + actual index in the codebook. The value 0 is reserved for
when the index is missing (i.e. first time step). Shape should be
`[B, n_q, T]`.
states: state for the streaming decoding.
offset: offset of the current time step.
Returns a 3-tuple `(probabilities, new_states, new_offset)` with probabilities
with a shape `[B, card, n_q, T]`.
"""
B
,
K
,
T
=
indices
.
shape
input_
=
sum
([
self
.
emb
[
k
](
indices
[:,
k
])
for
k
in
range
(
K
)])
out
,
states
,
offset
=
self
.
transformer
(
input_
,
states
,
offset
)
logits
=
torch
.
stack
([
self
.
linears
[
k
](
out
)
for
k
in
range
(
K
)],
dim
=
1
).
permute
(
0
,
3
,
1
,
2
)
return
torch
.
softmax
(
logits
,
dim
=
1
),
states
,
offset
class
EncodecModel
(
nn
.
Module
):
"""EnCodec model operating on the raw waveform.
Args:
target_bandwidths (list of float): Target bandwidths.
encoder (nn.Module): Encoder network.
decoder (nn.Module): Decoder network.
sample_rate (int): Audio sample rate.
channels (int): Number of audio channels.
normalize (bool): Whether to apply audio normalization.
segment (float or None): segment duration in sec. when doing overlap-add.
overlap (float): overlap between segment, given as a fraction of the segment duration.
name (str): name of the model, used as metadata when compressing audio.
"""
def
__init__
(
self
,
encoder
:
m
.
SEANetEncoder
,
decoder
:
m
.
SEANetDecoder
,
quantizer
:
qt
.
ResidualVectorQuantizer
,
target_bandwidths
:
tp
.
List
[
float
],
sample_rate
:
int
,
channels
:
int
,
normalize
:
bool
=
False
,
segment
:
tp
.
Optional
[
float
]
=
None
,
overlap
:
float
=
0.01
,
name
:
str
=
'unset'
):
super
().
__init__
()
self
.
bandwidth
:
tp
.
Optional
[
float
]
=
None
self
.
target_bandwidths
=
target_bandwidths
self
.
encoder
=
encoder
self
.
quantizer
=
quantizer
self
.
decoder
=
decoder
self
.
sample_rate
=
sample_rate
self
.
channels
=
channels
self
.
normalize
=
normalize
self
.
segment
=
segment
self
.
overlap
=
overlap
self
.
frame_rate
=
math
.
ceil
(
self
.
sample_rate
/
np
.
prod
(
self
.
encoder
.
ratios
))
self
.
name
=
name
self
.
bits_per_codebook
=
int
(
math
.
log2
(
self
.
quantizer
.
bins
))
assert
2
**
self
.
bits_per_codebook
==
self
.
quantizer
.
bins
,
\
"quantizer bins must be a power of 2."
@
property
def
segment_length
(
self
)
->
tp
.
Optional
[
int
]:
if
self
.
segment
is
None
:
return
None
return
int
(
self
.
segment
*
self
.
sample_rate
)
@
property
def
segment_stride
(
self
)
->
tp
.
Optional
[
int
]:
segment_length
=
self
.
segment_length
if
segment_length
is
None
:
return
None
return
max
(
1
,
int
((
1
-
self
.
overlap
)
*
segment_length
))
def
encode
(
self
,
x
:
torch
.
Tensor
)
->
tp
.
List
[
EncodedFrame
]:
"""Given a tensor `x`, returns a list of frames containing
the discrete encoded codes for `x`, along with rescaling factors
for each segment, when `self.normalize` is True.
Each frames is a tuple `(codebook, scale)`, with `codebook` of
shape `[B, K, T]`, with `K` the number of codebooks.
"""
assert
x
.
dim
()
==
3
_
,
channels
,
length
=
x
.
shape
assert
channels
>
0
and
channels
<=
2
segment_length
=
self
.
segment_length
if
segment_length
is
None
:
segment_length
=
length
stride
=
length
else
:
stride
=
self
.
segment_stride
# type: ignore
assert
stride
is
not
None
encoded_frames
:
tp
.
List
[
EncodedFrame
]
=
[]
for
offset
in
range
(
0
,
length
,
stride
):
frame
=
x
[:,
:,
offset
:
offset
+
segment_length
]
encoded_frames
.
append
(
self
.
_encode_frame
(
frame
))
return
encoded_frames
def
_encode_frame
(
self
,
x
:
torch
.
Tensor
)
->
EncodedFrame
:
length
=
x
.
shape
[
-
1
]
duration
=
length
/
self
.
sample_rate
assert
self
.
segment
is
None
or
duration
<=
1e-5
+
self
.
segment
if
self
.
normalize
:
mono
=
x
.
mean
(
dim
=
1
,
keepdim
=
True
)
volume
=
mono
.
pow
(
2
).
mean
(
dim
=
2
,
keepdim
=
True
).
sqrt
()
scale
=
1e-8
+
volume
x
=
x
/
scale
scale
=
scale
.
view
(
-
1
,
1
)
else
:
scale
=
None
emb
=
self
.
encoder
(
x
)
codes
=
self
.
quantizer
.
encode
(
emb
,
self
.
frame_rate
,
self
.
bandwidth
)
codes
=
codes
.
transpose
(
0
,
1
)
# codes is [B, K, T], with T frames, K nb of codebooks.
return
codes
,
scale
def
decode
(
self
,
encoded_frames
:
tp
.
List
[
EncodedFrame
])
->
torch
.
Tensor
:
"""Decode the given frames into a waveform.
Note that the output might be a bit bigger than the input. In that case,
any extra steps at the end can be trimmed.
"""
segment_length
=
self
.
segment_length
if
segment_length
is
None
:
assert
len
(
encoded_frames
)
==
1
return
self
.
_decode_frame
(
encoded_frames
[
0
])
frames
=
[
self
.
_decode_frame
(
frame
)
for
frame
in
encoded_frames
]
return
_linear_overlap_add
(
frames
,
self
.
segment_stride
or
1
)
def
_decode_frame
(
self
,
encoded_frame
:
EncodedFrame
)
->
torch
.
Tensor
:
codes
,
scale
=
encoded_frame
codes
=
codes
.
transpose
(
0
,
1
)
emb
=
self
.
quantizer
.
decode
(
codes
)
out
=
self
.
decoder
(
emb
)
if
scale
is
not
None
:
out
=
out
*
scale
.
view
(
-
1
,
1
,
1
)
return
out
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
frames
=
self
.
encode
(
x
)
return
self
.
decode
(
frames
)[:,
:,
:
x
.
shape
[
-
1
]]
def
set_target_bandwidth
(
self
,
bandwidth
:
float
):
if
bandwidth
not
in
self
.
target_bandwidths
:
raise
ValueError
(
f
"This model doesn't support the bandwidth
{
bandwidth
}
. "
f
"Select one of
{
self
.
target_bandwidths
}
."
)
self
.
bandwidth
=
bandwidth
def
get_lm_model
(
self
)
->
LMModel
:
"""Return the associated LM model to improve the compression rate.
"""
device
=
next
(
self
.
parameters
()).
device
lm
=
LMModel
(
self
.
quantizer
.
n_q
,
self
.
quantizer
.
bins
,
num_layers
=
5
,
dim
=
200
,
past_context
=
int
(
3.5
*
self
.
frame_rate
)).
to
(
device
)
checkpoints
=
{
'encodec_24khz'
:
'encodec_lm_24khz-1608e3c0.th'
,
'encodec_48khz'
:
'encodec_lm_48khz-7add9fc3.th'
,
}
try
:
checkpoint_name
=
checkpoints
[
self
.
name
]
except
KeyError
:
raise
RuntimeError
(
"No LM pre-trained for the current Encodec model."
)
url
=
_get_checkpoint_url
(
ROOT_URL
,
checkpoint_name
)
state
=
torch
.
hub
.
load_state_dict_from_url
(
url
,
map_location
=
'cpu'
,
check_hash
=
True
)
# type: ignore
lm
.
load_state_dict
(
state
)
lm
.
eval
()
return
lm
@
staticmethod
def
_get_model
(
target_bandwidths
:
tp
.
List
[
float
],
sample_rate
:
int
=
24_000
,
channels
:
int
=
1
,
causal
:
bool
=
True
,
model_norm
:
str
=
'weight_norm'
,
audio_normalize
:
bool
=
False
,
segment
:
tp
.
Optional
[
float
]
=
None
,
name
:
str
=
'unset'
):
encoder
=
m
.
SEANetEncoder
(
channels
=
channels
,
norm
=
model_norm
,
causal
=
causal
)
decoder
=
m
.
SEANetDecoder
(
channels
=
channels
,
norm
=
model_norm
,
causal
=
causal
)
n_q
=
int
(
1000
*
target_bandwidths
[
-
1
]
//
(
math
.
ceil
(
sample_rate
/
encoder
.
hop_length
)
*
10
))
quantizer
=
qt
.
ResidualVectorQuantizer
(
dimension
=
encoder
.
dimension
,
n_q
=
n_q
,
bins
=
1024
,
)
model
=
EncodecModel
(
encoder
,
decoder
,
quantizer
,
target_bandwidths
,
sample_rate
,
channels
,
normalize
=
audio_normalize
,
segment
=
segment
,
name
=
name
,
)
return
model
@
staticmethod
def
_get_pretrained
(
checkpoint_name
:
str
,
repository
:
tp
.
Optional
[
Path
]
=
None
):
if
repository
is
not
None
:
if
not
repository
.
is_dir
():
raise
ValueError
(
f
"
{
repository
}
must exist and be a directory."
)
file
=
repository
/
checkpoint_name
checksum
=
file
.
stem
.
split
(
'-'
)[
1
]
_check_checksum
(
file
,
checksum
)
return
torch
.
load
(
file
)
else
:
url
=
_get_checkpoint_url
(
ROOT_URL
,
checkpoint_name
)
return
torch
.
hub
.
load_state_dict_from_url
(
url
,
map_location
=
'cpu'
,
check_hash
=
True
)
# type:ignore
@
staticmethod
def
encodec_model_24khz
(
pretrained
:
bool
=
True
,
repository
:
tp
.
Optional
[
Path
]
=
None
):
"""Return the pretrained causal 24khz model.
"""
if
repository
:
assert
pretrained
target_bandwidths
=
[
1.5
,
3.
,
6
,
12.
,
24.
]
checkpoint_name
=
'encodec_24khz-d7cc33bc.th'
sample_rate
=
24_000
channels
=
1
model
=
EncodecModel
.
_get_model
(
target_bandwidths
,
sample_rate
,
channels
,
causal
=
True
,
model_norm
=
'weight_norm'
,
audio_normalize
=
False
,
name
=
'encodec_24khz'
if
pretrained
else
'unset'
)
if
pretrained
:
state_dict
=
EncodecModel
.
_get_pretrained
(
checkpoint_name
,
repository
)
model
.
load_state_dict
(
state_dict
)
model
.
eval
()
return
model
@
staticmethod
def
encodec_model_48khz
(
pretrained
:
bool
=
True
,
repository
:
tp
.
Optional
[
Path
]
=
None
):
"""Return the pretrained 48khz model.
"""
if
repository
:
assert
pretrained
target_bandwidths
=
[
3.
,
6.
,
12.
,
24.
]
checkpoint_name
=
'encodec_48khz-7e698e3e.th'
sample_rate
=
48_000
channels
=
2
model
=
EncodecModel
.
_get_model
(
target_bandwidths
,
sample_rate
,
channels
,
causal
=
False
,
model_norm
=
'time_group_norm'
,
audio_normalize
=
True
,
segment
=
1.
,
name
=
'encodec_48khz'
if
pretrained
else
'unset'
)
if
pretrained
:
state_dict
=
EncodecModel
.
_get_pretrained
(
checkpoint_name
,
repository
)
model
.
load_state_dict
(
state_dict
)
model
.
eval
()
return
model
def
test
():
from
itertools
import
product
import
torchaudio
bandwidths
=
[
3
,
6
,
12
,
24
]
models
=
{
'encodec_24khz'
:
EncodecModel
.
encodec_model_24khz
,
'encodec_48khz'
:
EncodecModel
.
encodec_model_48khz
}
for
model_name
,
bw
in
product
(
models
.
keys
(),
bandwidths
):
model
=
models
[
model_name
]()
model
.
set_target_bandwidth
(
bw
)
audio_suffix
=
model_name
.
split
(
'_'
)[
1
][:
3
]
wav
,
sr
=
torchaudio
.
load
(
f
"test_
{
audio_suffix
}
.wav"
)
wav
=
wav
[:,
:
model
.
sample_rate
*
2
]
wav_in
=
wav
.
unsqueeze
(
0
)
wav_dec
=
model
(
wav_in
)[
0
]
assert
wav
.
shape
==
wav_dec
.
shape
,
(
wav
.
shape
,
wav_dec
.
shape
)
if
__name__
==
'__main__'
:
test
()
examples/music_generation/inspiremusic/wavtokenizer/encoder/modules/__init__.py
0 → 100644
View file @
0112b0f0
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Torch modules."""
# flake8: noqa
from
.conv
import
(
pad1d
,
unpad1d
,
NormConv1d
,
NormConvTranspose1d
,
NormConv2d
,
NormConvTranspose2d
,
SConv1d
,
SConvTranspose1d
,
)
from
.lstm
import
SLSTM
from
.seanet
import
SEANetEncoder
,
SEANetDecoder
from
.transformer
import
StreamingTransformerEncoder
examples/music_generation/inspiremusic/wavtokenizer/encoder/modules/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
examples/music_generation/inspiremusic/wavtokenizer/encoder/modules/__pycache__/conv.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
examples/music_generation/inspiremusic/wavtokenizer/encoder/modules/__pycache__/lstm.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
examples/music_generation/inspiremusic/wavtokenizer/encoder/modules/__pycache__/norm.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
examples/music_generation/inspiremusic/wavtokenizer/encoder/modules/__pycache__/seanet.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
examples/music_generation/inspiremusic/wavtokenizer/encoder/modules/__pycache__/transformer.cpython-310.pyc
0 → 100644
View file @
0112b0f0
File added
examples/music_generation/inspiremusic/wavtokenizer/encoder/modules/conv.py
0 → 100644
View file @
0112b0f0
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Convolutional layers wrappers and utilities."""
import
math
import
typing
as
tp
import
warnings
import
torch
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
torch.nn.utils
import
spectral_norm
,
weight_norm
from
.norm
import
ConvLayerNorm
CONV_NORMALIZATIONS
=
frozenset
([
'none'
,
'weight_norm'
,
'spectral_norm'
,
'time_layer_norm'
,
'layer_norm'
,
'time_group_norm'
])
def
apply_parametrization_norm
(
module
:
nn
.
Module
,
norm
:
str
=
'none'
)
->
nn
.
Module
:
assert
norm
in
CONV_NORMALIZATIONS
if
norm
==
'weight_norm'
:
return
weight_norm
(
module
)
elif
norm
==
'spectral_norm'
:
return
spectral_norm
(
module
)
else
:
# We already check was in CONV_NORMALIZATION, so any other choice
# doesn't need reparametrization.
return
module
def
get_norm_module
(
module
:
nn
.
Module
,
causal
:
bool
=
False
,
norm
:
str
=
'none'
,
**
norm_kwargs
)
->
nn
.
Module
:
"""Return the proper normalization module. If causal is True, this will ensure the returned
module is causal, or return an error if the normalization doesn't support causal evaluation.
"""
assert
norm
in
CONV_NORMALIZATIONS
if
norm
==
'layer_norm'
:
assert
isinstance
(
module
,
nn
.
modules
.
conv
.
_ConvNd
)
return
ConvLayerNorm
(
module
.
out_channels
,
**
norm_kwargs
)
elif
norm
==
'time_group_norm'
:
if
causal
:
raise
ValueError
(
"GroupNorm doesn't support causal evaluation."
)
assert
isinstance
(
module
,
nn
.
modules
.
conv
.
_ConvNd
)
return
nn
.
GroupNorm
(
1
,
module
.
out_channels
,
**
norm_kwargs
)
else
:
return
nn
.
Identity
()
def
get_extra_padding_for_conv1d
(
x
:
torch
.
Tensor
,
kernel_size
:
int
,
stride
:
int
,
padding_total
:
int
=
0
)
->
int
:
"""See `pad_for_conv1d`.
"""
length
=
x
.
shape
[
-
1
]
n_frames
=
(
length
-
kernel_size
+
padding_total
)
/
stride
+
1
ideal_length
=
(
math
.
ceil
(
n_frames
)
-
1
)
*
stride
+
(
kernel_size
-
padding_total
)
return
ideal_length
-
length
def
pad_for_conv1d
(
x
:
torch
.
Tensor
,
kernel_size
:
int
,
stride
:
int
,
padding_total
:
int
=
0
):
"""Pad for a convolution to make sure that the last window is full.
Extra padding is added at the end. This is required to ensure that we can rebuild
an output of the same length, as otherwise, even with padding, some time steps
might get removed.
For instance, with total padding = 4, kernel size = 4, stride = 2:
0 0 1 2 3 4 5 0 0 # (0s are padding)
1 2 3 # (output frames of a convolution, last 0 is never used)
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
1 2 3 4 # once you removed padding, we are missing one time step !
"""
extra_padding
=
get_extra_padding_for_conv1d
(
x
,
kernel_size
,
stride
,
padding_total
)
return
F
.
pad
(
x
,
(
0
,
extra_padding
))
def
pad1d
(
x
:
torch
.
Tensor
,
paddings
:
tp
.
Tuple
[
int
,
int
],
mode
:
str
=
'zero'
,
value
:
float
=
0.
):
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
If this is the case, we insert extra 0 padding to the right before the reflection happen.
"""
length
=
x
.
shape
[
-
1
]
padding_left
,
padding_right
=
paddings
assert
padding_left
>=
0
and
padding_right
>=
0
,
(
padding_left
,
padding_right
)
if
mode
==
'reflect'
:
max_pad
=
max
(
padding_left
,
padding_right
)
extra_pad
=
0
if
length
<=
max_pad
:
extra_pad
=
max_pad
-
length
+
1
x
=
F
.
pad
(
x
,
(
0
,
extra_pad
))
padded
=
F
.
pad
(
x
,
paddings
,
mode
,
value
)
end
=
padded
.
shape
[
-
1
]
-
extra_pad
return
padded
[...,
:
end
]
else
:
return
F
.
pad
(
x
,
paddings
,
mode
,
value
)
def
unpad1d
(
x
:
torch
.
Tensor
,
paddings
:
tp
.
Tuple
[
int
,
int
]):
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
padding_left
,
padding_right
=
paddings
assert
padding_left
>=
0
and
padding_right
>=
0
,
(
padding_left
,
padding_right
)
assert
(
padding_left
+
padding_right
)
<=
x
.
shape
[
-
1
]
end
=
x
.
shape
[
-
1
]
-
padding_right
return
x
[...,
padding_left
:
end
]
class
NormConv1d
(
nn
.
Module
):
"""Wrapper around Conv1d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def
__init__
(
self
,
*
args
,
causal
:
bool
=
False
,
norm
:
str
=
'none'
,
norm_kwargs
:
tp
.
Dict
[
str
,
tp
.
Any
]
=
{},
**
kwargs
):
super
().
__init__
()
self
.
conv
=
apply_parametrization_norm
(
nn
.
Conv1d
(
*
args
,
**
kwargs
),
norm
)
self
.
norm
=
get_norm_module
(
self
.
conv
,
causal
,
norm
,
**
norm_kwargs
)
self
.
norm_type
=
norm
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
norm
(
x
)
return
x
class
NormConv2d
(
nn
.
Module
):
"""Wrapper around Conv2d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def
__init__
(
self
,
*
args
,
norm
:
str
=
'none'
,
norm_kwargs
:
tp
.
Dict
[
str
,
tp
.
Any
]
=
{},
**
kwargs
):
super
().
__init__
()
self
.
conv
=
apply_parametrization_norm
(
nn
.
Conv2d
(
*
args
,
**
kwargs
),
norm
)
self
.
norm
=
get_norm_module
(
self
.
conv
,
causal
=
False
,
norm
=
norm
,
**
norm_kwargs
)
self
.
norm_type
=
norm
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
norm
(
x
)
return
x
class
NormConvTranspose1d
(
nn
.
Module
):
"""Wrapper around ConvTranspose1d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def
__init__
(
self
,
*
args
,
causal
:
bool
=
False
,
norm
:
str
=
'none'
,
norm_kwargs
:
tp
.
Dict
[
str
,
tp
.
Any
]
=
{},
**
kwargs
):
super
().
__init__
()
self
.
convtr
=
apply_parametrization_norm
(
nn
.
ConvTranspose1d
(
*
args
,
**
kwargs
),
norm
)
self
.
norm
=
get_norm_module
(
self
.
convtr
,
causal
,
norm
,
**
norm_kwargs
)
self
.
norm_type
=
norm
def
forward
(
self
,
x
):
x
=
self
.
convtr
(
x
)
x
=
self
.
norm
(
x
)
return
x
class
NormConvTranspose2d
(
nn
.
Module
):
"""Wrapper around ConvTranspose2d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def
__init__
(
self
,
*
args
,
norm
:
str
=
'none'
,
norm_kwargs
:
tp
.
Dict
[
str
,
tp
.
Any
]
=
{},
**
kwargs
):
super
().
__init__
()
self
.
convtr
=
apply_parametrization_norm
(
nn
.
ConvTranspose2d
(
*
args
,
**
kwargs
),
norm
)
self
.
norm
=
get_norm_module
(
self
.
convtr
,
causal
=
False
,
norm
=
norm
,
**
norm_kwargs
)
def
forward
(
self
,
x
):
x
=
self
.
convtr
(
x
)
x
=
self
.
norm
(
x
)
return
x
class
SConv1d
(
nn
.
Module
):
"""Conv1d with some builtin handling of asymmetric or causal padding
and normalization.
"""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
int
,
stride
:
int
=
1
,
dilation
:
int
=
1
,
groups
:
int
=
1
,
bias
:
bool
=
True
,
causal
:
bool
=
False
,
norm
:
str
=
'none'
,
norm_kwargs
:
tp
.
Dict
[
str
,
tp
.
Any
]
=
{},
pad_mode
:
str
=
'reflect'
):
super
().
__init__
()
# warn user on unusual setup between dilation and stride
if
stride
>
1
and
dilation
>
1
:
warnings
.
warn
(
'SConv1d has been initialized with stride > 1 and dilation > 1'
f
' (kernel_size=
{
kernel_size
}
stride=
{
stride
}
, dilation=
{
dilation
}
).'
)
self
.
conv
=
NormConv1d
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
dilation
=
dilation
,
groups
=
groups
,
bias
=
bias
,
causal
=
causal
,
norm
=
norm
,
norm_kwargs
=
norm_kwargs
)
self
.
causal
=
causal
self
.
pad_mode
=
pad_mode
def
forward
(
self
,
x
):
B
,
C
,
T
=
x
.
shape
kernel_size
=
self
.
conv
.
conv
.
kernel_size
[
0
]
stride
=
self
.
conv
.
conv
.
stride
[
0
]
dilation
=
self
.
conv
.
conv
.
dilation
[
0
]
kernel_size
=
(
kernel_size
-
1
)
*
dilation
+
1
# effective kernel size with dilations
padding_total
=
kernel_size
-
stride
extra_padding
=
get_extra_padding_for_conv1d
(
x
,
kernel_size
,
stride
,
padding_total
)
if
self
.
causal
:
# Left padding for causal
x
=
pad1d
(
x
,
(
padding_total
,
extra_padding
),
mode
=
self
.
pad_mode
)
else
:
# Asymmetric padding required for odd strides
padding_right
=
padding_total
//
2
padding_left
=
padding_total
-
padding_right
x
=
pad1d
(
x
,
(
padding_left
,
padding_right
+
extra_padding
),
mode
=
self
.
pad_mode
)
return
self
.
conv
(
x
)
class
SConvTranspose1d
(
nn
.
Module
):
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
and normalization.
"""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
int
,
stride
:
int
=
1
,
causal
:
bool
=
False
,
norm
:
str
=
'none'
,
trim_right_ratio
:
float
=
1.
,
norm_kwargs
:
tp
.
Dict
[
str
,
tp
.
Any
]
=
{}):
super
().
__init__
()
self
.
convtr
=
NormConvTranspose1d
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
causal
=
causal
,
norm
=
norm
,
norm_kwargs
=
norm_kwargs
)
self
.
causal
=
causal
self
.
trim_right_ratio
=
trim_right_ratio
assert
self
.
causal
or
self
.
trim_right_ratio
==
1.
,
\
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
assert
self
.
trim_right_ratio
>=
0.
and
self
.
trim_right_ratio
<=
1.
def
forward
(
self
,
x
):
kernel_size
=
self
.
convtr
.
convtr
.
kernel_size
[
0
]
stride
=
self
.
convtr
.
convtr
.
stride
[
0
]
padding_total
=
kernel_size
-
stride
y
=
self
.
convtr
(
x
)
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
# removed at the very end, when keeping only the right length for the output,
# as removing it here would require also passing the length at the matching layer
# in the encoder.
if
self
.
causal
:
# Trim the padding on the right according to the specified ratio
# if trim_right_ratio = 1.0, trim everything from right
padding_right
=
math
.
ceil
(
padding_total
*
self
.
trim_right_ratio
)
padding_left
=
padding_total
-
padding_right
y
=
unpad1d
(
y
,
(
padding_left
,
padding_right
))
else
:
# Asymmetric padding required for odd strides
padding_right
=
padding_total
//
2
padding_left
=
padding_total
-
padding_right
y
=
unpad1d
(
y
,
(
padding_left
,
padding_right
))
return
y
examples/music_generation/inspiremusic/wavtokenizer/encoder/modules/lstm.py
0 → 100644
View file @
0112b0f0
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""LSTM layers module."""
from
torch
import
nn
class
SLSTM
(
nn
.
Module
):
"""
LSTM without worrying about the hidden state, nor the layout of the data.
Expects input as convolutional layout.
"""
def
__init__
(
self
,
dimension
:
int
,
num_layers
:
int
=
2
,
skip
:
bool
=
True
):
super
().
__init__
()
self
.
skip
=
skip
self
.
lstm
=
nn
.
LSTM
(
dimension
,
dimension
,
num_layers
)
# def forward(self, x):
# x = x.permute(2, 0, 1)
# y, _ = self.lstm(x)
# if self.skip:
# y = y + x
# y = y.permute(1, 2, 0)
# return y
# 修改transpose顺序
def
forward
(
self
,
x
):
# # 插入reshape
# x = x.reshape(x.shape)
x1
=
x
.
permute
(
2
,
0
,
1
)
y
,
_
=
self
.
lstm
(
x1
)
y
=
y
.
permute
(
1
,
2
,
0
)
if
self
.
skip
:
y
=
y
+
x
return
y
examples/music_generation/inspiremusic/wavtokenizer/encoder/modules/norm.py
0 → 100644
View file @
0112b0f0
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Normalization modules."""
import
typing
as
tp
import
einops
import
torch
from
torch
import
nn
class
ConvLayerNorm
(
nn
.
LayerNorm
):
"""
Convolution-friendly LayerNorm that moves channels to last dimensions
before running the normalization and moves them back to original position right after.
"""
def
__init__
(
self
,
normalized_shape
:
tp
.
Union
[
int
,
tp
.
List
[
int
],
torch
.
Size
],
**
kwargs
):
super
().
__init__
(
normalized_shape
,
**
kwargs
)
def
forward
(
self
,
x
):
x
=
einops
.
rearrange
(
x
,
'b ... t -> b t ...'
)
x
=
super
().
forward
(
x
)
x
=
einops
.
rearrange
(
x
,
'b t ... -> b ... t'
)
return
examples/music_generation/inspiremusic/wavtokenizer/encoder/modules/seanet.py
0 → 100644
View file @
0112b0f0
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Encodec SEANet-based encoder and decoder implementation."""
import
typing
as
tp
import
numpy
as
np
import
torch.nn
as
nn
from
.
import
(
SConv1d
,
SConvTranspose1d
,
SLSTM
)
class
SEANetResnetBlock
(
nn
.
Module
):
"""Residual block from SEANet model.
Args:
dim (int): Dimension of the input/output
kernel_sizes (list): List of kernel sizes for the convolutions.
dilations (list): List of dilations for the convolutions.
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function
norm (str): Normalization method.
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
causal (bool): Whether to use fully causal convolution.
pad_mode (str): Padding mode for the convolutions.
compress (int): Reduced dimensionality in residual branches (from Demucs v3)
true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection.
"""
def
__init__
(
self
,
dim
:
int
,
kernel_sizes
:
tp
.
List
[
int
]
=
[
3
,
1
],
dilations
:
tp
.
List
[
int
]
=
[
1
,
1
],
activation
:
str
=
'ELU'
,
activation_params
:
dict
=
{
'alpha'
:
1.0
},
norm
:
str
=
'weight_norm'
,
norm_params
:
tp
.
Dict
[
str
,
tp
.
Any
]
=
{},
causal
:
bool
=
False
,
pad_mode
:
str
=
'reflect'
,
compress
:
int
=
2
,
true_skip
:
bool
=
True
):
super
().
__init__
()
assert
len
(
kernel_sizes
)
==
len
(
dilations
),
'Number of kernel sizes should match number of dilations'
act
=
getattr
(
nn
,
activation
)
hidden
=
dim
//
compress
block
=
[]
for
i
,
(
kernel_size
,
dilation
)
in
enumerate
(
zip
(
kernel_sizes
,
dilations
)):
in_chs
=
dim
if
i
==
0
else
hidden
out_chs
=
dim
if
i
==
len
(
kernel_sizes
)
-
1
else
hidden
block
+=
[
act
(
**
activation_params
),
SConv1d
(
in_chs
,
out_chs
,
kernel_size
=
kernel_size
,
dilation
=
dilation
,
norm
=
norm
,
norm_kwargs
=
norm_params
,
causal
=
causal
,
pad_mode
=
pad_mode
),
]
self
.
block
=
nn
.
Sequential
(
*
block
)
self
.
shortcut
:
nn
.
Module
if
true_skip
:
self
.
shortcut
=
nn
.
Identity
()
else
:
self
.
shortcut
=
SConv1d
(
dim
,
dim
,
kernel_size
=
1
,
norm
=
norm
,
norm_kwargs
=
norm_params
,
causal
=
causal
,
pad_mode
=
pad_mode
)
def
forward
(
self
,
x
):
return
self
.
shortcut
(
x
)
+
self
.
block
(
x
)
class
SEANetEncoder
(
nn
.
Module
):
"""SEANet encoder.
Args:
channels (int): Audio channels.
dimension (int): Intermediate representation dimension.
n_filters (int): Base width for the model.
n_residual_layers (int): nb of residual layers.
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
that must match the decoder order
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function
norm (str): Normalization method.
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
kernel_size (int): Kernel size for the initial convolution.
last_kernel_size (int): Kernel size for the initial convolution.
residual_kernel_size (int): Kernel size for the residual layers.
dilation_base (int): How much to increase the dilation with each layer.
causal (bool): Whether to use fully causal convolution.
pad_mode (str): Padding mode for the convolutions.
true_skip (bool): Whether to use true skip connection or a simple
(streamable) convolution as the skip connection in the residual network blocks.
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
lstm (int): Number of LSTM layers at the end of the encoder.
"""
def
__init__
(
self
,
channels
:
int
=
1
,
dimension
:
int
=
128
,
n_filters
:
int
=
32
,
n_residual_layers
:
int
=
1
,
ratios
:
tp
.
List
[
int
]
=
[
8
,
5
,
4
,
2
],
activation
:
str
=
'ELU'
,
activation_params
:
dict
=
{
'alpha'
:
1.0
},
norm
:
str
=
'weight_norm'
,
norm_params
:
tp
.
Dict
[
str
,
tp
.
Any
]
=
{},
kernel_size
:
int
=
7
,
last_kernel_size
:
int
=
7
,
residual_kernel_size
:
int
=
3
,
dilation_base
:
int
=
2
,
causal
:
bool
=
False
,
pad_mode
:
str
=
'reflect'
,
true_skip
:
bool
=
False
,
compress
:
int
=
2
,
lstm
:
int
=
2
):
super
().
__init__
()
self
.
channels
=
channels
self
.
dimension
=
dimension
self
.
n_filters
=
n_filters
self
.
ratios
=
list
(
reversed
(
ratios
))
del
ratios
self
.
n_residual_layers
=
n_residual_layers
self
.
hop_length
=
np
.
prod
(
self
.
ratios
)
act
=
getattr
(
nn
,
activation
)
mult
=
1
model
:
tp
.
List
[
nn
.
Module
]
=
[
SConv1d
(
channels
,
mult
*
n_filters
,
kernel_size
,
norm
=
norm
,
norm_kwargs
=
norm_params
,
causal
=
causal
,
pad_mode
=
pad_mode
)
]
# Downsample to raw audio scale
for
i
,
ratio
in
enumerate
(
self
.
ratios
):
# Add residual layers
for
j
in
range
(
n_residual_layers
):
model
+=
[
SEANetResnetBlock
(
mult
*
n_filters
,
kernel_sizes
=
[
residual_kernel_size
,
1
],
dilations
=
[
dilation_base
**
j
,
1
],
norm
=
norm
,
norm_params
=
norm_params
,
activation
=
activation
,
activation_params
=
activation_params
,
causal
=
causal
,
pad_mode
=
pad_mode
,
compress
=
compress
,
true_skip
=
true_skip
)]
# Add downsampling layers
model
+=
[
act
(
**
activation_params
),
SConv1d
(
mult
*
n_filters
,
mult
*
n_filters
*
2
,
kernel_size
=
ratio
*
2
,
stride
=
ratio
,
norm
=
norm
,
norm_kwargs
=
norm_params
,
causal
=
causal
,
pad_mode
=
pad_mode
),
]
mult
*=
2
if
lstm
:
model
+=
[
SLSTM
(
mult
*
n_filters
,
num_layers
=
lstm
)]
model
+=
[
act
(
**
activation_params
),
SConv1d
(
mult
*
n_filters
,
dimension
,
last_kernel_size
,
norm
=
norm
,
norm_kwargs
=
norm_params
,
causal
=
causal
,
pad_mode
=
pad_mode
)
]
self
.
model
=
nn
.
Sequential
(
*
model
)
def
forward
(
self
,
x
):
return
self
.
model
(
x
)
class
SEANetDecoder
(
nn
.
Module
):
"""SEANet decoder.
Args:
channels (int): Audio channels.
dimension (int): Intermediate representation dimension.
n_filters (int): Base width for the model.
n_residual_layers (int): nb of residual layers.
ratios (Sequence[int]): kernel size and stride ratios
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function
final_activation (str): Final activation function after all convolutions.
final_activation_params (dict): Parameters to provide to the activation function
norm (str): Normalization method.
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
kernel_size (int): Kernel size for the initial convolution.
last_kernel_size (int): Kernel size for the initial convolution.
residual_kernel_size (int): Kernel size for the residual layers.
dilation_base (int): How much to increase the dilation with each layer.
causal (bool): Whether to use fully causal convolution.
pad_mode (str): Padding mode for the convolutions.
true_skip (bool): Whether to use true skip connection or a simple
(streamable) convolution as the skip connection in the residual network blocks.
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
lstm (int): Number of LSTM layers at the end of the encoder.
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
If equal to 1.0, it means that all the trimming is done at the right.
"""
def
__init__
(
self
,
channels
:
int
=
1
,
dimension
:
int
=
128
,
n_filters
:
int
=
32
,
n_residual_layers
:
int
=
1
,
ratios
:
tp
.
List
[
int
]
=
[
8
,
5
,
4
,
2
],
activation
:
str
=
'ELU'
,
activation_params
:
dict
=
{
'alpha'
:
1.0
},
final_activation
:
tp
.
Optional
[
str
]
=
None
,
final_activation_params
:
tp
.
Optional
[
dict
]
=
None
,
norm
:
str
=
'weight_norm'
,
norm_params
:
tp
.
Dict
[
str
,
tp
.
Any
]
=
{},
kernel_size
:
int
=
7
,
last_kernel_size
:
int
=
7
,
residual_kernel_size
:
int
=
3
,
dilation_base
:
int
=
2
,
causal
:
bool
=
False
,
pad_mode
:
str
=
'reflect'
,
true_skip
:
bool
=
False
,
compress
:
int
=
2
,
lstm
:
int
=
2
,
trim_right_ratio
:
float
=
1.0
):
super
().
__init__
()
self
.
dimension
=
dimension
self
.
channels
=
channels
self
.
n_filters
=
n_filters
self
.
ratios
=
ratios
del
ratios
self
.
n_residual_layers
=
n_residual_layers
self
.
hop_length
=
np
.
prod
(
self
.
ratios
)
act
=
getattr
(
nn
,
activation
)
mult
=
int
(
2
**
len
(
self
.
ratios
))
model
:
tp
.
List
[
nn
.
Module
]
=
[
SConv1d
(
dimension
,
mult
*
n_filters
,
kernel_size
,
norm
=
norm
,
norm_kwargs
=
norm_params
,
causal
=
causal
,
pad_mode
=
pad_mode
)
]
if
lstm
:
model
+=
[
SLSTM
(
mult
*
n_filters
,
num_layers
=
lstm
)]
# Upsample to raw audio scale
for
i
,
ratio
in
enumerate
(
self
.
ratios
):
# Add upsampling layers
model
+=
[
act
(
**
activation_params
),
SConvTranspose1d
(
mult
*
n_filters
,
mult
*
n_filters
//
2
,
kernel_size
=
ratio
*
2
,
stride
=
ratio
,
norm
=
norm
,
norm_kwargs
=
norm_params
,
causal
=
causal
,
trim_right_ratio
=
trim_right_ratio
),
]
# Add residual layers
for
j
in
range
(
n_residual_layers
):
model
+=
[
SEANetResnetBlock
(
mult
*
n_filters
//
2
,
kernel_sizes
=
[
residual_kernel_size
,
1
],
dilations
=
[
dilation_base
**
j
,
1
],
activation
=
activation
,
activation_params
=
activation_params
,
norm
=
norm
,
norm_params
=
norm_params
,
causal
=
causal
,
pad_mode
=
pad_mode
,
compress
=
compress
,
true_skip
=
true_skip
)]
mult
//=
2
# Add final layers
model
+=
[
act
(
**
activation_params
),
SConv1d
(
n_filters
,
channels
,
last_kernel_size
,
norm
=
norm
,
norm_kwargs
=
norm_params
,
causal
=
causal
,
pad_mode
=
pad_mode
)
]
# Add optional final activation to decoder (eg. tanh)
if
final_activation
is
not
None
:
final_act
=
getattr
(
nn
,
final_activation
)
final_activation_params
=
final_activation_params
or
{}
model
+=
[
final_act
(
**
final_activation_params
)
]
self
.
model
=
nn
.
Sequential
(
*
model
)
def
forward
(
self
,
z
):
y
=
self
.
model
(
z
)
return
y
def
test
():
import
torch
encoder
=
SEANetEncoder
()
decoder
=
SEANetDecoder
()
x
=
torch
.
randn
(
1
,
1
,
24000
)
z
=
encoder
(
x
)
assert
list
(
z
.
shape
)
==
[
1
,
128
,
75
],
z
.
shape
y
=
decoder
(
z
)
assert
y
.
shape
==
x
.
shape
,
(
x
.
shape
,
y
.
shape
)
if
__name__
==
'__main__'
:
test
()
examples/music_generation/inspiremusic/wavtokenizer/encoder/modules/transformer.py
0 → 100644
View file @
0112b0f0
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""A streamable transformer."""
import
typing
as
tp
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
def
create_sin_embedding
(
positions
:
torch
.
Tensor
,
dim
:
int
,
max_period
:
float
=
10000
):
"""Create time embedding for the given positions, target dimension `dim`.
"""
# We aim for BTC format
assert
dim
%
2
==
0
half_dim
=
dim
//
2
adim
=
torch
.
arange
(
half_dim
,
device
=
positions
.
device
).
view
(
1
,
1
,
-
1
)
phase
=
positions
/
(
max_period
**
(
adim
/
(
half_dim
-
1
)))
return
torch
.
cat
([
torch
.
cos
(
phase
),
torch
.
sin
(
phase
),
],
dim
=-
1
)
class
StreamingTransformerEncoderLayer
(
nn
.
TransformerEncoderLayer
):
def
forward
(
self
,
x
:
torch
.
Tensor
,
x_past
:
torch
.
Tensor
,
past_context
:
int
):
# type: ignore
if
self
.
norm_first
:
sa_input
=
self
.
norm1
(
x
)
x
=
x
+
self
.
_sa_block
(
sa_input
,
x_past
,
past_context
)
x
=
x
+
self
.
_ff_block
(
self
.
norm2
(
x
))
else
:
sa_input
=
x
x
=
self
.
norm1
(
x
+
self
.
_sa_block
(
sa_input
,
x_past
,
past_context
))
x
=
self
.
norm2
(
x
+
self
.
_ff_block
(
x
))
return
x
,
sa_input
# self-attention block
def
_sa_block
(
self
,
x
:
torch
.
Tensor
,
x_past
:
torch
.
Tensor
,
past_context
:
int
):
# type: ignore
_
,
T
,
_
=
x
.
shape
_
,
H
,
_
=
x_past
.
shape
queries
=
x
keys
=
torch
.
cat
([
x_past
,
x
],
dim
=
1
)
values
=
keys
queries_pos
=
torch
.
arange
(
H
,
T
+
H
,
device
=
x
.
device
).
view
(
-
1
,
1
)
keys_pos
=
torch
.
arange
(
T
+
H
,
device
=
x
.
device
).
view
(
1
,
-
1
)
delta
=
queries_pos
-
keys_pos
valid_access
=
(
delta
>=
0
)
&
(
delta
<=
past_context
)
x
=
self
.
self_attn
(
queries
,
keys
,
values
,
attn_mask
=~
valid_access
,
need_weights
=
False
)[
0
]
return
self
.
dropout1
(
x
)
class
StreamingTransformerEncoder
(
nn
.
Module
):
"""TransformerEncoder with streaming support.
Args:
dim (int): dimension of the data.
hidden_scale (int): intermediate dimension of FF module is this times the dimension.
num_heads (int): number of heads.
num_layers (int): number of layers.
max_period (float): maxium period of cosines in the positional embedding.
past_context (int or None): receptive field for the causal mask, infinite if None.
gelu (bool): if true uses GeLUs, otherwise use ReLUs.
norm_in (bool): normalize the input.
dropout (float): dropout probability.
**kwargs: See `nn.TransformerEncoderLayer`.
"""
def
__init__
(
self
,
dim
,
hidden_scale
:
float
=
4.
,
num_heads
:
int
=
8
,
num_layers
:
int
=
5
,
max_period
:
float
=
10000
,
past_context
:
int
=
1000
,
gelu
:
bool
=
True
,
norm_in
:
bool
=
True
,
dropout
:
float
=
0.
,
**
kwargs
):
super
().
__init__
()
assert
dim
%
num_heads
==
0
hidden_dim
=
int
(
dim
*
hidden_scale
)
self
.
max_period
=
max_period
self
.
past_context
=
past_context
activation
:
tp
.
Any
=
F
.
gelu
if
gelu
else
F
.
relu
self
.
norm_in
:
nn
.
Module
if
norm_in
:
self
.
norm_in
=
nn
.
LayerNorm
(
dim
)
else
:
self
.
norm_in
=
nn
.
Identity
()
self
.
layers
=
nn
.
ModuleList
()
for
idx
in
range
(
num_layers
):
self
.
layers
.
append
(
StreamingTransformerEncoderLayer
(
dim
,
num_heads
,
hidden_dim
,
activation
=
activation
,
batch_first
=
True
,
dropout
=
dropout
,
**
kwargs
))
def
forward
(
self
,
x
:
torch
.
Tensor
,
states
:
tp
.
Optional
[
tp
.
List
[
torch
.
Tensor
]]
=
None
,
offset
:
tp
.
Union
[
int
,
torch
.
Tensor
]
=
0
):
B
,
T
,
C
=
x
.
shape
if
states
is
None
:
states
=
[
torch
.
zeros_like
(
x
[:,
:
1
])
for
_
in
range
(
1
+
len
(
self
.
layers
))]
positions
=
torch
.
arange
(
T
,
device
=
x
.
device
).
view
(
1
,
-
1
,
1
)
+
offset
pos_emb
=
create_sin_embedding
(
positions
,
C
,
max_period
=
self
.
max_period
)
new_state
:
tp
.
List
[
torch
.
Tensor
]
=
[]
x
=
self
.
norm_in
(
x
)
x
=
x
+
pos_emb
for
layer_state
,
layer
in
zip
(
states
,
self
.
layers
):
x
,
new_layer_state
=
layer
(
x
,
layer_state
,
self
.
past_context
)
new_layer_state
=
torch
.
cat
([
layer_state
,
new_layer_state
],
dim
=
1
)
new_state
.
append
(
new_layer_state
[:,
-
self
.
past_context
:,
:])
return
x
,
new_state
,
offset
+
T
examples/music_generation/inspiremusic/wavtokenizer/encoder/msstftd.py
0 → 100644
View file @
0112b0f0
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""MS-STFT discriminator, provided here for reference."""
import
typing
as
tp
import
torchaudio
import
torch
from
torch
import
nn
from
einops
import
rearrange
from
.modules
import
NormConv2d
FeatureMapType
=
tp
.
List
[
torch
.
Tensor
]
LogitsType
=
torch
.
Tensor
DiscriminatorOutput
=
tp
.
Tuple
[
tp
.
List
[
LogitsType
],
tp
.
List
[
FeatureMapType
]]
def
get_2d_padding
(
kernel_size
:
tp
.
Tuple
[
int
,
int
],
dilation
:
tp
.
Tuple
[
int
,
int
]
=
(
1
,
1
)):
return
(((
kernel_size
[
0
]
-
1
)
*
dilation
[
0
])
//
2
,
((
kernel_size
[
1
]
-
1
)
*
dilation
[
1
])
//
2
)
class
DiscriminatorSTFT
(
nn
.
Module
):
"""STFT sub-discriminator.
Args:
filters (int): Number of filters in convolutions
in_channels (int): Number of input channels. Default: 1
out_channels (int): Number of output channels. Default: 1
n_fft (int): Size of FFT for each scale. Default: 1024
hop_length (int): Length of hop between STFT windows for each scale. Default: 256
kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)``
stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)``
dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]``
win_length (int): Window size for each scale. Default: 1024
normalized (bool): Whether to normalize by magnitude after stft. Default: True
norm (str): Normalization method. Default: `'weight_norm'`
activation (str): Activation function. Default: `'LeakyReLU'`
activation_params (dict): Parameters to provide to the activation function.
growth (int): Growth factor for the filters. Default: 1
"""
def
__init__
(
self
,
filters
:
int
,
in_channels
:
int
=
1
,
out_channels
:
int
=
1
,
n_fft
:
int
=
1024
,
hop_length
:
int
=
256
,
win_length
:
int
=
1024
,
max_filters
:
int
=
1024
,
filters_scale
:
int
=
1
,
kernel_size
:
tp
.
Tuple
[
int
,
int
]
=
(
3
,
9
),
dilations
:
tp
.
List
=
[
1
,
2
,
4
],
stride
:
tp
.
Tuple
[
int
,
int
]
=
(
1
,
2
),
normalized
:
bool
=
True
,
norm
:
str
=
'weight_norm'
,
activation
:
str
=
'LeakyReLU'
,
activation_params
:
dict
=
{
'negative_slope'
:
0.2
}):
super
().
__init__
()
assert
len
(
kernel_size
)
==
2
assert
len
(
stride
)
==
2
self
.
filters
=
filters
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
n_fft
=
n_fft
self
.
hop_length
=
hop_length
self
.
win_length
=
win_length
self
.
normalized
=
normalized
self
.
activation
=
getattr
(
torch
.
nn
,
activation
)(
**
activation_params
)
self
.
spec_transform
=
torchaudio
.
transforms
.
Spectrogram
(
n_fft
=
self
.
n_fft
,
hop_length
=
self
.
hop_length
,
win_length
=
self
.
win_length
,
window_fn
=
torch
.
hann_window
,
normalized
=
self
.
normalized
,
center
=
False
,
pad_mode
=
None
,
power
=
None
)
spec_channels
=
2
*
self
.
in_channels
self
.
convs
=
nn
.
ModuleList
()
self
.
convs
.
append
(
NormConv2d
(
spec_channels
,
self
.
filters
,
kernel_size
=
kernel_size
,
padding
=
get_2d_padding
(
kernel_size
))
)
in_chs
=
min
(
filters_scale
*
self
.
filters
,
max_filters
)
for
i
,
dilation
in
enumerate
(
dilations
):
out_chs
=
min
((
filters_scale
**
(
i
+
1
))
*
self
.
filters
,
max_filters
)
self
.
convs
.
append
(
NormConv2d
(
in_chs
,
out_chs
,
kernel_size
=
kernel_size
,
stride
=
stride
,
dilation
=
(
dilation
,
1
),
padding
=
get_2d_padding
(
kernel_size
,
(
dilation
,
1
)),
norm
=
norm
))
in_chs
=
out_chs
out_chs
=
min
((
filters_scale
**
(
len
(
dilations
)
+
1
))
*
self
.
filters
,
max_filters
)
self
.
convs
.
append
(
NormConv2d
(
in_chs
,
out_chs
,
kernel_size
=
(
kernel_size
[
0
],
kernel_size
[
0
]),
padding
=
get_2d_padding
((
kernel_size
[
0
],
kernel_size
[
0
])),
norm
=
norm
))
self
.
conv_post
=
NormConv2d
(
out_chs
,
self
.
out_channels
,
kernel_size
=
(
kernel_size
[
0
],
kernel_size
[
0
]),
padding
=
get_2d_padding
((
kernel_size
[
0
],
kernel_size
[
0
])),
norm
=
norm
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
fmap
=
[]
z
=
self
.
spec_transform
(
x
)
# [B, 2, Freq, Frames, 2]
z
=
torch
.
cat
([
z
.
real
,
z
.
imag
],
dim
=
1
)
z
=
rearrange
(
z
,
'b c w t -> b c t w'
)
for
i
,
layer
in
enumerate
(
self
.
convs
):
z
=
layer
(
z
)
z
=
self
.
activation
(
z
)
fmap
.
append
(
z
)
z
=
self
.
conv_post
(
z
)
return
z
,
fmap
class
MultiScaleSTFTDiscriminator
(
nn
.
Module
):
"""Multi-Scale STFT (MS-STFT) discriminator.
Args:
filters (int): Number of filters in convolutions
in_channels (int): Number of input channels. Default: 1
out_channels (int): Number of output channels. Default: 1
n_ffts (Sequence[int]): Size of FFT for each scale
hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale
win_lengths (Sequence[int]): Window size for each scale
**kwargs: additional args for STFTDiscriminator
"""
def
__init__
(
self
,
filters
:
int
,
in_channels
:
int
=
1
,
out_channels
:
int
=
1
,
n_ffts
:
tp
.
List
[
int
]
=
[
1024
,
2048
,
512
],
hop_lengths
:
tp
.
List
[
int
]
=
[
256
,
512
,
128
],
win_lengths
:
tp
.
List
[
int
]
=
[
1024
,
2048
,
512
],
**
kwargs
):
super
().
__init__
()
assert
len
(
n_ffts
)
==
len
(
hop_lengths
)
==
len
(
win_lengths
)
self
.
discriminators
=
nn
.
ModuleList
([
DiscriminatorSTFT
(
filters
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
n_fft
=
n_ffts
[
i
],
win_length
=
win_lengths
[
i
],
hop_length
=
hop_lengths
[
i
],
**
kwargs
)
for
i
in
range
(
len
(
n_ffts
))
])
self
.
num_discriminators
=
len
(
self
.
discriminators
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
DiscriminatorOutput
:
logits
=
[]
fmaps
=
[]
for
disc
in
self
.
discriminators
:
logit
,
fmap
=
disc
(
x
)
logits
.
append
(
logit
)
fmaps
.
append
(
fmap
)
return
logits
,
fmaps
def
test
():
disc
=
MultiScaleSTFTDiscriminator
(
filters
=
32
)
y
=
torch
.
randn
(
1
,
1
,
24000
)
y_hat
=
torch
.
randn
(
1
,
1
,
24000
)
y_disc_r
,
fmap_r
=
disc
(
y
)
y_disc_gen
,
fmap_gen
=
disc
(
y_hat
)
assert
len
(
y_disc_r
)
==
len
(
y_disc_gen
)
==
len
(
fmap_r
)
==
len
(
fmap_gen
)
==
disc
.
num_discriminators
assert
all
([
len
(
fm
)
==
5
for
fm
in
fmap_r
+
fmap_gen
])
assert
all
([
list
(
f
.
shape
)[:
2
]
==
[
1
,
32
]
for
fm
in
fmap_r
+
fmap_gen
for
f
in
fm
])
assert
all
([
len
(
logits
.
shape
)
==
4
for
logits
in
y_disc_r
+
y_disc_gen
])
if
__name__
==
'__main__'
:
test
()
Prev
1
…
5
6
7
8
9
10
11
12
13
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment