Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
VITA-Audio_pytorch
Commits
39ac40a9
Commit
39ac40a9
authored
Jun 06, 2025
by
chenzk
Browse files
v1.0
parents
Pipeline
#2747
failed with stages
in 0 seconds
Changes
427
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
2191 additions
and
0 deletions
+2191
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/audio/feature_transforms/utterance_cmvn.py
...c/fairseq/data/audio/feature_transforms/utterance_cmvn.py
+40
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/audio/hubert_dataset.py
...dparty/UniSpeech/src/fairseq/data/audio/hubert_dataset.py
+459
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/audio/raw_audio_dataset.py
...rty/UniSpeech/src/fairseq/data/audio/raw_audio_dataset.py
+405
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/audio/speech_to_text_dataset.py
...niSpeech/src/fairseq/data/audio/speech_to_text_dataset.py
+511
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/audio/utterance_mixing_dataset.py
...Speech/src/fairseq/data/audio/utterance_mixing_dataset.py
+574
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/base_wrapper_dataset.py
...dparty/UniSpeech/src/fairseq/data/base_wrapper_dataset.py
+78
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/concat_dataset.py
...l/thirdparty/UniSpeech/src/fairseq/data/concat_dataset.py
+124
-0
No files found.
Too many changes to show.
To preserve performance only
427 of 427+
files are displayed.
Plain diff
Email patch
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/audio/feature_transforms/utterance_cmvn.py
0 → 100644
View file @
39ac40a9
import
numpy
as
np
from
fairseq.data.audio.feature_transforms
import
(
AudioFeatureTransform
,
register_audio_feature_transform
,
)
@
register_audio_feature_transform
(
"utterance_cmvn"
)
class
UtteranceCMVN
(
AudioFeatureTransform
):
"""Utterance-level CMVN (cepstral mean and variance normalization)"""
@
classmethod
def
from_config_dict
(
cls
,
config
=
None
):
_config
=
{}
if
config
is
None
else
config
return
UtteranceCMVN
(
_config
.
get
(
"norm_means"
,
True
),
_config
.
get
(
"norm_vars"
,
True
),
)
def
__init__
(
self
,
norm_means
=
True
,
norm_vars
=
True
):
self
.
norm_means
,
self
.
norm_vars
=
norm_means
,
norm_vars
def
__repr__
(
self
):
return
(
self
.
__class__
.
__name__
+
f
"(norm_means=
{
self
.
norm_means
}
, norm_vars=
{
self
.
norm_vars
}
)"
)
def
__call__
(
self
,
x
):
mean
=
x
.
mean
(
axis
=
0
)
square_sums
=
(
x
**
2
).
sum
(
axis
=
0
)
if
self
.
norm_means
:
x
=
np
.
subtract
(
x
,
mean
)
if
self
.
norm_vars
:
var
=
square_sums
/
x
.
shape
[
0
]
-
mean
**
2
std
=
np
.
sqrt
(
np
.
maximum
(
var
,
1e-10
))
x
=
np
.
divide
(
x
,
std
)
return
x
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/audio/hubert_dataset.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
itertools
import
logging
import
os
import
sys
import
io
from
typing
import
Any
,
List
,
Optional
,
Union
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
fairseq.data
import
data_utils
from
fairseq.data.fairseq_dataset
import
FairseqDataset
from
fairseq.data.audio.audio_utils
import
(
parse_path
,
read_from_stored_zip
,
is_sf_audio_data
,
)
logger
=
logging
.
getLogger
(
__name__
)
def
load_label
(
label_path
,
inds
,
tot
):
with
open
(
label_path
)
as
f
:
labels
=
[
line
.
rstrip
()
for
line
in
f
]
assert
(
len
(
labels
)
==
tot
),
f
"number of labels does not match (
{
len
(
labels
)
}
!=
{
tot
}
)"
labels
=
[
labels
[
i
]
for
i
in
inds
]
return
labels
def
load_label_offset
(
label_path
,
inds
,
tot
):
with
open
(
label_path
)
as
f
:
code_lengths
=
[
len
(
line
.
encode
(
"utf-8"
))
for
line
in
f
]
assert
(
len
(
code_lengths
)
==
tot
),
f
"number of labels does not match (
{
len
(
code_lengths
)
}
!=
{
tot
}
)"
offsets
=
list
(
itertools
.
accumulate
([
0
]
+
code_lengths
))
offsets
=
[(
offsets
[
i
],
offsets
[
i
+
1
])
for
i
in
inds
]
return
offsets
def
verify_label_lengths
(
audio_sizes
,
audio_rate
,
label_path
,
label_rate
,
inds
,
tot
,
tol
=
2
,
# tolerance in seconds
):
if
label_rate
<
0
:
logger
.
info
(
f
"
{
label_path
}
is sequence label. skipped"
)
return
with
open
(
label_path
)
as
f
:
lengths
=
[
len
(
line
.
rstrip
().
split
())
for
line
in
f
]
assert
len
(
lengths
)
==
tot
lengths
=
[
lengths
[
i
]
for
i
in
inds
]
num_invalid
=
0
for
i
,
ind
in
enumerate
(
inds
):
dur_from_audio
=
audio_sizes
[
i
]
/
audio_rate
dur_from_label
=
lengths
[
i
]
/
label_rate
if
abs
(
dur_from_audio
-
dur_from_label
)
>
tol
:
logger
.
warning
(
(
f
"audio and label duration differ too much "
f
"(|
{
dur_from_audio
}
-
{
dur_from_label
}
| >
{
tol
}
) "
f
"in line
{
ind
+
1
}
of
{
label_path
}
. Check if `label_rate` "
f
"is correctly set (currently
{
label_rate
}
). "
f
"num. of samples =
{
audio_sizes
[
i
]
}
; "
f
"label length =
{
lengths
[
i
]
}
"
)
)
num_invalid
+=
1
if
num_invalid
>
0
:
logger
.
warning
(
f
"total
{
num_invalid
}
(audio, label) pairs with mismatched lengths"
)
class
HubertDataset
(
FairseqDataset
):
def
__init__
(
self
,
manifest_path
:
str
,
sample_rate
:
float
,
label_paths
:
List
[
str
],
label_rates
:
Union
[
List
[
float
],
float
],
# -1 for sequence labels
pad_list
:
List
[
str
],
eos_list
:
List
[
str
],
label_processors
:
Optional
[
List
[
Any
]]
=
None
,
max_keep_sample_size
:
Optional
[
int
]
=
None
,
min_keep_sample_size
:
Optional
[
int
]
=
None
,
max_sample_size
:
Optional
[
int
]
=
None
,
shuffle
:
bool
=
True
,
pad_audio
:
bool
=
False
,
normalize
:
bool
=
False
,
store_labels
:
bool
=
True
,
random_crop
:
bool
=
False
,
single_target
:
bool
=
False
,
multitask
:
bool
=
False
):
self
.
sample_rate
=
sample_rate
self
.
shuffle
=
shuffle
self
.
random_crop
=
random_crop
self
.
num_labels
=
len
(
label_paths
)
self
.
pad_list
=
pad_list
self
.
eos_list
=
eos_list
self
.
label_processors
=
label_processors
self
.
single_target
=
single_target
self
.
multitask
=
multitask
self
.
epoch
=
0
self
.
chunk_names
=
[]
self
.
chunk_indices
=
[]
n_long
,
n_short
=
0
,
0
names
,
inds
,
sizes
=
[],
[],
[]
with
open
(
manifest_path
)
as
f
:
root
=
f
.
readline
().
strip
()
for
ind
,
line
in
enumerate
(
f
):
items
=
line
.
strip
().
split
(
"
\t
"
)
sz
=
int
(
items
[
1
])
if
min_keep_sample_size
is
not
None
and
sz
<
min_keep_sample_size
:
n_short
+=
1
elif
max_keep_sample_size
is
not
None
and
sz
>
max_keep_sample_size
:
n_long
+=
1
else
:
fname
=
items
[
0
].
split
(
":"
)
if
len
(
fname
)
>
1
:
if
len
(
self
.
chunk_names
)
==
0
or
fname
[
0
]
!=
self
.
chunk_names
[
-
1
]:
self
.
chunk_names
.
append
(
fname
[
0
])
self
.
chunk_indices
.
append
(
len
(
names
))
names
.
append
(
items
[
0
])
inds
.
append
(
ind
)
sizes
.
append
(
sz
)
tot
=
ind
+
1
logger
.
info
(
(
f
"max_keep=
{
max_keep_sample_size
}
, min_keep=
{
min_keep_sample_size
}
, "
f
"loaded
{
len
(
names
)
}
, skipped
{
n_short
}
short and
{
n_long
}
long, "
f
"longest-loaded=
{
max
(
sizes
)
}
, shortest-loaded=
{
min
(
sizes
)
}
"
)
)
self
.
audio_root
=
root
self
.
audio_names
=
names
self
.
sizes
=
sizes
self
.
label_rates
=
(
[
label_rates
for
_
in
range
(
len
(
label_paths
))]
if
isinstance
(
label_rates
,
int
)
else
label_rates
)
self
.
store_labels
=
store_labels
if
store_labels
:
self
.
label_list
=
[
load_label
(
p
,
inds
,
tot
)
for
p
in
label_paths
]
else
:
self
.
label_paths
=
label_paths
self
.
label_offsets_list
=
[
load_label_offset
(
p
,
inds
,
tot
)
for
p
in
label_paths
]
assert
(
label_processors
is
None
or
len
(
label_processors
)
==
self
.
num_labels
)
for
label_path
,
label_rate
in
zip
(
label_paths
,
self
.
label_rates
):
verify_label_lengths
(
self
.
sizes
,
sample_rate
,
label_path
,
label_rate
,
inds
,
tot
)
self
.
max_sample_size
=
(
max_sample_size
if
max_sample_size
is
not
None
else
sys
.
maxsize
)
self
.
pad_audio
=
pad_audio
self
.
normalize
=
normalize
logger
.
info
(
f
"pad_audio=
{
pad_audio
}
, random_crop=
{
random_crop
}
, "
f
"normalize=
{
normalize
}
, max_sample_size=
{
self
.
max_sample_size
}
"
)
def
set_epoch
(
self
,
epoch
):
self
.
epoch
=
epoch
def
batch_by_size
(
self
,
indices
,
max_tokens
=
None
,
max_sentences
=
None
,
required_batch_size_multiple
=
1
):
self
.
max_tokens
=
max_tokens
self
.
max_sentences
=
max_sentences
self
.
required_batch_size_multiple
=
required_batch_size_multiple
if
isinstance
(
indices
[
0
],
list
):
batch_list
=
[]
for
indice
in
indices
:
batch
=
super
(
HubertDataset
,
self
).
batch_by_size
(
indice
,
max_tokens
,
max_sentences
,
required_batch_size_multiple
)
batch_list
.
append
(
batch
)
return
batch_list
else
:
return
super
(
HubertDataset
,
self
).
batch_by_size
(
indices
,
max_tokens
,
max_sentences
,
required_batch_size_multiple
)
def
shuffle_batches
(
self
,
batches
,
seed
):
if
isinstance
(
batches
[
0
],
list
):
new_batches
=
[]
with
data_utils
.
numpy_seed
(
seed
):
np
.
random
.
shuffle
(
batches
)
for
batch
in
batches
:
np
.
random
.
shuffle
(
batch
)
new_batches
.
extend
(
batch
)
return
new_batches
else
:
with
data_utils
.
numpy_seed
(
seed
):
np
.
random
.
shuffle
(
batches
)
return
batches
def
reset_batch_sampler
(
self
):
indices
=
self
.
ordered_indices
()
batch_sampler
=
self
.
batch_by_size
(
indices
,
self
.
max_tokens
,
self
.
max_sentences
,
self
.
required_batch_size_multiple
)
return
batch_sampler
def
get_audio
(
self
,
index
):
import
soundfile
as
sf
wav_path
=
os
.
path
.
join
(
self
.
audio_root
,
self
.
audio_names
[
index
])
_path
,
slice_ptr
=
parse_path
(
wav_path
)
if
len
(
slice_ptr
)
==
2
:
byte_data
=
read_from_stored_zip
(
_path
,
slice_ptr
[
0
],
slice_ptr
[
1
])
assert
is_sf_audio_data
(
byte_data
)
wav_path
=
io
.
BytesIO
(
byte_data
)
wav
,
cur_sample_rate
=
sf
.
read
(
wav_path
)
wav
=
torch
.
from_numpy
(
wav
).
float
()
wav
=
self
.
postprocess
(
wav
,
cur_sample_rate
)
return
wav
def
get_label
(
self
,
index
,
label_idx
):
if
self
.
store_labels
:
label
=
self
.
label_list
[
label_idx
][
index
]
else
:
with
open
(
self
.
label_paths
[
label_idx
])
as
f
:
offset_s
,
offset_e
=
self
.
label_offsets_list
[
label_idx
][
index
]
f
.
seek
(
offset_s
)
label
=
f
.
read
(
offset_e
-
offset_s
)
if
self
.
label_processors
is
not
None
:
label
=
self
.
label_processors
[
label_idx
](
label
)
return
label
def
get_labels
(
self
,
index
):
return
[
self
.
get_label
(
index
,
i
)
for
i
in
range
(
self
.
num_labels
)]
def
__getitem__
(
self
,
index
):
wav
=
self
.
get_audio
(
index
)
labels
=
self
.
get_labels
(
index
)
return
{
"id"
:
index
,
"source"
:
wav
,
"label_list"
:
labels
}
def
__len__
(
self
):
return
len
(
self
.
sizes
)
def
crop_to_max_size
(
self
,
wav
,
target_size
):
size
=
len
(
wav
)
diff
=
size
-
target_size
if
diff
<=
0
:
return
wav
,
0
start
,
end
=
0
,
target_size
if
self
.
random_crop
:
start
=
np
.
random
.
randint
(
0
,
diff
+
1
)
end
=
size
-
diff
+
start
return
wav
[
start
:
end
],
start
def
collater
(
self
,
samples
):
# target = max(sizes) -> random_crop not used
# target = max_sample_size -> random_crop used for long
samples
=
[
s
for
s
in
samples
if
s
[
"source"
]
is
not
None
]
if
len
(
samples
)
==
0
:
return
{}
audios
=
[
s
[
"source"
]
for
s
in
samples
]
audio_sizes
=
[
len
(
s
)
for
s
in
audios
]
if
self
.
pad_audio
:
audio_size
=
min
(
max
(
audio_sizes
),
self
.
max_sample_size
)
else
:
audio_size
=
min
(
min
(
audio_sizes
),
self
.
max_sample_size
)
collated_audios
,
padding_mask
,
audio_starts
=
self
.
collater_audio
(
audios
,
audio_size
)
targets_by_label
=
[
[
s
[
"label_list"
][
i
]
for
s
in
samples
]
for
i
in
range
(
self
.
num_labels
)
]
targets_list
,
lengths_list
,
ntokens_list
=
self
.
collater_label
(
targets_by_label
,
audio_size
,
audio_starts
)
net_input
=
{
"source"
:
collated_audios
,
"padding_mask"
:
padding_mask
}
batch
=
{
"id"
:
torch
.
LongTensor
([
s
[
"id"
]
for
s
in
samples
]),
"net_input"
:
net_input
,
}
if
self
.
single_target
:
batch
[
"target_lengths"
]
=
lengths_list
[
0
]
batch
[
"ntokens"
]
=
ntokens_list
[
0
]
batch
[
"target"
]
=
targets_list
[
0
]
else
:
batch
[
"target_lengths_list"
]
=
lengths_list
batch
[
"ntokens_list"
]
=
ntokens_list
batch
[
"target_list"
]
=
targets_list
if
self
.
multitask
:
batch
[
"task"
]
=
"multitask"
else
:
batch
[
"task"
]
=
"hubert"
return
batch
def
collater_audio
(
self
,
audios
,
audio_size
):
collated_audios
=
audios
[
0
].
new_zeros
(
len
(
audios
),
audio_size
)
padding_mask
=
(
torch
.
BoolTensor
(
collated_audios
.
shape
).
fill_
(
False
)
# if self.pad_audio else None
)
audio_starts
=
[
0
for
_
in
audios
]
for
i
,
audio
in
enumerate
(
audios
):
diff
=
len
(
audio
)
-
audio_size
if
diff
==
0
:
collated_audios
[
i
]
=
audio
elif
diff
<
0
:
assert
self
.
pad_audio
collated_audios
[
i
]
=
torch
.
cat
(
[
audio
,
audio
.
new_full
((
-
diff
,),
0.0
)]
)
padding_mask
[
i
,
diff
:]
=
True
else
:
collated_audios
[
i
],
audio_starts
[
i
]
=
self
.
crop_to_max_size
(
audio
,
audio_size
)
return
collated_audios
,
padding_mask
,
audio_starts
def
collater_frm_label
(
self
,
targets
,
audio_size
,
audio_starts
,
label_rate
,
pad
):
assert
label_rate
>
0
s2f
=
label_rate
/
self
.
sample_rate
frm_starts
=
[
int
(
round
(
s
*
s2f
))
for
s
in
audio_starts
]
frm_size
=
int
(
round
(
audio_size
*
s2f
))
if
not
self
.
pad_audio
:
rem_size
=
[
len
(
t
)
-
s
for
t
,
s
in
zip
(
targets
,
frm_starts
)]
frm_size
=
min
(
frm_size
,
*
rem_size
)
targets
=
[
t
[
s
:
s
+
frm_size
]
for
t
,
s
in
zip
(
targets
,
frm_starts
)]
logger
.
debug
(
f
"audio_starts=
{
audio_starts
}
"
)
logger
.
debug
(
f
"frame_starts=
{
frm_starts
}
"
)
logger
.
debug
(
f
"frame_size=
{
frm_size
}
"
)
lengths
=
torch
.
LongTensor
([
len
(
t
)
for
t
in
targets
])
ntokens
=
lengths
.
sum
().
item
()
targets
=
data_utils
.
collate_tokens
(
targets
,
pad_idx
=
pad
,
left_pad
=
False
)
return
targets
,
lengths
,
ntokens
def
collater_seq_label
(
self
,
targets
,
pad
):
lengths
=
torch
.
LongTensor
([
len
(
t
)
for
t
in
targets
])
ntokens
=
lengths
.
sum
().
item
()
targets
=
data_utils
.
collate_tokens
(
targets
,
pad_idx
=
pad
,
left_pad
=
False
)
return
targets
,
lengths
,
ntokens
def
collater_label
(
self
,
targets_by_label
,
audio_size
,
audio_starts
):
targets_list
,
lengths_list
,
ntokens_list
=
[],
[],
[]
itr
=
zip
(
targets_by_label
,
self
.
label_rates
,
self
.
pad_list
)
for
targets
,
label_rate
,
pad
in
itr
:
if
label_rate
==
-
1
:
targets
,
lengths
,
ntokens
=
self
.
collater_seq_label
(
targets
,
pad
)
else
:
targets
,
lengths
,
ntokens
=
self
.
collater_frm_label
(
targets
,
audio_size
,
audio_starts
,
label_rate
,
pad
)
targets_list
.
append
(
targets
)
lengths_list
.
append
(
lengths
)
ntokens_list
.
append
(
ntokens
)
return
targets_list
,
lengths_list
,
ntokens_list
def
num_tokens
(
self
,
index
):
return
self
.
size
(
index
)
def
size
(
self
,
index
):
if
self
.
pad_audio
:
return
self
.
sizes
[
index
]
return
min
(
self
.
sizes
[
index
],
self
.
max_sample_size
)
def
ordered_indices
(
self
):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if
self
.
shuffle
:
if
len
(
self
.
chunk_names
)
>
0
:
with
data_utils
.
numpy_seed
(
self
.
epoch
):
self
.
chunk_order
=
np
.
random
.
permutation
(
len
(
self
.
chunk_names
))
chunk_count
=
0
tmp_sizes
=
[]
tmp_indices
=
[]
indice
=
[]
for
i
in
self
.
chunk_order
:
chunk_count
+=
1
start
=
self
.
chunk_indices
[
i
]
end
=
self
.
chunk_indices
[
i
+
1
]
if
i
<
len
(
self
.
chunk_names
)
-
1
else
len
(
self
)
size
=
list
(
self
.
sizes
[
start
:
end
])
tmp_indices
.
extend
(
list
(
np
.
arange
(
start
,
end
)))
tmp_sizes
.
extend
(
size
)
if
chunk_count
%
10
==
0
or
i
==
self
.
chunk_order
[
0
]:
order
=
[
np
.
random
.
permutation
(
len
(
tmp_indices
))]
order
.
append
(
np
.
minimum
(
np
.
array
(
tmp_sizes
),
self
.
max_sample_size
,
)
)
sort_idx
=
np
.
lexsort
(
order
)[::
-
1
]
indice
.
append
([
tmp_indices
[
k
]
for
k
in
sort_idx
])
tmp_indices
=
[]
tmp_sizes
=
[]
return
indice
else
:
order
=
[
np
.
random
.
permutation
(
len
(
self
))]
order
.
append
(
np
.
minimum
(
np
.
array
(
self
.
sizes
),
self
.
max_sample_size
,
)
)
return
np
.
lexsort
(
order
)[::
-
1
]
else
:
return
np
.
arange
(
len
(
self
))
def
postprocess
(
self
,
wav
,
cur_sample_rate
):
if
wav
.
dim
()
==
2
:
wav
=
wav
.
mean
(
-
1
)
assert
wav
.
dim
()
==
1
,
wav
.
dim
()
if
cur_sample_rate
!=
self
.
sample_rate
:
raise
Exception
(
f
"sr
{
cur_sample_rate
}
!=
{
self
.
sample_rate
}
"
)
if
self
.
normalize
:
with
torch
.
no_grad
():
wav
=
F
.
layer_norm
(
wav
,
wav
.
shape
)
return
wav
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/audio/raw_audio_dataset.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
logging
import
os
import
sys
import
io
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
fairseq.data
import
data_utils
from
..
import
FairseqDataset
from
..data_utils
import
compute_mask_indices
,
get_buckets
,
get_bucketed_sizes
from
fairseq.data.audio.audio_utils
import
(
parse_path
,
read_from_stored_zip
,
is_sf_audio_data
,
)
logger
=
logging
.
getLogger
(
__name__
)
class
RawAudioDataset
(
FairseqDataset
):
def
__init__
(
self
,
sample_rate
,
max_sample_size
=
None
,
min_sample_size
=
0
,
shuffle
=
True
,
pad
=
False
,
normalize
=
False
,
compute_mask_indices
=
False
,
**
mask_compute_kwargs
,
):
super
().
__init__
()
self
.
sample_rate
=
sample_rate
self
.
sizes
=
[]
self
.
max_sample_size
=
(
max_sample_size
if
max_sample_size
is
not
None
else
sys
.
maxsize
)
self
.
min_sample_size
=
min_sample_size
self
.
pad
=
pad
self
.
shuffle
=
shuffle
self
.
normalize
=
normalize
self
.
compute_mask_indices
=
compute_mask_indices
self
.
epoch
=
0
if
self
.
compute_mask_indices
:
self
.
mask_compute_kwargs
=
mask_compute_kwargs
self
.
_features_size_map
=
{}
self
.
_C
=
mask_compute_kwargs
[
"encoder_embed_dim"
]
self
.
_conv_feature_layers
=
eval
(
mask_compute_kwargs
[
"conv_feature_layers"
])
def
__getitem__
(
self
,
index
):
raise
NotImplementedError
()
def
__len__
(
self
):
return
len
(
self
.
sizes
)
def
set_epoch
(
self
,
epoch
):
self
.
epoch
=
epoch
def
postprocess
(
self
,
feats
,
curr_sample_rate
):
if
feats
.
dim
()
==
2
:
feats
=
feats
.
mean
(
-
1
)
if
curr_sample_rate
!=
self
.
sample_rate
:
raise
Exception
(
f
"sample rate:
{
curr_sample_rate
}
, need
{
self
.
sample_rate
}
"
)
assert
feats
.
dim
()
==
1
,
feats
.
dim
()
if
self
.
normalize
:
with
torch
.
no_grad
():
feats
=
F
.
layer_norm
(
feats
,
feats
.
shape
)
return
feats
def
crop_to_max_size
(
self
,
wav
,
target_size
):
size
=
len
(
wav
)
diff
=
size
-
target_size
if
diff
<=
0
:
return
wav
,
0
start
=
np
.
random
.
randint
(
0
,
diff
+
1
)
end
=
size
-
diff
+
start
return
wav
[
start
:
end
],
start
def
_compute_mask_indices
(
self
,
dims
,
padding_mask
):
B
,
T
,
C
=
dims
mask_indices
,
mask_channel_indices
=
None
,
None
if
self
.
mask_compute_kwargs
[
"mask_prob"
]
>
0
:
mask_indices
=
compute_mask_indices
(
(
B
,
T
),
padding_mask
,
self
.
mask_compute_kwargs
[
"mask_prob"
],
self
.
mask_compute_kwargs
[
"mask_length"
],
self
.
mask_compute_kwargs
[
"mask_selection"
],
self
.
mask_compute_kwargs
[
"mask_other"
],
min_masks
=
2
,
no_overlap
=
self
.
mask_compute_kwargs
[
"no_mask_overlap"
],
min_space
=
self
.
mask_compute_kwargs
[
"mask_min_space"
],
)
mask_indices
=
torch
.
from_numpy
(
mask_indices
)
if
self
.
mask_compute_kwargs
[
"mask_channel_prob"
]
>
0
:
mask_channel_indices
=
compute_mask_indices
(
(
B
,
C
),
None
,
self
.
mask_compute_kwargs
[
"mask_channel_prob"
],
self
.
mask_compute_kwargs
[
"mask_channel_length"
],
self
.
mask_compute_kwargs
[
"mask_channel_selection"
],
self
.
mask_compute_kwargs
[
"mask_channel_other"
],
no_overlap
=
self
.
mask_compute_kwargs
[
"no_mask_channel_overlap"
],
min_space
=
self
.
mask_compute_kwargs
[
"mask_channel_min_space"
],
)
mask_channel_indices
=
(
torch
.
from_numpy
(
mask_channel_indices
).
unsqueeze
(
1
).
expand
(
-
1
,
T
,
-
1
)
)
return
mask_indices
,
mask_channel_indices
@
staticmethod
def
_bucket_tensor
(
tensor
,
num_pad
,
value
):
return
F
.
pad
(
tensor
,
(
0
,
num_pad
),
value
=
value
)
def
collater
(
self
,
samples
):
samples
=
[
s
for
s
in
samples
if
s
[
"source"
]
is
not
None
]
if
len
(
samples
)
==
0
:
return
{}
sources
=
[
s
[
"source"
]
for
s
in
samples
]
sizes
=
[
len
(
s
)
for
s
in
sources
]
if
self
.
pad
:
target_size
=
min
(
max
(
sizes
),
self
.
max_sample_size
)
else
:
target_size
=
min
(
min
(
sizes
),
self
.
max_sample_size
)
collated_sources
=
sources
[
0
].
new_zeros
(
len
(
sources
),
target_size
)
padding_mask
=
(
torch
.
BoolTensor
(
collated_sources
.
shape
).
fill_
(
False
)
if
self
.
pad
else
None
)
for
i
,
(
source
,
size
)
in
enumerate
(
zip
(
sources
,
sizes
)):
diff
=
size
-
target_size
if
diff
==
0
:
collated_sources
[
i
]
=
source
elif
diff
<
0
:
assert
self
.
pad
collated_sources
[
i
]
=
torch
.
cat
(
[
source
,
source
.
new_full
((
-
diff
,),
0.0
)]
)
padding_mask
[
i
,
diff
:]
=
True
else
:
collated_sources
[
i
],
start
=
self
.
crop_to_max_size
(
source
,
target_size
)
input
=
{
"source"
:
collated_sources
}
out
=
{
"id"
:
torch
.
LongTensor
([
s
[
"id"
]
for
s
in
samples
])}
if
self
.
pad
:
input
[
"padding_mask"
]
=
padding_mask
if
hasattr
(
self
,
"num_buckets"
)
and
self
.
num_buckets
>
0
:
assert
self
.
pad
,
"Cannot bucket without padding first."
bucket
=
max
(
self
.
_bucketed_sizes
[
s
[
"id"
]]
for
s
in
samples
)
num_pad
=
bucket
-
collated_sources
.
size
(
-
1
)
if
num_pad
:
input
[
"source"
]
=
self
.
_bucket_tensor
(
collated_sources
,
num_pad
,
0
)
input
[
"padding_mask"
]
=
self
.
_bucket_tensor
(
padding_mask
,
num_pad
,
True
)
if
self
.
compute_mask_indices
:
B
=
input
[
"source"
].
size
(
0
)
T
=
self
.
_get_mask_indices_dims
(
input
[
"source"
].
size
(
-
1
))
padding_mask_reshaped
=
input
[
"padding_mask"
].
clone
()
extra
=
padding_mask_reshaped
.
size
(
1
)
%
T
if
extra
>
0
:
padding_mask_reshaped
=
padding_mask_reshaped
[:,
:
-
extra
]
padding_mask_reshaped
=
padding_mask_reshaped
.
view
(
padding_mask_reshaped
.
size
(
0
),
T
,
-
1
)
padding_mask_reshaped
=
padding_mask_reshaped
.
all
(
-
1
)
input
[
"padding_count"
]
=
padding_mask_reshaped
.
sum
(
-
1
).
max
().
item
()
mask_indices
,
mask_channel_indices
=
self
.
_compute_mask_indices
(
(
B
,
T
,
self
.
_C
),
padding_mask_reshaped
,
)
input
[
"mask_indices"
]
=
mask_indices
input
[
"mask_channel_indices"
]
=
mask_channel_indices
out
[
"sample_size"
]
=
mask_indices
.
sum
().
item
()
out
[
"net_input"
]
=
input
return
out
def
_get_mask_indices_dims
(
self
,
size
,
padding
=
0
,
dilation
=
1
):
if
size
not
in
self
.
_features_size_map
:
L_in
=
size
for
(
_
,
kernel_size
,
stride
)
in
self
.
_conv_feature_layers
:
L_out
=
L_in
+
2
*
padding
-
dilation
*
(
kernel_size
-
1
)
-
1
L_out
=
1
+
L_out
//
stride
L_in
=
L_out
self
.
_features_size_map
[
size
]
=
L_out
return
self
.
_features_size_map
[
size
]
def
num_tokens
(
self
,
index
):
return
self
.
size
(
index
)
def
size
(
self
,
index
):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
if
self
.
pad
:
return
self
.
sizes
[
index
]
return
min
(
self
.
sizes
[
index
],
self
.
max_sample_size
)
def
ordered_indices
(
self
):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if
self
.
shuffle
:
if
len
(
self
.
chunk_names
)
>
0
:
with
data_utils
.
numpy_seed
(
self
.
epoch
):
self
.
chunk_order
=
np
.
random
.
permutation
(
len
(
self
.
chunk_names
))
chunk_count
=
0
tmp_sizes
=
[]
tmp_indices
=
[]
indice
=
[]
for
i
in
self
.
chunk_order
:
chunk_count
+=
1
start
=
self
.
chunk_indices
[
i
]
end
=
self
.
chunk_indices
[
i
+
1
]
if
i
<
len
(
self
.
chunk_names
)
-
1
else
len
(
self
)
size
=
list
(
self
.
sizes
[
start
:
end
])
tmp_indices
.
extend
(
list
(
np
.
arange
(
start
,
end
)))
tmp_sizes
.
extend
(
size
)
if
chunk_count
%
10
==
0
or
i
==
self
.
chunk_order
[
0
]:
order
=
[
np
.
random
.
permutation
(
len
(
tmp_indices
))]
order
.
append
(
np
.
minimum
(
np
.
array
(
tmp_sizes
),
self
.
max_sample_size
,
)
)
sort_idx
=
np
.
lexsort
(
order
)[::
-
1
]
indice
.
append
([
tmp_indices
[
k
]
for
k
in
sort_idx
])
tmp_indices
=
[]
tmp_sizes
=
[]
return
indice
else
:
order
=
[
np
.
random
.
permutation
(
len
(
self
))]
order
.
append
(
np
.
minimum
(
np
.
array
(
self
.
sizes
),
self
.
max_sample_size
,
)
)
return
np
.
lexsort
(
order
)[::
-
1
]
else
:
return
np
.
arange
(
len
(
self
))
def
batch_by_size
(
self
,
indices
,
max_tokens
=
None
,
max_sentences
=
None
,
required_batch_size_multiple
=
1
):
self
.
max_tokens
=
max_tokens
self
.
max_sentences
=
max_sentences
self
.
required_batch_size_multiple
=
required_batch_size_multiple
if
isinstance
(
indices
[
0
],
list
):
batch_list
=
[]
for
indice
in
indices
:
batch
=
super
(
RawAudioDataset
,
self
).
batch_by_size
(
indice
,
max_tokens
,
max_sentences
,
required_batch_size_multiple
)
batch_list
.
append
(
batch
)
return
batch_list
else
:
return
super
(
RawAudioDataset
,
self
).
batch_by_size
(
indices
,
max_tokens
,
max_sentences
,
required_batch_size_multiple
)
def
shuffle_batches
(
self
,
batches
,
seed
):
if
isinstance
(
batches
[
0
],
list
):
new_batches
=
[]
with
data_utils
.
numpy_seed
(
seed
):
np
.
random
.
shuffle
(
batches
)
for
batch
in
batches
:
np
.
random
.
shuffle
(
batch
)
new_batches
.
extend
(
batch
)
return
new_batches
else
:
with
data_utils
.
numpy_seed
(
seed
):
np
.
random
.
shuffle
(
batches
)
return
batches
def
reset_batch_sampler
(
self
):
indices
=
self
.
ordered_indices
()
batch_sampler
=
self
.
batch_by_size
(
indices
,
self
.
max_tokens
,
self
.
max_sentences
,
self
.
required_batch_size_multiple
)
return
batch_sampler
def
set_bucket_info
(
self
,
num_buckets
):
self
.
num_buckets
=
num_buckets
if
self
.
num_buckets
>
0
:
self
.
_collated_sizes
=
np
.
minimum
(
np
.
array
(
self
.
sizes
),
self
.
max_sample_size
,
)
self
.
buckets
=
get_buckets
(
self
.
_collated_sizes
,
self
.
num_buckets
,
)
self
.
_bucketed_sizes
=
get_bucketed_sizes
(
self
.
_collated_sizes
,
self
.
buckets
)
logger
.
info
(
f
"
{
len
(
self
.
buckets
)
}
bucket(s) for the audio dataset: "
f
"
{
self
.
buckets
}
"
)
class
FileAudioDataset
(
RawAudioDataset
):
def
__init__
(
self
,
manifest_path
,
sample_rate
,
max_sample_size
=
None
,
min_sample_size
=
0
,
shuffle
=
True
,
pad
=
False
,
normalize
=
False
,
num_buckets
=
0
,
compute_mask_indices
=
False
,
**
mask_compute_kwargs
,
):
super
().
__init__
(
sample_rate
=
sample_rate
,
max_sample_size
=
max_sample_size
,
min_sample_size
=
min_sample_size
,
shuffle
=
shuffle
,
pad
=
pad
,
normalize
=
normalize
,
compute_mask_indices
=
compute_mask_indices
,
**
mask_compute_kwargs
,
)
self
.
chunk_names
=
[]
self
.
chunk_indices
=
[]
self
.
fnames
=
[]
self
.
skipped
=
[]
skipped
=
0
count
=
0
sizes
=
[]
self
.
skipped_indices
=
set
()
with
open
(
manifest_path
,
"r"
)
as
f
:
self
.
root_dir
=
f
.
readline
().
strip
()
for
i
,
line
in
enumerate
(
f
):
items
=
line
.
strip
().
split
(
"
\t
"
)
#assert len(items) == 2, line
sz
=
int
(
items
[
1
])
if
min_sample_size
is
not
None
and
sz
<
min_sample_size
:
skipped
+=
1
self
.
skipped
.
append
(
i
)
self
.
skipped_indices
.
add
(
i
)
continue
if
pad
and
max_sample_size
is
not
None
and
sz
>
max_sample_size
:
skipped
+=
1
self
.
skipped
.
append
(
i
)
continue
fname
=
items
[
0
].
split
(
":"
)
if
len
(
fname
)
>
1
:
if
len
(
self
.
chunk_names
)
==
0
or
fname
[
0
]
!=
self
.
chunk_names
[
-
1
]:
self
.
chunk_names
.
append
(
fname
[
0
])
self
.
chunk_indices
.
append
(
len
(
self
.
fnames
))
self
.
fnames
.
append
(
items
[
0
])
sizes
.
append
(
sz
)
logger
.
info
(
f
"loaded
{
len
(
self
.
fnames
)
}
, skipped
{
skipped
}
samples"
)
self
.
sizes
=
np
.
array
(
sizes
,
dtype
=
np
.
int64
)
try
:
import
pyarrow
self
.
fnames
=
pyarrow
.
array
(
self
.
fnames
)
except
:
logger
.
debug
(
"Could not create a pyarrow array. Please install pyarrow for better performance"
)
pass
self
.
set_bucket_info
(
num_buckets
)
def
__getitem__
(
self
,
index
):
import
soundfile
as
sf
path_or_fp
=
os
.
path
.
join
(
self
.
root_dir
,
str
(
self
.
fnames
[
index
]))
_path
,
slice_ptr
=
parse_path
(
path_or_fp
)
if
len
(
slice_ptr
)
==
2
:
byte_data
=
read_from_stored_zip
(
_path
,
slice_ptr
[
0
],
slice_ptr
[
1
])
assert
is_sf_audio_data
(
byte_data
)
path_or_fp
=
io
.
BytesIO
(
byte_data
)
wav
,
curr_sample_rate
=
sf
.
read
(
path_or_fp
,
dtype
=
"float32"
)
wav
=
torch
.
from_numpy
(
wav
).
float
()
wav
=
self
.
postprocess
(
wav
,
curr_sample_rate
)
return
{
"id"
:
index
,
"source"
:
wav
}
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/audio/speech_to_text_dataset.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
csv
import
io
import
logging
import
os.path
as
op
import
re
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
from
fairseq.data
import
(
ConcatDataset
,
Dictionary
,
FairseqDataset
,
ResamplingDataset
,
data_utils
as
fairseq_data_utils
,
)
from
fairseq.data.audio.audio_utils
import
(
get_fbank
,
get_waveform
,
read_from_stored_zip
,
is_npy_data
,
is_sf_audio_data
,
parse_path
,
FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS
)
from
fairseq.data.audio.feature_transforms
import
CompositeAudioFeatureTransform
logger
=
logging
.
getLogger
(
__name__
)
class
S2TDataConfig
(
object
):
"""Wrapper class for data config YAML"""
def
__init__
(
self
,
yaml_path
):
try
:
import
yaml
except
ImportError
:
print
(
"Please install PyYAML to load YAML files for "
"S2T data config"
)
self
.
config
=
{}
if
op
.
isfile
(
yaml_path
):
try
:
with
open
(
yaml_path
)
as
f
:
self
.
config
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
except
Exception
as
e
:
raise
Exception
(
f
"Failed to load config from
{
yaml_path
}
:
{
e
}
"
)
else
:
raise
FileNotFoundError
(
f
"
{
yaml_path
}
not found"
)
@
property
def
vocab_filename
(
self
):
"""fairseq vocabulary file under data root"""
return
self
.
config
.
get
(
"vocab_filename"
,
"dict.txt"
)
@
property
def
shuffle
(
self
)
->
bool
:
"""Shuffle dataset samples before batching"""
return
self
.
config
.
get
(
"shuffle"
,
False
)
@
property
def
pre_tokenizer
(
self
)
->
Dict
:
"""Pre-tokenizer to apply before subword tokenization. Returning
a dictionary with `tokenizer` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
return
self
.
config
.
get
(
"pre_tokenizer"
,
{
"tokenizer"
:
None
})
@
property
def
bpe_tokenizer
(
self
)
->
Dict
:
"""Subword tokenizer to apply after pre-tokenization. Returning
a dictionary with `bpe` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
return
self
.
config
.
get
(
"bpe_tokenizer"
,
{
"bpe"
:
None
})
@
property
def
prepend_tgt_lang_tag
(
self
)
->
bool
:
"""Prepend target lang ID token as the target BOS (e.g. for to-many
multilingual setting). During inference, this requires `--prefix-size 1`
to force BOS to be lang ID token."""
return
self
.
config
.
get
(
"prepend_tgt_lang_tag"
,
False
)
@
property
def
input_feat_per_channel
(
self
):
"""The dimension of input features (per audio channel)"""
return
self
.
config
.
get
(
"input_feat_per_channel"
,
80
)
@
property
def
input_channels
(
self
):
"""The number of channels in the input audio"""
return
self
.
config
.
get
(
"input_channels"
,
1
)
@
property
def
sampling_alpha
(
self
):
"""Hyper-parameter alpha = 1/T for temperature-based resampling.
(alpha = 1 for no resampling)"""
return
self
.
config
.
get
(
"sampling_alpha"
,
1.0
)
@
property
def
use_audio_input
(
self
):
"""Needed by the dataset loader to see if the model requires
raw audio as inputs."""
return
self
.
config
.
get
(
"use_audio_input"
,
False
)
@
property
def
audio_root
(
self
):
"""Audio paths in the manifest TSV can be relative and this provides
the root path. Set this to empty string when using absolute paths."""
return
self
.
config
.
get
(
"audio_root"
,
""
)
def
get_feature_transforms
(
self
,
split
,
is_train
):
"""Split-specific feature transforms. Allowing train set wildcard `_train`,
evaluation set wildcard `_eval` and general wildcard `*` for matching."""
from
copy
import
deepcopy
cfg
=
deepcopy
(
self
.
config
)
_cur
=
cfg
.
get
(
"transforms"
,
{})
cur
=
_cur
.
get
(
split
)
cur
=
_cur
.
get
(
"_train"
)
if
cur
is
None
and
is_train
else
cur
cur
=
_cur
.
get
(
"_eval"
)
if
cur
is
None
and
not
is_train
else
cur
cur
=
_cur
.
get
(
"*"
)
if
cur
is
None
else
cur
cfg
[
"transforms"
]
=
cur
return
cfg
def
get_features_from_npy_or_audio
(
path
):
ext
=
op
.
splitext
(
op
.
basename
(
path
))[
1
]
if
ext
not
in
FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS
:
raise
ValueError
(
f
'Unsupported file format for "
{
path
}
"'
)
return
np
.
load
(
path
)
if
ext
==
".npy"
else
get_fbank
(
path
)
def
get_features_or_waveform_from_stored_zip
(
path
,
byte_offset
,
byte_size
,
need_waveform
=
False
):
assert
path
.
endswith
(
".zip"
)
data
=
read_from_stored_zip
(
path
,
byte_offset
,
byte_size
)
f
=
io
.
BytesIO
(
data
)
if
is_npy_data
(
data
):
features_or_waveform
=
np
.
load
(
f
)
elif
is_sf_audio_data
(
data
):
features_or_waveform
=
\
get_waveform
(
f
,
always_2d
=
False
)[
0
]
if
need_waveform
else
get_fbank
(
f
)
else
:
raise
ValueError
(
f
'Unknown file format for "
{
path
}
"'
)
return
features_or_waveform
def
get_features_or_waveform
(
path
:
str
,
need_waveform
=
False
):
"""Get speech features from .npy file or waveform from .wav/.flac file.
The file may be inside an uncompressed ZIP file and is accessed via byte
offset and length.
Args:
path (str): File path in the format of "<.npy/.wav/.flac path>" or
"<zip path>:<byte offset>:<byte length>".
need_waveform (bool): return waveform instead of features.
Returns:
features_or_waveform (numpy.ndarray): speech features or waveform.
"""
_path
,
slice_ptr
=
parse_path
(
path
)
if
len
(
slice_ptr
)
==
0
:
if
need_waveform
:
return
get_waveform
(
_path
,
always_2d
=
False
)
return
get_features_from_npy_or_audio
(
_path
)
elif
len
(
slice_ptr
)
==
2
:
features_or_waveform
=
get_features_or_waveform_from_stored_zip
(
_path
,
slice_ptr
[
0
],
slice_ptr
[
1
],
need_waveform
=
need_waveform
)
else
:
raise
ValueError
(
f
"Invalid path:
{
path
}
"
)
return
features_or_waveform
def
_collate_frames
(
frames
:
List
[
torch
.
Tensor
],
is_audio_input
:
bool
=
False
)
->
torch
.
Tensor
:
"""
Convert a list of 2D frames into a padded 3D tensor
Args:
frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
length of i-th frame and f_dim is static dimension of features
Returns:
3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
"""
max_len
=
max
(
frame
.
size
(
0
)
for
frame
in
frames
)
if
is_audio_input
:
out
=
frames
[
0
].
new_zeros
((
len
(
frames
),
max_len
))
else
:
out
=
frames
[
0
].
new_zeros
((
len
(
frames
),
max_len
,
frames
[
0
].
size
(
1
)))
for
i
,
v
in
enumerate
(
frames
):
out
[
i
,
:
v
.
size
(
0
)]
=
v
return
out
class
SpeechToTextDataset
(
FairseqDataset
):
LANG_TAG_TEMPLATE
=
"<lang:{}>"
def
__init__
(
self
,
split
:
str
,
is_train_split
:
bool
,
data_cfg
:
S2TDataConfig
,
audio_paths
:
List
[
str
],
n_frames
:
List
[
int
],
src_texts
:
Optional
[
List
[
str
]]
=
None
,
tgt_texts
:
Optional
[
List
[
str
]]
=
None
,
speakers
:
Optional
[
List
[
str
]]
=
None
,
src_langs
:
Optional
[
List
[
str
]]
=
None
,
tgt_langs
:
Optional
[
List
[
str
]]
=
None
,
ids
:
Optional
[
List
[
str
]]
=
None
,
tgt_dict
:
Optional
[
Dictionary
]
=
None
,
pre_tokenizer
=
None
,
bpe_tokenizer
=
None
,
):
self
.
split
,
self
.
is_train_split
=
split
,
is_train_split
self
.
data_cfg
=
data_cfg
self
.
audio_paths
,
self
.
n_frames
=
audio_paths
,
n_frames
self
.
n_samples
=
len
(
audio_paths
)
assert
len
(
n_frames
)
==
self
.
n_samples
>
0
assert
src_texts
is
None
or
len
(
src_texts
)
==
self
.
n_samples
assert
tgt_texts
is
None
or
len
(
tgt_texts
)
==
self
.
n_samples
assert
speakers
is
None
or
len
(
speakers
)
==
self
.
n_samples
assert
src_langs
is
None
or
len
(
src_langs
)
==
self
.
n_samples
assert
tgt_langs
is
None
or
len
(
tgt_langs
)
==
self
.
n_samples
assert
ids
is
None
or
len
(
ids
)
==
self
.
n_samples
assert
(
tgt_dict
is
None
and
tgt_texts
is
None
)
or
(
tgt_dict
is
not
None
and
tgt_texts
is
not
None
)
self
.
src_texts
,
self
.
tgt_texts
=
src_texts
,
tgt_texts
self
.
src_langs
,
self
.
tgt_langs
=
src_langs
,
tgt_langs
self
.
tgt_dict
=
tgt_dict
self
.
check_tgt_lang_tag
()
self
.
ids
=
ids
self
.
shuffle
=
data_cfg
.
shuffle
if
is_train_split
else
False
self
.
feature_transforms
=
CompositeAudioFeatureTransform
.
from_config_dict
(
self
.
data_cfg
.
get_feature_transforms
(
split
,
is_train_split
)
)
self
.
pre_tokenizer
=
pre_tokenizer
self
.
bpe_tokenizer
=
bpe_tokenizer
logger
.
info
(
self
.
__repr__
())
def
__repr__
(
self
):
return
(
self
.
__class__
.
__name__
+
f
'(split="
{
self
.
split
}
", n_samples=
{
self
.
n_samples
}
, '
f
"prepend_tgt_lang_tag=
{
self
.
data_cfg
.
prepend_tgt_lang_tag
}
, "
f
"shuffle=
{
self
.
shuffle
}
, transforms=
{
self
.
feature_transforms
}
)"
)
@
classmethod
def
is_lang_tag
(
cls
,
token
):
pattern
=
cls
.
LANG_TAG_TEMPLATE
.
replace
(
"{}"
,
"(.*)"
)
return
re
.
match
(
pattern
,
token
)
def
check_tgt_lang_tag
(
self
):
if
self
.
data_cfg
.
prepend_tgt_lang_tag
:
assert
self
.
tgt_langs
is
not
None
and
self
.
tgt_dict
is
not
None
tgt_lang_tags
=
[
self
.
LANG_TAG_TEMPLATE
.
format
(
t
)
for
t
in
set
(
self
.
tgt_langs
)
]
assert
all
(
t
in
self
.
tgt_dict
for
t
in
tgt_lang_tags
)
def
tokenize_text
(
self
,
text
:
str
):
if
self
.
pre_tokenizer
is
not
None
:
text
=
self
.
pre_tokenizer
.
encode
(
text
)
if
self
.
bpe_tokenizer
is
not
None
:
text
=
self
.
bpe_tokenizer
.
encode
(
text
)
return
text
def
__getitem__
(
self
,
index
:
int
)
->
Tuple
[
int
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
source
=
get_features_or_waveform
(
self
.
audio_paths
[
index
],
need_waveform
=
self
.
data_cfg
.
use_audio_input
)
if
self
.
feature_transforms
is
not
None
:
assert
not
self
.
data_cfg
.
use_audio_input
source
=
self
.
feature_transforms
(
source
)
source
=
torch
.
from_numpy
(
source
).
float
()
target
=
None
if
self
.
tgt_texts
is
not
None
:
tokenized
=
self
.
tokenize_text
(
self
.
tgt_texts
[
index
])
target
=
self
.
tgt_dict
.
encode_line
(
tokenized
,
add_if_not_exist
=
False
,
append_eos
=
True
).
long
()
if
self
.
data_cfg
.
prepend_tgt_lang_tag
:
lang_tag
=
self
.
LANG_TAG_TEMPLATE
.
format
(
self
.
tgt_langs
[
index
])
lang_tag_idx
=
self
.
tgt_dict
.
index
(
lang_tag
)
target
=
torch
.
cat
((
torch
.
LongTensor
([
lang_tag_idx
]),
target
),
0
)
return
index
,
source
,
target
def
__len__
(
self
):
return
self
.
n_samples
def
collater
(
self
,
samples
:
List
[
Tuple
[
int
,
torch
.
Tensor
,
torch
.
Tensor
]])
->
Dict
:
if
len
(
samples
)
==
0
:
return
{}
indices
=
torch
.
tensor
([
i
for
i
,
_
,
_
in
samples
],
dtype
=
torch
.
long
)
frames
=
_collate_frames
(
[
s
for
_
,
s
,
_
in
samples
],
self
.
data_cfg
.
use_audio_input
)
# sort samples by descending number of frames
n_frames
=
torch
.
tensor
([
s
.
size
(
0
)
for
_
,
s
,
_
in
samples
],
dtype
=
torch
.
long
)
n_frames
,
order
=
n_frames
.
sort
(
descending
=
True
)
indices
=
indices
.
index_select
(
0
,
order
)
frames
=
frames
.
index_select
(
0
,
order
)
target
,
target_lengths
=
None
,
None
prev_output_tokens
=
None
ntokens
=
None
if
self
.
tgt_texts
is
not
None
:
target
=
fairseq_data_utils
.
collate_tokens
(
[
t
for
_
,
_
,
t
in
samples
],
self
.
tgt_dict
.
pad
(),
self
.
tgt_dict
.
eos
(),
left_pad
=
False
,
move_eos_to_beginning
=
False
,
)
target
=
target
.
index_select
(
0
,
order
)
target_lengths
=
torch
.
tensor
(
[
t
.
size
(
0
)
for
_
,
_
,
t
in
samples
],
dtype
=
torch
.
long
).
index_select
(
0
,
order
)
prev_output_tokens
=
fairseq_data_utils
.
collate_tokens
(
[
t
for
_
,
_
,
t
in
samples
],
self
.
tgt_dict
.
pad
(),
self
.
tgt_dict
.
eos
(),
left_pad
=
False
,
move_eos_to_beginning
=
True
,
)
prev_output_tokens
=
prev_output_tokens
.
index_select
(
0
,
order
)
ntokens
=
sum
(
t
.
size
(
0
)
for
_
,
_
,
t
in
samples
)
out
=
{
"id"
:
indices
,
"net_input"
:
{
"src_tokens"
:
frames
,
"src_lengths"
:
n_frames
,
"prev_output_tokens"
:
prev_output_tokens
,
},
"target"
:
target
,
"target_lengths"
:
target_lengths
,
"ntokens"
:
ntokens
,
"nsentences"
:
len
(
samples
),
}
return
out
def
num_tokens
(
self
,
index
):
return
self
.
n_frames
[
index
]
def
size
(
self
,
index
):
t_len
=
0
if
self
.
tgt_texts
is
not
None
:
tokenized
=
self
.
tokenize_text
(
self
.
tgt_texts
[
index
])
t_len
=
len
(
tokenized
.
split
(
" "
))
return
self
.
n_frames
[
index
],
t_len
@
property
def
sizes
(
self
):
return
np
.
array
(
self
.
n_frames
)
@
property
def
can_reuse_epoch_itr_across_epochs
(
self
):
return
True
def
ordered_indices
(
self
):
if
self
.
shuffle
:
order
=
[
np
.
random
.
permutation
(
len
(
self
))]
else
:
order
=
[
np
.
arange
(
len
(
self
))]
# first by descending order of # of frames then by original/random order
order
.
append
([
-
n
for
n
in
self
.
n_frames
])
return
np
.
lexsort
(
order
)
def
prefetch
(
self
,
indices
):
raise
False
class
SpeechToTextDatasetCreator
(
object
):
# mandatory columns
KEY_ID
,
KEY_AUDIO
,
KEY_N_FRAMES
=
"id"
,
"audio"
,
"n_frames"
KEY_TGT_TEXT
=
"tgt_text"
# optional columns
KEY_SPEAKER
,
KEY_SRC_TEXT
=
"speaker"
,
"src_text"
KEY_SRC_LANG
,
KEY_TGT_LANG
=
"src_lang"
,
"tgt_lang"
# default values
DEFAULT_SPEAKER
=
DEFAULT_SRC_TEXT
=
DEFAULT_LANG
=
""
@
classmethod
def
_from_list
(
cls
,
split_name
:
str
,
is_train_split
,
samples
:
List
[
List
[
Dict
]],
data_cfg
:
S2TDataConfig
,
tgt_dict
,
pre_tokenizer
,
bpe_tokenizer
,
)
->
SpeechToTextDataset
:
audio_paths
,
n_frames
,
src_texts
,
tgt_texts
,
ids
=
[],
[],
[],
[],
[]
speakers
,
src_langs
,
tgt_langs
=
[],
[],
[]
for
s
in
samples
:
ids
.
extend
([
ss
[
cls
.
KEY_ID
]
for
ss
in
s
])
audio_paths
.
extend
(
[
op
.
join
(
data_cfg
.
audio_root
,
ss
[
cls
.
KEY_AUDIO
])
for
ss
in
s
]
)
n_frames
.
extend
([
int
(
ss
[
cls
.
KEY_N_FRAMES
])
for
ss
in
s
])
tgt_texts
.
extend
([
ss
[
cls
.
KEY_TGT_TEXT
]
for
ss
in
s
])
src_texts
.
extend
(
[
ss
.
get
(
cls
.
KEY_SRC_TEXT
,
cls
.
DEFAULT_SRC_TEXT
)
for
ss
in
s
]
)
speakers
.
extend
([
ss
.
get
(
cls
.
KEY_SPEAKER
,
cls
.
DEFAULT_SPEAKER
)
for
ss
in
s
])
src_langs
.
extend
([
ss
.
get
(
cls
.
KEY_SRC_LANG
,
cls
.
DEFAULT_LANG
)
for
ss
in
s
])
tgt_langs
.
extend
([
ss
.
get
(
cls
.
KEY_TGT_LANG
,
cls
.
DEFAULT_LANG
)
for
ss
in
s
])
return
SpeechToTextDataset
(
split_name
,
is_train_split
,
data_cfg
,
audio_paths
,
n_frames
,
src_texts
,
tgt_texts
,
speakers
,
src_langs
,
tgt_langs
,
ids
,
tgt_dict
,
pre_tokenizer
,
bpe_tokenizer
,
)
@
classmethod
def
_get_size_ratios
(
cls
,
ids
:
List
[
str
],
sizes
:
List
[
int
],
alpha
:
float
=
1.0
):
"""Size ratios for temperature-based sampling
(https://arxiv.org/abs/1907.05019)"""
_sizes
=
np
.
array
(
sizes
)
prob
=
_sizes
/
_sizes
.
sum
()
smoothed_prob
=
prob
**
alpha
smoothed_prob
=
smoothed_prob
/
smoothed_prob
.
sum
()
size_ratio
=
(
smoothed_prob
*
_sizes
.
sum
())
/
_sizes
o_str
=
str
({
_i
:
f
"
{
prob
[
i
]:.
3
f
}
"
for
i
,
_i
in
enumerate
(
ids
)})
logger
.
info
(
f
"original sampling probability:
{
o_str
}
"
)
p_str
=
str
({
_i
:
f
"
{
smoothed_prob
[
i
]:.
3
f
}
"
for
i
,
_i
in
enumerate
(
ids
)})
logger
.
info
(
f
"balanced sampling probability:
{
p_str
}
"
)
sr_str
=
str
({
_id
:
f
"
{
size_ratio
[
i
]:.
3
f
}
"
for
i
,
_id
in
enumerate
(
ids
)})
logger
.
info
(
f
"balanced sampling size ratio:
{
sr_str
}
"
)
return
size_ratio
.
tolist
()
@
classmethod
def
from_tsv
(
cls
,
root
:
str
,
data_cfg
:
S2TDataConfig
,
splits
:
str
,
tgt_dict
,
pre_tokenizer
,
bpe_tokenizer
,
is_train_split
:
bool
,
epoch
:
int
,
seed
:
int
,
)
->
SpeechToTextDataset
:
samples
=
[]
_splits
=
splits
.
split
(
","
)
for
split
in
_splits
:
tsv_path
=
op
.
join
(
root
,
f
"
{
split
}
.tsv"
)
if
not
op
.
isfile
(
tsv_path
):
raise
FileNotFoundError
(
f
"Dataset not found:
{
tsv_path
}
"
)
with
open
(
tsv_path
)
as
f
:
reader
=
csv
.
DictReader
(
f
,
delimiter
=
"
\t
"
,
quotechar
=
None
,
doublequote
=
False
,
lineterminator
=
"
\n
"
,
quoting
=
csv
.
QUOTE_NONE
,
)
samples
.
append
([
dict
(
e
)
for
e
in
reader
])
assert
len
(
samples
)
>
0
datasets
=
[
cls
.
_from_list
(
name
,
is_train_split
,
[
s
],
data_cfg
,
tgt_dict
,
pre_tokenizer
,
bpe_tokenizer
,
)
for
name
,
s
in
zip
(
_splits
,
samples
)
]
if
is_train_split
and
len
(
_splits
)
>
1
and
data_cfg
.
sampling_alpha
!=
1.0
:
# temperature-based sampling
size_ratios
=
cls
.
_get_size_ratios
(
_splits
,
[
len
(
s
)
for
s
in
samples
],
alpha
=
data_cfg
.
sampling_alpha
)
datasets
=
[
ResamplingDataset
(
d
,
size_ratio
=
r
,
seed
=
seed
,
epoch
=
epoch
,
replace
=
(
r
>=
1.0
)
)
for
d
,
r
in
zip
(
datasets
,
size_ratios
)
]
return
ConcatDataset
(
datasets
)
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/audio/utterance_mixing_dataset.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
itertools
import
logging
import
os
import
sys
import
io
import
json
import
h5py
from
typing
import
Any
,
List
,
Optional
,
Union
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
fairseq.data
import
data_utils
from
fairseq.data.fairseq_dataset
import
FairseqDataset
from
fairseq.data.audio.audio_utils
import
(
parse_path
,
read_from_stored_zip
,
is_sf_audio_data
,
)
logger
=
logging
.
getLogger
(
__name__
)
def
load_label
(
label_path
,
inds
,
tot
):
with
open
(
label_path
)
as
f
:
labels
=
[
line
.
rstrip
()
for
line
in
f
]
assert
(
len
(
labels
)
==
tot
),
f
"number of labels does not match (
{
len
(
labels
)
}
!=
{
tot
}
)"
labels
=
[
labels
[
i
]
for
i
in
inds
]
return
labels
def
load_label_offset
(
label_path
,
inds
,
tot
):
with
open
(
label_path
)
as
f
:
code_lengths
=
[
len
(
line
.
encode
(
"utf-8"
))
for
line
in
f
]
assert
(
len
(
code_lengths
)
==
tot
),
f
"number of labels does not match (
{
len
(
code_lengths
)
}
!=
{
tot
}
)"
offsets
=
list
(
itertools
.
accumulate
([
0
]
+
code_lengths
))
offsets
=
[(
offsets
[
i
],
offsets
[
i
+
1
])
for
i
in
inds
]
return
offsets
def
verify_label_lengths
(
audio_sizes
,
audio_rate
,
label_path
,
label_rate
,
inds
,
tot
,
tol
=
0.1
,
# tolerance in seconds
):
if
label_rate
<
0
:
logger
.
info
(
f
"
{
label_path
}
is sequence label. skipped"
)
return
with
open
(
label_path
)
as
f
:
lengths
=
[
len
(
line
.
rstrip
().
split
())
for
line
in
f
]
assert
len
(
lengths
)
==
tot
lengths
=
[
lengths
[
i
]
for
i
in
inds
]
num_invalid
=
0
for
i
,
ind
in
enumerate
(
inds
):
dur_from_audio
=
audio_sizes
[
i
]
/
audio_rate
dur_from_label
=
lengths
[
i
]
/
label_rate
if
abs
(
dur_from_audio
-
dur_from_label
)
>
tol
:
logger
.
warning
(
(
f
"audio and label duration differ too much "
f
"(|
{
dur_from_audio
}
-
{
dur_from_label
}
| >
{
tol
}
) "
f
"in line
{
ind
+
1
}
of
{
label_path
}
. Check if `label_rate` "
f
"is correctly set (currently
{
label_rate
}
). "
f
"num. of samples =
{
audio_sizes
[
i
]
}
; "
f
"label length =
{
lengths
[
i
]
}
"
)
)
num_invalid
+=
1
if
num_invalid
>
0
:
logger
.
warning
(
f
"total
{
num_invalid
}
(audio, label) pairs with mismatched lengths"
)
class
UtteranceMixingDataset
(
FairseqDataset
):
def
__init__
(
self
,
manifest_path
:
str
,
sample_rate
:
float
,
label_paths
:
List
[
str
],
label_rates
:
Union
[
List
[
float
],
float
],
# -1 for sequence labels
pad_list
:
List
[
str
],
eos_list
:
List
[
str
],
label_processors
:
Optional
[
List
[
Any
]]
=
None
,
max_keep_sample_size
:
Optional
[
int
]
=
None
,
min_keep_sample_size
:
Optional
[
int
]
=
None
,
max_sample_size
:
Optional
[
int
]
=
None
,
shuffle
:
bool
=
True
,
pad_audio
:
bool
=
False
,
normalize
:
bool
=
False
,
store_labels
:
bool
=
True
,
random_crop
:
bool
=
False
,
single_target
:
bool
=
False
,
multitask
:
bool
=
False
,
mixing_max_len
:
int
=
-
1
,
mixing_prob
:
float
=
0.2
,
mixing_num
:
int
=
1
,
mixing_noise
:
bool
=
False
,
mixing_noise_prob
:
float
=
0.0
,
mixing_noise_num
:
int
=
1
,
noise_path
:
Optional
[
str
]
=
None
,
):
self
.
sample_rate
=
sample_rate
self
.
shuffle
=
shuffle
self
.
random_crop
=
random_crop
self
.
num_labels
=
len
(
label_paths
)
self
.
pad_list
=
pad_list
self
.
eos_list
=
eos_list
self
.
label_processors
=
label_processors
self
.
single_target
=
single_target
self
.
multitask
=
multitask
self
.
epoch
=
0
self
.
chunk_names
=
[]
self
.
chunk_indices
=
[]
n_long
,
n_short
=
0
,
0
names
,
inds
,
sizes
=
[],
[],
[]
bnds
=
[]
bnd_path
=
manifest_path
.
replace
(
'tsv'
,
'bnd'
)
if
os
.
path
.
exists
(
bnd_path
):
with
open
(
bnd_path
)
as
f
:
bnds
=
f
.
readlines
()
new_bnds
=
[]
with
open
(
manifest_path
)
as
f
:
root
=
f
.
readline
().
strip
()
for
ind
,
line
in
enumerate
(
f
):
items
=
line
.
strip
().
split
(
"
\t
"
)
sz
=
int
(
items
[
1
])
if
min_keep_sample_size
is
not
None
and
sz
<
min_keep_sample_size
:
n_short
+=
1
elif
max_keep_sample_size
is
not
None
and
sz
>
max_keep_sample_size
:
n_long
+=
1
else
:
fname
=
items
[
0
].
split
(
":"
)
if
len
(
fname
)
>
1
:
if
len
(
self
.
chunk_names
)
==
0
or
fname
[
0
]
!=
self
.
chunk_names
[
-
1
]:
self
.
chunk_names
.
append
(
fname
[
0
])
self
.
chunk_indices
.
append
(
len
(
names
))
names
.
append
(
items
[
0
])
inds
.
append
(
ind
)
sizes
.
append
(
sz
)
if
len
(
bnds
)
>
0
:
new_bnds
.
append
(
list
(
map
(
int
,
bnds
[
ind
].
strip
().
split
())))
tot
=
ind
+
1
logger
.
info
(
(
f
"max_keep=
{
max_keep_sample_size
}
, min_keep=
{
min_keep_sample_size
}
, "
f
"loaded
{
len
(
names
)
}
, skipped
{
n_short
}
short and
{
n_long
}
long, "
f
"longest-loaded=
{
max
(
sizes
)
}
, shortest-loaded=
{
min
(
sizes
)
}
"
)
)
self
.
audio_root
=
root
self
.
audio_names
=
names
self
.
sizes
=
sizes
self
.
bnds
=
new_bnds
self
.
label_rates
=
(
[
label_rates
for
_
in
range
(
len
(
label_paths
))]
if
isinstance
(
label_rates
,
int
)
else
label_rates
)
self
.
store_labels
=
store_labels
if
store_labels
:
self
.
label_list
=
[
load_label
(
p
,
inds
,
tot
)
for
p
in
label_paths
]
else
:
self
.
label_paths
=
label_paths
self
.
label_offsets_list
=
[
load_label_offset
(
p
,
inds
,
tot
)
for
p
in
label_paths
]
assert
(
label_processors
is
None
or
len
(
label_processors
)
==
self
.
num_labels
)
for
label_path
,
label_rate
in
zip
(
label_paths
,
self
.
label_rates
):
verify_label_lengths
(
self
.
sizes
,
sample_rate
,
label_path
,
label_rate
,
inds
,
tot
)
self
.
max_sample_size
=
(
max_sample_size
if
max_sample_size
is
not
None
else
sys
.
maxsize
)
self
.
pad_audio
=
pad_audio
self
.
normalize
=
normalize
logger
.
info
(
f
"pad_audio=
{
pad_audio
}
, random_crop=
{
random_crop
}
, "
f
"normalize=
{
normalize
}
, max_sample_size=
{
self
.
max_sample_size
}
"
)
self
.
mixing_max_len
=
mixing_max_len
self
.
mixing_prob
=
mixing_prob
self
.
mixing_num
=
mixing_num
self
.
mixing_noise
=
mixing_noise
self
.
mixing_noise_prob
=
mixing_noise_prob
self
.
mixing_noise_num
=
mixing_noise_num
self
.
noise_path
=
noise_path
if
self
.
mixing_noise
:
assert
os
.
path
.
exists
(
self
.
noise_path
),
f
"Invalid noise path
{
self
.
noise_path
}
"
self
.
noise_list
=
json
.
load
(
open
(
self
.
noise_path
,
'r'
))
self
.
noise_container
=
{}
else
:
self
.
noise_list
=
[]
logger
.
info
(
f
"mixing_max_len=
{
mixing_max_len
}
, mixing_prob=
{
mixing_prob
}
, mixing_num=
{
mixing_num
}
,"
f
"mixing_noise=
{
mixing_noise
}
, mixing_noise_prob=
{
mixing_noise_prob
}
, mixing_noise_num=
{
mixing_noise_num
}
,"
f
"noise_path=
{
noise_path
}
, noise_list_len=
{
len
(
self
.
noise_list
)
}
,"
)
def
set_epoch
(
self
,
epoch
):
self
.
epoch
=
epoch
def
batch_by_size
(
self
,
indices
,
max_tokens
=
None
,
max_sentences
=
None
,
required_batch_size_multiple
=
1
):
self
.
max_tokens
=
max_tokens
self
.
max_sentences
=
max_sentences
self
.
required_batch_size_multiple
=
required_batch_size_multiple
if
isinstance
(
indices
[
0
],
list
):
batch_list
=
[]
for
indice
in
indices
:
batch
=
super
(
UtteranceMixingDataset
,
self
).
batch_by_size
(
indice
,
max_tokens
,
max_sentences
,
required_batch_size_multiple
)
batch_list
.
append
(
batch
)
return
batch_list
else
:
return
super
(
UtteranceMixingDataset
,
self
).
batch_by_size
(
indices
,
max_tokens
,
max_sentences
,
required_batch_size_multiple
)
def
shuffle_batches
(
self
,
batches
,
seed
):
if
isinstance
(
batches
[
0
],
list
):
new_batches
=
[]
with
data_utils
.
numpy_seed
(
seed
):
np
.
random
.
shuffle
(
batches
)
for
batch
in
batches
:
np
.
random
.
shuffle
(
batch
)
new_batches
.
extend
(
batch
)
return
new_batches
else
:
with
data_utils
.
numpy_seed
(
seed
):
np
.
random
.
shuffle
(
batches
)
return
batches
def
reset_batch_sampler
(
self
):
indices
=
self
.
ordered_indices
()
batch_sampler
=
self
.
batch_by_size
(
indices
,
self
.
max_tokens
,
self
.
max_sentences
,
self
.
required_batch_size_multiple
)
return
batch_sampler
def
get_audio
(
self
,
index
):
import
soundfile
as
sf
wav_path
=
os
.
path
.
join
(
self
.
audio_root
,
self
.
audio_names
[
index
])
_path
,
slice_ptr
=
parse_path
(
wav_path
)
if
len
(
slice_ptr
)
==
2
:
byte_data
=
read_from_stored_zip
(
_path
,
slice_ptr
[
0
],
slice_ptr
[
1
])
assert
is_sf_audio_data
(
byte_data
)
wav_path
=
io
.
BytesIO
(
byte_data
)
wav
,
cur_sample_rate
=
sf
.
read
(
wav_path
)
wav
=
torch
.
from_numpy
(
wav
).
float
()
wav
=
self
.
postprocess
(
wav
,
cur_sample_rate
)
return
wav
def
get_label
(
self
,
index
,
label_idx
):
if
self
.
store_labels
:
label
=
self
.
label_list
[
label_idx
][
index
]
else
:
with
open
(
self
.
label_paths
[
label_idx
])
as
f
:
offset_s
,
offset_e
=
self
.
label_offsets_list
[
label_idx
][
index
]
f
.
seek
(
offset_s
)
label
=
f
.
read
(
offset_e
-
offset_s
)
if
self
.
label_processors
is
not
None
:
label
=
self
.
label_processors
[
label_idx
](
label
)
return
label
def
get_labels
(
self
,
index
):
return
[
self
.
get_label
(
index
,
i
)
for
i
in
range
(
self
.
num_labels
)]
def
__getitem__
(
self
,
index
):
wav
=
self
.
get_audio
(
index
)
labels
=
self
.
get_labels
(
index
)
if
len
(
self
.
bnds
)
>
0
:
bnd
=
self
.
bnds
[
index
]
else
:
bnd
=
[]
return
{
"id"
:
index
,
"source"
:
wav
,
"label_list"
:
labels
,
"boundary"
:
bnd
}
def
__len__
(
self
):
return
len
(
self
.
sizes
)
def
crop_to_max_size
(
self
,
wav
,
target_size
):
size
=
len
(
wav
)
diff
=
size
-
target_size
if
diff
<=
0
:
return
wav
,
0
start
,
end
=
0
,
target_size
if
self
.
random_crop
:
start
=
np
.
random
.
randint
(
0
,
diff
+
1
)
end
=
size
-
diff
+
start
return
wav
[
start
:
end
],
start
def
collater
(
self
,
samples
):
# target = max(sizes) -> random_crop not used
# target = max_sample_size -> random_crop used for long
samples
=
[
s
for
s
in
samples
if
s
[
"source"
]
is
not
None
]
if
len
(
samples
)
==
0
:
return
{}
audios
=
[
s
[
"source"
]
for
s
in
samples
]
audio_sizes
=
[
len
(
s
)
for
s
in
audios
]
bnds
=
[
s
[
"boundary"
]
for
s
in
samples
]
if
self
.
pad_audio
:
audio_size
=
min
(
max
(
audio_sizes
),
self
.
max_sample_size
)
else
:
audio_size
=
min
(
min
(
audio_sizes
),
self
.
max_sample_size
)
collated_audios
,
padding_mask
,
audio_starts
=
self
.
collater_audio
(
audios
,
audio_size
)
if
self
.
mixing_prob
>
0
:
collated_audios
=
self
.
mixing_collated_audios
(
collated_audios
)
targets_by_label
=
[
[
s
[
"label_list"
][
i
]
for
s
in
samples
]
for
i
in
range
(
self
.
num_labels
)
]
targets_list
,
lengths_list
,
ntokens_list
=
self
.
collater_label
(
targets_by_label
,
audio_size
,
audio_starts
)
net_input
=
{
"source"
:
collated_audios
,
"padding_mask"
:
padding_mask
,
"boundary"
:
bnds
}
batch
=
{
"id"
:
torch
.
LongTensor
([
s
[
"id"
]
for
s
in
samples
]),
"net_input"
:
net_input
,
}
if
self
.
single_target
:
batch
[
"target_lengths"
]
=
lengths_list
[
0
]
batch
[
"ntokens"
]
=
ntokens_list
[
0
]
batch
[
"target"
]
=
targets_list
[
0
]
else
:
batch
[
"target_lengths_list"
]
=
lengths_list
batch
[
"ntokens_list"
]
=
ntokens_list
batch
[
"target_list"
]
=
targets_list
if
self
.
multitask
:
batch
[
"task"
]
=
"multitask"
else
:
batch
[
"task"
]
=
"wavlm"
return
batch
def
mixing_collated_audios
(
self
,
source
):
# mixing utterance or noise within the current batch
B
=
source
.
shape
[
0
]
T
=
source
.
shape
[
1
]
mixing_max_len
=
T
//
2
if
self
.
mixing_max_len
<
0
else
T
//
self
.
mixing_max_len
mixing_max_len
=
T
if
mixing_max_len
>
T
else
mixing_max_len
for
i
in
range
(
B
):
if
np
.
random
.
random
()
<
self
.
mixing_prob
:
if
self
.
mixing_noise
and
np
.
random
.
random
()
<
self
.
mixing_noise_prob
:
# mixing with noise
choices
=
np
.
random
.
choice
(
self
.
noise_list
,
self
.
mixing_noise_num
)
for
c
in
choices
:
path
,
key
,
start
,
end
=
c
[
"loc"
].
split
(
"
\t
"
)
if
path
not
in
self
.
noise_container
:
self
.
noise_container
[
path
]
=
h5py
.
File
(
path
,
"r"
)[
"wav"
]
noise
=
self
.
noise_container
[
path
][
int
(
start
):
int
(
end
)]
noise
=
noise
.
astype
(
np
.
float32
)
/
np
.
iinfo
(
np
.
int16
).
max
ref_pow
=
np
.
mean
(
source
[
i
].
numpy
()
**
2
)
noise_pow
=
np
.
mean
(
noise
**
2
)
if
noise_pow
==
0
:
scale
=
0
else
:
snr
=
np
.
random
.
uniform
(
-
5
,
20
)
scale
=
(
ref_pow
/
(
noise_pow
*
10
**
(
snr
/
10
)))
**
0.5
noise
=
scale
*
noise
noise
=
torch
.
from_numpy
(
noise
).
type_as
(
source
)
c_len
=
np
.
random
.
randint
(
0
,
mixing_max_len
+
1
)
c_len
=
min
(
c_len
,
noise
.
shape
[
0
])
c_end
=
np
.
random
.
randint
(
c_len
,
noise
.
shape
[
0
]
+
1
)
c_start
=
c_end
-
c_len
s_end
=
np
.
random
.
randint
(
c_len
,
T
+
1
)
s_start
=
s_end
-
c_len
source
[
i
,
s_start
:
s_end
]
+=
noise
[
c_start
:
c_end
]
else
:
# mixing with utterance
choices
=
np
.
random
.
choice
(
range
(
B
),
self
.
mixing_num
,
replace
=
True
)
for
c
in
choices
:
c_len
=
np
.
random
.
randint
(
0
,
mixing_max_len
+
1
)
c_end
=
np
.
random
.
randint
(
c_len
,
T
+
1
)
c_start
=
c_end
-
c_len
s_end
=
np
.
random
.
randint
(
c_len
,
T
+
1
)
s_start
=
s_end
-
c_len
ref_pow
=
np
.
mean
(
source
[
i
].
numpy
()
**
2
)
noise_pow
=
np
.
mean
(
source
[
c
].
numpy
()
**
2
)
if
noise_pow
==
0
:
scale
=
0
else
:
snr
=
np
.
random
.
uniform
(
-
5
,
5
)
scale
=
(
ref_pow
/
(
noise_pow
*
10
**
(
snr
/
10
)))
**
0.5
source
[
i
,
s_start
:
s_end
]
+=
source
[
c
,
c_start
:
c_end
].
clone
()
*
scale
if
self
.
normalize
:
with
torch
.
no_grad
():
source
[
i
]
=
F
.
layer_norm
(
source
[
i
],
source
[
i
].
shape
)
return
source
def
collater_audio
(
self
,
audios
,
audio_size
):
collated_audios
=
audios
[
0
].
new_zeros
(
len
(
audios
),
audio_size
)
padding_mask
=
(
torch
.
BoolTensor
(
collated_audios
.
shape
).
fill_
(
False
)
# if self.pad_audio else None
)
audio_starts
=
[
0
for
_
in
audios
]
for
i
,
audio
in
enumerate
(
audios
):
diff
=
len
(
audio
)
-
audio_size
if
diff
==
0
:
collated_audios
[
i
]
=
audio
elif
diff
<
0
:
assert
self
.
pad_audio
collated_audios
[
i
]
=
torch
.
cat
(
[
audio
,
audio
.
new_full
((
-
diff
,),
0.0
)]
)
padding_mask
[
i
,
diff
:]
=
True
else
:
collated_audios
[
i
],
audio_starts
[
i
]
=
self
.
crop_to_max_size
(
audio
,
audio_size
)
return
collated_audios
,
padding_mask
,
audio_starts
def
collater_frm_label
(
self
,
targets
,
audio_size
,
audio_starts
,
label_rate
,
pad
):
assert
label_rate
>
0
s2f
=
label_rate
/
self
.
sample_rate
frm_starts
=
[
int
(
round
(
s
*
s2f
))
for
s
in
audio_starts
]
frm_size
=
int
(
round
(
audio_size
*
s2f
))
if
not
self
.
pad_audio
:
rem_size
=
[
len
(
t
)
-
s
for
t
,
s
in
zip
(
targets
,
frm_starts
)]
frm_size
=
min
(
frm_size
,
*
rem_size
)
targets
=
[
t
[
s
:
s
+
frm_size
]
for
t
,
s
in
zip
(
targets
,
frm_starts
)]
logger
.
debug
(
f
"audio_starts=
{
audio_starts
}
"
)
logger
.
debug
(
f
"frame_starts=
{
frm_starts
}
"
)
logger
.
debug
(
f
"frame_size=
{
frm_size
}
"
)
lengths
=
torch
.
LongTensor
([
len
(
t
)
for
t
in
targets
])
ntokens
=
lengths
.
sum
().
item
()
targets
=
data_utils
.
collate_tokens
(
targets
,
pad_idx
=
pad
,
left_pad
=
False
)
return
targets
,
lengths
,
ntokens
def
collater_seq_label
(
self
,
targets
,
pad
):
lengths
=
torch
.
LongTensor
([
len
(
t
)
for
t
in
targets
])
ntokens
=
lengths
.
sum
().
item
()
targets
=
data_utils
.
collate_tokens
(
targets
,
pad_idx
=
pad
,
left_pad
=
False
)
return
targets
,
lengths
,
ntokens
def
collater_label
(
self
,
targets_by_label
,
audio_size
,
audio_starts
):
targets_list
,
lengths_list
,
ntokens_list
=
[],
[],
[]
itr
=
zip
(
targets_by_label
,
self
.
label_rates
,
self
.
pad_list
)
for
targets
,
label_rate
,
pad
in
itr
:
if
label_rate
==
-
1
:
targets
,
lengths
,
ntokens
=
self
.
collater_seq_label
(
targets
,
pad
)
else
:
targets
,
lengths
,
ntokens
=
self
.
collater_frm_label
(
targets
,
audio_size
,
audio_starts
,
label_rate
,
pad
)
targets_list
.
append
(
targets
)
lengths_list
.
append
(
lengths
)
ntokens_list
.
append
(
ntokens
)
return
targets_list
,
lengths_list
,
ntokens_list
def
num_tokens
(
self
,
index
):
return
self
.
size
(
index
)
def
size
(
self
,
index
):
if
self
.
pad_audio
:
return
self
.
sizes
[
index
]
return
min
(
self
.
sizes
[
index
],
self
.
max_sample_size
)
def
ordered_indices
(
self
):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if
self
.
shuffle
:
if
len
(
self
.
chunk_names
)
>
0
:
with
data_utils
.
numpy_seed
(
self
.
epoch
):
self
.
chunk_order
=
np
.
random
.
permutation
(
len
(
self
.
chunk_names
))
chunk_count
=
0
tmp_sizes
=
[]
tmp_indices
=
[]
indice
=
[]
for
i
in
self
.
chunk_order
:
chunk_count
+=
1
start
=
self
.
chunk_indices
[
i
]
end
=
self
.
chunk_indices
[
i
+
1
]
if
i
<
len
(
self
.
chunk_names
)
-
1
else
len
(
self
)
size
=
list
(
self
.
sizes
[
start
:
end
])
tmp_indices
.
extend
(
list
(
np
.
arange
(
start
,
end
)))
tmp_sizes
.
extend
(
size
)
if
chunk_count
%
10
==
0
or
i
==
self
.
chunk_order
[
0
]:
order
=
[
np
.
random
.
permutation
(
len
(
tmp_indices
))]
order
.
append
(
np
.
minimum
(
np
.
array
(
tmp_sizes
),
self
.
max_sample_size
,
)
)
sort_idx
=
np
.
lexsort
(
order
)[::
-
1
]
indice
.
append
([
tmp_indices
[
k
]
for
k
in
sort_idx
])
tmp_indices
=
[]
tmp_sizes
=
[]
return
indice
else
:
order
=
[
np
.
random
.
permutation
(
len
(
self
))]
order
.
append
(
np
.
minimum
(
np
.
array
(
self
.
sizes
),
self
.
max_sample_size
,
)
)
return
np
.
lexsort
(
order
)[::
-
1
]
else
:
return
np
.
arange
(
len
(
self
))
def
postprocess
(
self
,
wav
,
cur_sample_rate
):
if
wav
.
dim
()
==
2
:
wav
=
wav
.
mean
(
-
1
)
assert
wav
.
dim
()
==
1
,
wav
.
dim
()
if
cur_sample_rate
!=
self
.
sample_rate
:
raise
Exception
(
f
"sr
{
cur_sample_rate
}
!=
{
self
.
sample_rate
}
"
)
if
self
.
normalize
:
with
torch
.
no_grad
():
wav
=
F
.
layer_norm
(
wav
,
wav
.
shape
)
return
wav
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/base_wrapper_dataset.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
torch.utils.data.dataloader
import
default_collate
from
.
import
FairseqDataset
class
BaseWrapperDataset
(
FairseqDataset
):
def
__init__
(
self
,
dataset
):
super
().
__init__
()
self
.
dataset
=
dataset
def
__getitem__
(
self
,
index
):
return
self
.
dataset
[
index
]
def
__len__
(
self
):
return
len
(
self
.
dataset
)
def
collater
(
self
,
samples
):
if
hasattr
(
self
.
dataset
,
"collater"
):
return
self
.
dataset
.
collater
(
samples
)
else
:
return
default_collate
(
samples
)
@
property
def
sizes
(
self
):
return
self
.
dataset
.
sizes
def
num_tokens
(
self
,
index
):
return
self
.
dataset
.
num_tokens
(
index
)
def
size
(
self
,
index
):
return
self
.
dataset
.
size
(
index
)
def
ordered_indices
(
self
):
return
self
.
dataset
.
ordered_indices
()
@
property
def
supports_prefetch
(
self
):
return
getattr
(
self
.
dataset
,
"supports_prefetch"
,
False
)
def
attr
(
self
,
attr
:
str
,
index
:
int
):
return
self
.
dataset
.
attr
(
attr
,
index
)
def
prefetch
(
self
,
indices
):
self
.
dataset
.
prefetch
(
indices
)
def
get_batch_shapes
(
self
):
return
self
.
dataset
.
get_batch_shapes
()
def
batch_by_size
(
self
,
indices
,
max_tokens
=
None
,
max_sentences
=
None
,
required_batch_size_multiple
=
1
,
):
return
self
.
dataset
.
batch_by_size
(
indices
,
max_tokens
=
max_tokens
,
max_sentences
=
max_sentences
,
required_batch_size_multiple
=
required_batch_size_multiple
,
)
def
filter_indices_by_size
(
self
,
indices
,
max_sizes
):
return
self
.
dataset
.
filter_indices_by_size
(
indices
,
max_sizes
)
@
property
def
can_reuse_epoch_itr_across_epochs
(
self
):
return
self
.
dataset
.
can_reuse_epoch_itr_across_epochs
def
set_epoch
(
self
,
epoch
):
super
().
set_epoch
(
epoch
)
if
hasattr
(
self
.
dataset
,
"set_epoch"
):
self
.
dataset
.
set_epoch
(
epoch
)
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/concat_dataset.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
bisect
import
numpy
as
np
from
torch.utils.data.dataloader
import
default_collate
from
.
import
FairseqDataset
class
ConcatDataset
(
FairseqDataset
):
@
staticmethod
def
cumsum
(
sequence
,
sample_ratios
):
r
,
s
=
[],
0
for
e
,
ratio
in
zip
(
sequence
,
sample_ratios
):
curr_len
=
int
(
ratio
*
len
(
e
))
r
.
append
(
curr_len
+
s
)
s
+=
curr_len
return
r
def
__init__
(
self
,
datasets
,
sample_ratios
=
1
):
super
(
ConcatDataset
,
self
).
__init__
()
assert
len
(
datasets
)
>
0
,
"datasets should not be an empty iterable"
self
.
datasets
=
list
(
datasets
)
if
isinstance
(
sample_ratios
,
int
):
sample_ratios
=
[
sample_ratios
]
*
len
(
self
.
datasets
)
self
.
sample_ratios
=
sample_ratios
self
.
cumulative_sizes
=
self
.
cumsum
(
self
.
datasets
,
sample_ratios
)
self
.
real_sizes
=
[
len
(
d
)
for
d
in
self
.
datasets
]
def
__len__
(
self
):
return
self
.
cumulative_sizes
[
-
1
]
def
__getitem__
(
self
,
idx
):
dataset_idx
,
sample_idx
=
self
.
_get_dataset_and_sample_index
(
idx
)
return
self
.
datasets
[
dataset_idx
][
sample_idx
]
def
_get_dataset_and_sample_index
(
self
,
idx
:
int
):
dataset_idx
=
bisect
.
bisect_right
(
self
.
cumulative_sizes
,
idx
)
if
dataset_idx
==
0
:
sample_idx
=
idx
else
:
sample_idx
=
idx
-
self
.
cumulative_sizes
[
dataset_idx
-
1
]
sample_idx
=
sample_idx
%
self
.
real_sizes
[
dataset_idx
]
return
dataset_idx
,
sample_idx
def
collater
(
self
,
samples
,
**
extra_args
):
# For now only supports datasets with same underlying collater implementations
if
hasattr
(
self
.
datasets
[
0
],
"collater"
):
return
self
.
datasets
[
0
].
collater
(
samples
,
**
extra_args
)
else
:
return
default_collate
(
samples
,
**
extra_args
)
def
size
(
self
,
idx
:
int
):
"""
Return an example's size as a float or tuple.
"""
dataset_idx
,
sample_idx
=
self
.
_get_dataset_and_sample_index
(
idx
)
return
self
.
datasets
[
dataset_idx
].
size
(
sample_idx
)
def
num_tokens
(
self
,
index
:
int
):
return
np
.
max
(
self
.
size
(
index
))
def
attr
(
self
,
attr
:
str
,
index
:
int
):
dataset_idx
=
bisect
.
bisect_right
(
self
.
cumulative_sizes
,
index
)
return
getattr
(
self
.
datasets
[
dataset_idx
],
attr
,
None
)
@
property
def
sizes
(
self
):
_dataset_sizes
=
[]
for
ds
,
sr
in
zip
(
self
.
datasets
,
self
.
sample_ratios
):
if
isinstance
(
ds
.
sizes
,
np
.
ndarray
):
_dataset_sizes
.
append
(
np
.
tile
(
ds
.
sizes
,
sr
))
else
:
# Only support underlying dataset with single size array.
assert
isinstance
(
ds
.
sizes
,
list
)
_dataset_sizes
.
append
(
np
.
tile
(
ds
.
sizes
[
0
],
sr
))
return
np
.
concatenate
(
_dataset_sizes
)
@
property
def
supports_prefetch
(
self
):
return
all
(
d
.
supports_prefetch
for
d
in
self
.
datasets
)
def
ordered_indices
(
self
):
"""
Returns indices sorted by length. So less padding is needed.
"""
if
isinstance
(
self
.
sizes
,
np
.
ndarray
)
and
len
(
self
.
sizes
.
shape
)
>
1
:
# special handling for concatenating lang_pair_datasets
indices
=
np
.
arange
(
len
(
self
))
sizes
=
self
.
sizes
tgt_sizes
=
(
sizes
[:,
1
]
if
len
(
sizes
.
shape
)
>
0
and
sizes
.
shape
[
1
]
>
1
else
None
)
src_sizes
=
(
sizes
[:,
0
]
if
len
(
sizes
.
shape
)
>
0
and
sizes
.
shape
[
1
]
>
1
else
sizes
)
# sort by target length, then source length
if
tgt_sizes
is
not
None
:
indices
=
indices
[
np
.
argsort
(
tgt_sizes
[
indices
],
kind
=
"mergesort"
)]
return
indices
[
np
.
argsort
(
src_sizes
[
indices
],
kind
=
"mergesort"
)]
else
:
return
np
.
argsort
(
self
.
sizes
)
def
prefetch
(
self
,
indices
):
frm
=
0
for
to
,
ds
in
zip
(
self
.
cumulative_sizes
,
self
.
datasets
):
real_size
=
len
(
ds
)
if
getattr
(
ds
,
"supports_prefetch"
,
False
):
ds
.
prefetch
([(
i
-
frm
)
%
real_size
for
i
in
indices
if
frm
<=
i
<
to
])
frm
=
to
@
property
def
can_reuse_epoch_itr_across_epochs
(
self
):
return
all
(
d
.
can_reuse_epoch_itr_across_epochs
for
d
in
self
.
datasets
)
def
set_epoch
(
self
,
epoch
):
super
().
set_epoch
(
epoch
)
for
ds
in
self
.
datasets
:
if
hasattr
(
ds
,
"set_epoch"
):
ds
.
set_epoch
(
epoch
)
Prev
1
…
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment