Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
VITA-Audio_pytorch
Commits
39ac40a9
Commit
39ac40a9
authored
Jun 06, 2025
by
chenzk
Browse files
v1.0
parents
Pipeline
#2747
failed with stages
in 0 seconds
Changes
427
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2724 additions
and
0 deletions
+2724
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/config/__init__.py
...-eval/thirdparty/UniSpeech/src/fairseq/config/__init__.py
+4
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/config/config.yaml
...-eval/thirdparty/UniSpeech/src/fairseq/config/config.yaml
+18
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/criterions/__init__.py
...l/thirdparty/UniSpeech/src/fairseq/criterions/__init__.py
+36
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/criterions/adaptive_loss.py
...rdparty/UniSpeech/src/fairseq/criterions/adaptive_loss.py
+123
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/criterions/cross_entropy.py
...rdparty/UniSpeech/src/fairseq/criterions/cross_entropy.py
+90
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/criterions/ctc.py
...s-eval/thirdparty/UniSpeech/src/fairseq/criterions/ctc.py
+304
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/criterions/fairseq_criterion.py
...rty/UniSpeech/src/fairseq/criterions/fairseq_criterion.py
+120
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/criterions/hubert_criterion.py
...arty/UniSpeech/src/fairseq/criterions/hubert_criterion.py
+195
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/criterions/unispeech_criterion.py
...y/UniSpeech/src/fairseq/criterions/unispeech_criterion.py
+143
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/criterions/wav2vec_criterion.py
...rty/UniSpeech/src/fairseq/criterions/wav2vec_criterion.py
+195
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/criterions/wavlm_criterion.py
...party/UniSpeech/src/fairseq/criterions/wavlm_criterion.py
+207
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/__init__.py
...ts-eval/thirdparty/UniSpeech/src/fairseq/data/__init__.py
+43
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/add_target_dataset.py
...irdparty/UniSpeech/src/fairseq/data/add_target_dataset.py
+70
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/audio/__init__.py
...l/thirdparty/UniSpeech/src/fairseq/data/audio/__init__.py
+0
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/audio/audio_utils.py
...hirdparty/UniSpeech/src/fairseq/data/audio/audio_utils.py
+298
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/audio/chunk_audio_dataset.py
...y/UniSpeech/src/fairseq/data/audio/chunk_audio_dataset.py
+354
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/audio/feats_dataset.py
...rdparty/UniSpeech/src/fairseq/data/audio/feats_dataset.py
+282
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/audio/feature_transforms/__init__.py
...ech/src/fairseq/data/audio/feature_transforms/__init__.py
+82
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/audio/feature_transforms/global_cmvn.py
.../src/fairseq/data/audio/feature_transforms/global_cmvn.py
+29
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/audio/feature_transforms/specaugment.py
.../src/fairseq/data/audio/feature_transforms/specaugment.py
+131
-0
No files found.
Too many changes to show.
To preserve performance only
427 of 427+
files are displayed.
Plain diff
Email patch
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/config/__init__.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/config/config.yaml
0 → 100644
View file @
39ac40a9
# @package _group_
hydra
:
run
:
dir
:
.
defaults
:
-
task
:
null
-
model
:
null
-
criterion
:
cross_entropy
-
optimizer
:
null
-
lr_scheduler
:
fixed
-
bpe
:
null
-
tokenizer
:
null
-
scoring
:
null
-
generation
:
null
-
common_eval
:
null
-
eval_lm
:
null
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/criterions/__init__.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
import
importlib
import
os
from
fairseq
import
registry
from
fairseq.criterions.fairseq_criterion
import
(
# noqa
FairseqCriterion
,
LegacyFairseqCriterion
,
)
from
omegaconf
import
DictConfig
(
build_criterion_
,
register_criterion
,
CRITERION_REGISTRY
,
CRITERION_DATACLASS_REGISTRY
,
)
=
registry
.
setup_registry
(
"--criterion"
,
base_class
=
FairseqCriterion
,
default
=
"cross_entropy"
)
def
build_criterion
(
cfg
:
DictConfig
,
task
):
return
build_criterion_
(
cfg
,
task
)
# automatically import any Python files in the criterions/ directory
for
file
in
sorted
(
os
.
listdir
(
os
.
path
.
dirname
(
__file__
))):
if
file
.
endswith
(
".py"
)
and
not
file
.
startswith
(
"_"
):
file_name
=
file
[:
file
.
find
(
".py"
)]
importlib
.
import_module
(
"fairseq.criterions."
+
file_name
)
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/criterions/adaptive_loss.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
dataclasses
import
dataclass
import
torch.nn.functional
as
F
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
from
fairseq.dataclass
import
FairseqDataclass
from
fairseq.dataclass.constants
import
DDP_BACKEND_CHOICES
from
omegaconf
import
II
@
dataclass
class
AdaptiveLossConfig
(
FairseqDataclass
):
sentence_avg
:
bool
=
II
(
"optimization.sentence_avg"
)
ddp_backend
:
DDP_BACKEND_CHOICES
=
II
(
"distributed_training.ddp_backend"
)
@
register_criterion
(
"adaptive_loss"
,
dataclass
=
AdaptiveLossConfig
)
class
AdaptiveLoss
(
FairseqCriterion
):
"""This is an implementation of the loss function accompanying the adaptive softmax approximation for
graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs"
(http://arxiv.org/abs/1609.04309)."""
def
__init__
(
self
,
task
,
sentence_avg
):
super
().
__init__
(
task
)
self
.
sentence_avg
=
sentence_avg
@
classmethod
def
build_criterion
(
cls
,
cfg
:
AdaptiveLossConfig
,
task
):
if
cfg
.
ddp_backend
in
{
"c10d"
,
"pytorch_ddp"
}:
raise
Exception
(
"AdaptiveLoss is not compatible with the PyTorch "
"version of DistributedDataParallel. Please use "
"`--ddp-backend=legacy_ddp` instead."
)
return
cls
(
task
,
cfg
.
sentence_avg
)
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert
(
hasattr
(
model
.
decoder
,
"adaptive_softmax"
)
and
model
.
decoder
.
adaptive_softmax
is
not
None
)
adaptive_softmax
=
model
.
decoder
.
adaptive_softmax
net_output
=
model
(
**
sample
[
"net_input"
])
orig_target
=
model
.
get_targets
(
sample
,
net_output
)
nsentences
=
orig_target
.
size
(
0
)
orig_target
=
orig_target
.
view
(
-
1
)
bsz
=
orig_target
.
size
(
0
)
logits
,
target
=
adaptive_softmax
(
net_output
[
0
],
orig_target
)
assert
len
(
target
)
==
len
(
logits
)
loss
=
net_output
[
0
].
new
(
1
if
reduce
else
bsz
).
zero_
()
for
i
in
range
(
len
(
target
)):
if
target
[
i
]
is
not
None
:
assert
target
[
i
].
min
()
>=
0
and
target
[
i
].
max
()
<=
logits
[
i
].
size
(
1
)
loss
+=
F
.
cross_entropy
(
logits
[
i
],
target
[
i
],
ignore_index
=
self
.
padding_idx
,
reduction
=
"sum"
if
reduce
else
"none"
,
)
orig
=
utils
.
strip_pad
(
orig_target
,
self
.
padding_idx
)
ntokens
=
orig
.
numel
()
sample_size
=
sample
[
"target"
].
size
(
0
)
if
self
.
sentence_avg
else
ntokens
logging_output
=
{
"loss"
:
loss
.
data
,
"ntokens"
:
ntokens
,
"nsentences"
:
nsentences
,
"sample_size"
:
sample_size
,
}
return
loss
,
sample_size
,
logging_output
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
"""Aggregate logging outputs from data parallel training."""
loss_sum
=
utils
.
item
(
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
))
ntokens
=
utils
.
item
(
sum
(
log
.
get
(
"ntokens"
,
0
)
for
log
in
logging_outputs
))
sample_size
=
utils
.
item
(
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
)
metrics
.
log_scalar
(
"loss"
,
loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
)
if
sample_size
!=
ntokens
:
metrics
.
log_scalar
(
"nll_loss"
,
loss_sum
/
ntokens
/
math
.
log
(
2
),
ntokens
,
round
=
3
)
metrics
.
log_derived
(
"ppl"
,
lambda
meters
:
utils
.
get_perplexity
(
meters
[
"nll_loss"
].
avg
)
)
else
:
metrics
.
log_derived
(
"ppl"
,
lambda
meters
:
utils
.
get_perplexity
(
meters
[
"loss"
].
avg
)
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return
True
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/criterions/cross_entropy.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
dataclasses
import
dataclass
import
torch.nn.functional
as
F
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
from
fairseq.dataclass
import
FairseqDataclass
from
omegaconf
import
II
@
dataclass
class
CrossEntropyCriterionConfig
(
FairseqDataclass
):
sentence_avg
:
bool
=
II
(
"optimization.sentence_avg"
)
@
register_criterion
(
"cross_entropy"
,
dataclass
=
CrossEntropyCriterionConfig
)
class
CrossEntropyCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
task
,
sentence_avg
):
super
().
__init__
(
task
)
self
.
sentence_avg
=
sentence_avg
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output
=
model
(
**
sample
[
"net_input"
])
loss
,
_
=
self
.
compute_loss
(
model
,
net_output
,
sample
,
reduce
=
reduce
)
sample_size
=
(
sample
[
"target"
].
size
(
0
)
if
self
.
sentence_avg
else
sample
[
"ntokens"
]
)
logging_output
=
{
"loss"
:
loss
.
data
,
"ntokens"
:
sample
[
"ntokens"
],
"nsentences"
:
sample
[
"target"
].
size
(
0
),
"sample_size"
:
sample_size
,
}
return
loss
,
sample_size
,
logging_output
def
compute_loss
(
self
,
model
,
net_output
,
sample
,
reduce
=
True
):
lprobs
=
model
.
get_normalized_probs
(
net_output
,
log_probs
=
True
)
lprobs
=
lprobs
.
view
(
-
1
,
lprobs
.
size
(
-
1
))
target
=
model
.
get_targets
(
sample
,
net_output
).
view
(
-
1
)
loss
=
F
.
nll_loss
(
lprobs
,
target
,
ignore_index
=
self
.
padding_idx
,
reduction
=
"sum"
if
reduce
else
"none"
,
)
return
loss
,
loss
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
"""Aggregate logging outputs from data parallel training."""
loss_sum
=
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
)
ntokens
=
sum
(
log
.
get
(
"ntokens"
,
0
)
for
log
in
logging_outputs
)
sample_size
=
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
# we divide by log(2) to convert the loss from base e to base 2
metrics
.
log_scalar
(
"loss"
,
loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
)
if
sample_size
!=
ntokens
:
metrics
.
log_scalar
(
"nll_loss"
,
loss_sum
/
ntokens
/
math
.
log
(
2
),
ntokens
,
round
=
3
)
metrics
.
log_derived
(
"ppl"
,
lambda
meters
:
utils
.
get_perplexity
(
meters
[
"nll_loss"
].
avg
)
)
else
:
metrics
.
log_derived
(
"ppl"
,
lambda
meters
:
utils
.
get_perplexity
(
meters
[
"loss"
].
avg
)
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return
True
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/criterions/ctc.py
0 → 100644
View file @
39ac40a9
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
math
from
argparse
import
Namespace
from
dataclasses
import
dataclass
,
field
from
omegaconf
import
II
from
typing
import
Optional
import
torch
import
torch.nn.functional
as
F
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
from
fairseq.dataclass
import
FairseqDataclass
from
fairseq.data.data_utils
import
post_process
from
fairseq.tasks
import
FairseqTask
from
fairseq.logging.meters
import
safe_round
@
dataclass
class
CtcCriterionConfig
(
FairseqDataclass
):
zero_infinity
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"zero inf loss when source length <= target length"
},
)
sentence_avg
:
bool
=
II
(
"optimization.sentence_avg"
)
post_process
:
str
=
field
(
default
=
"letter"
,
metadata
=
{
"help"
:
"how to post process predictions into words. can be letter, "
"wordpiece, BPE symbols, etc. "
"See fairseq.data.data_utils.post_process() for full list of options"
},
)
wer_kenlm_model
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"if this is provided, use kenlm to compute wer (along with other wer_* args)"
},
)
wer_lexicon
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"lexicon to use with wer_kenlm_model"
},
)
wer_lm_weight
:
float
=
field
(
default
=
2.0
,
metadata
=
{
"help"
:
"lm weight to use with wer_kenlm_model"
},
)
wer_word_score
:
float
=
field
(
default
=-
1.0
,
metadata
=
{
"help"
:
"lm word score to use with wer_kenlm_model"
},
)
wer_args
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"DEPRECATED: tuple of (wer_kenlm_model, wer_lexicon, wer_lm_weight, wer_word_score)"
},
)
@
register_criterion
(
"ctc"
,
dataclass
=
CtcCriterionConfig
)
class
CtcCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
cfg
:
CtcCriterionConfig
,
task
:
FairseqTask
):
super
().
__init__
(
task
)
self
.
blank_idx
=
(
task
.
target_dictionary
.
index
(
task
.
blank_symbol
)
if
hasattr
(
task
,
"blank_symbol"
)
else
0
)
self
.
pad_idx
=
task
.
target_dictionary
.
pad
()
self
.
eos_idx
=
task
.
target_dictionary
.
eos
()
self
.
post_process
=
cfg
.
post_process
if
cfg
.
wer_args
is
not
None
:
(
cfg
.
wer_kenlm_model
,
cfg
.
wer_lexicon
,
cfg
.
wer_lm_weight
,
cfg
.
wer_word_score
,
)
=
eval
(
cfg
.
wer_args
)
if
cfg
.
wer_kenlm_model
is
not
None
:
from
examples.speech_recognition.w2l_decoder
import
W2lKenLMDecoder
dec_args
=
Namespace
()
dec_args
.
nbest
=
1
dec_args
.
criterion
=
"ctc"
dec_args
.
kenlm_model
=
cfg
.
wer_kenlm_model
dec_args
.
lexicon
=
cfg
.
wer_lexicon
dec_args
.
beam
=
50
dec_args
.
beam_size_token
=
min
(
50
,
len
(
task
.
target_dictionary
))
dec_args
.
beam_threshold
=
min
(
50
,
len
(
task
.
target_dictionary
))
dec_args
.
lm_weight
=
cfg
.
wer_lm_weight
dec_args
.
word_score
=
cfg
.
wer_word_score
dec_args
.
unk_weight
=
-
math
.
inf
dec_args
.
sil_weight
=
0
self
.
w2l_decoder
=
W2lKenLMDecoder
(
dec_args
,
task
.
target_dictionary
)
else
:
self
.
w2l_decoder
=
None
self
.
zero_infinity
=
cfg
.
zero_infinity
self
.
sentence_avg
=
cfg
.
sentence_avg
def
get_net_output
(
self
,
model
,
sample
):
net_output
=
model
(
**
sample
[
"net_input"
])
return
net_output
def
get_loss
(
self
,
model
,
sample
,
net_output
,
reduce
=
True
):
lprobs
=
model
.
get_normalized_probs
(
net_output
,
log_probs
=
True
).
contiguous
()
# (T, B, C) from the encoder
if
"src_lengths"
in
sample
[
"net_input"
]:
input_lengths
=
sample
[
"net_input"
][
"src_lengths"
]
else
:
if
net_output
[
"padding_mask"
]
is
not
None
:
non_padding_mask
=
~
net_output
[
"padding_mask"
]
input_lengths
=
non_padding_mask
.
long
().
sum
(
-
1
)
else
:
input_lengths
=
lprobs
.
new_full
(
(
lprobs
.
size
(
1
),),
lprobs
.
size
(
0
),
dtype
=
torch
.
long
)
pad_mask
=
(
sample
[
"target"
]
!=
self
.
pad_idx
)
&
(
sample
[
"target"
]
!=
self
.
eos_idx
)
targets_flat
=
sample
[
"target"
].
masked_select
(
pad_mask
)
if
"target_lengths"
in
sample
:
target_lengths
=
sample
[
"target_lengths"
]
else
:
target_lengths
=
pad_mask
.
sum
(
-
1
)
with
torch
.
backends
.
cudnn
.
flags
(
enabled
=
False
):
loss
=
F
.
ctc_loss
(
lprobs
,
targets_flat
,
input_lengths
,
target_lengths
,
blank
=
self
.
blank_idx
,
reduction
=
"sum"
,
zero_infinity
=
self
.
zero_infinity
,
)
ntokens
=
(
sample
[
"ntokens"
]
if
"ntokens"
in
sample
else
target_lengths
.
sum
().
item
()
)
sample_size
=
sample
[
"target"
].
size
(
0
)
if
self
.
sentence_avg
else
ntokens
logging_output
=
{
"loss"
:
utils
.
item
(
loss
.
data
),
# * sample['ntokens'],
"ntokens"
:
ntokens
,
"nsentences"
:
sample
[
"id"
].
numel
(),
"sample_size"
:
sample_size
,
}
if
not
model
.
training
:
import
editdistance
with
torch
.
no_grad
():
lprobs_t
=
lprobs
.
transpose
(
0
,
1
).
float
().
contiguous
().
cpu
()
c_err
=
0
c_len
=
0
w_errs
=
0
w_len
=
0
wv_errs
=
0
for
lp
,
t
,
inp_l
in
zip
(
lprobs_t
,
sample
[
"target_label"
]
if
"target_label"
in
sample
else
sample
[
"target"
],
input_lengths
,
):
lp
=
lp
[:
inp_l
].
unsqueeze
(
0
)
decoded
=
None
if
self
.
w2l_decoder
is
not
None
:
decoded
=
self
.
w2l_decoder
.
decode
(
lp
)
if
len
(
decoded
)
<
1
:
decoded
=
None
else
:
decoded
=
decoded
[
0
]
if
len
(
decoded
)
<
1
:
decoded
=
None
else
:
decoded
=
decoded
[
0
]
p
=
(
t
!=
self
.
task
.
target_dictionary
.
pad
())
&
(
t
!=
self
.
task
.
target_dictionary
.
eos
()
)
targ
=
t
[
p
]
targ_units
=
self
.
task
.
target_dictionary
.
string
(
targ
)
targ_units_arr
=
targ
.
tolist
()
toks
=
lp
.
argmax
(
dim
=-
1
).
unique_consecutive
()
pred_units_arr
=
toks
[
toks
!=
self
.
blank_idx
].
tolist
()
c_err
+=
editdistance
.
eval
(
pred_units_arr
,
targ_units_arr
)
c_len
+=
len
(
targ_units_arr
)
targ_words
=
post_process
(
targ_units
,
self
.
post_process
).
split
()
pred_units
=
self
.
task
.
target_dictionary
.
string
(
pred_units_arr
)
pred_words_raw
=
post_process
(
pred_units
,
self
.
post_process
).
split
()
if
decoded
is
not
None
and
"words"
in
decoded
:
pred_words
=
decoded
[
"words"
]
w_errs
+=
editdistance
.
eval
(
pred_words
,
targ_words
)
wv_errs
+=
editdistance
.
eval
(
pred_words_raw
,
targ_words
)
else
:
dist
=
editdistance
.
eval
(
pred_words_raw
,
targ_words
)
w_errs
+=
dist
wv_errs
+=
dist
w_len
+=
len
(
targ_words
)
logging_output
[
"wv_errors"
]
=
wv_errs
logging_output
[
"w_errors"
]
=
w_errs
logging_output
[
"w_total"
]
=
w_len
logging_output
[
"c_errors"
]
=
c_err
logging_output
[
"c_total"
]
=
c_len
return
loss
,
sample_size
,
logging_output
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
net_output
=
self
.
get_net_output
(
model
,
sample
)
loss
,
sample_size
,
logging_output
=
self
.
get_loss
(
model
,
sample
,
net_output
,
reduce
)
return
loss
,
sample_size
,
logging_output
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
"""Aggregate logging outputs from data parallel training."""
loss_sum
=
utils
.
item
(
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
))
ntokens
=
utils
.
item
(
sum
(
log
.
get
(
"ntokens"
,
0
)
for
log
in
logging_outputs
))
nsentences
=
utils
.
item
(
sum
(
log
.
get
(
"nsentences"
,
0
)
for
log
in
logging_outputs
)
)
sample_size
=
utils
.
item
(
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
)
metrics
.
log_scalar
(
"loss"
,
loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
)
metrics
.
log_scalar
(
"ntokens"
,
ntokens
)
metrics
.
log_scalar
(
"nsentences"
,
nsentences
)
if
sample_size
!=
ntokens
:
metrics
.
log_scalar
(
"nll_loss"
,
loss_sum
/
ntokens
/
math
.
log
(
2
),
ntokens
,
round
=
3
)
c_errors
=
sum
(
log
.
get
(
"c_errors"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_c_errors"
,
c_errors
)
c_total
=
sum
(
log
.
get
(
"c_total"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_c_total"
,
c_total
)
w_errors
=
sum
(
log
.
get
(
"w_errors"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_w_errors"
,
w_errors
)
wv_errors
=
sum
(
log
.
get
(
"wv_errors"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_wv_errors"
,
wv_errors
)
w_total
=
sum
(
log
.
get
(
"w_total"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_w_total"
,
w_total
)
if
c_total
>
0
:
metrics
.
log_derived
(
"uer"
,
lambda
meters
:
safe_round
(
meters
[
"_c_errors"
].
sum
*
100.0
/
meters
[
"_c_total"
].
sum
,
3
)
if
meters
[
"_c_total"
].
sum
>
0
else
float
(
"nan"
),
)
if
w_total
>
0
:
metrics
.
log_derived
(
"wer"
,
lambda
meters
:
safe_round
(
meters
[
"_w_errors"
].
sum
*
100.0
/
meters
[
"_w_total"
].
sum
,
3
)
if
meters
[
"_w_total"
].
sum
>
0
else
float
(
"nan"
),
)
metrics
.
log_derived
(
"raw_wer"
,
lambda
meters
:
safe_round
(
meters
[
"_wv_errors"
].
sum
*
100.0
/
meters
[
"_w_total"
].
sum
,
3
)
if
meters
[
"_w_total"
].
sum
>
0
else
float
(
"nan"
),
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return
True
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/criterions/fairseq_criterion.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
inspect
from
typing
import
Any
,
Dict
,
List
from
fairseq
import
metrics
,
utils
from
fairseq.dataclass
import
FairseqDataclass
from
fairseq.dataclass.utils
import
gen_parser_from_dataclass
from
torch.nn.modules.loss
import
_Loss
class
FairseqCriterion
(
_Loss
):
def
__init__
(
self
,
task
):
super
().
__init__
()
self
.
task
=
task
if
hasattr
(
task
,
"target_dictionary"
):
tgt_dict
=
task
.
target_dictionary
self
.
padding_idx
=
tgt_dict
.
pad
()
if
tgt_dict
is
not
None
else
-
100
@
classmethod
def
add_args
(
cls
,
parser
):
"""Add criterion-specific arguments to the parser."""
dc
=
getattr
(
cls
,
"__dataclass"
,
None
)
if
dc
is
not
None
:
gen_parser_from_dataclass
(
parser
,
dc
())
@
classmethod
def
build_criterion
(
cls
,
cfg
:
FairseqDataclass
,
task
):
"""Construct a criterion from command-line args."""
# arguments in the __init__.
init_args
=
{}
for
p
in
inspect
.
signature
(
cls
).
parameters
.
values
():
if
(
p
.
kind
==
p
.
POSITIONAL_ONLY
or
p
.
kind
==
p
.
VAR_POSITIONAL
or
p
.
kind
==
p
.
VAR_KEYWORD
):
# we haven't implemented inference for these argument types,
# but PRs welcome :)
raise
NotImplementedError
(
"{} not supported"
.
format
(
p
.
kind
))
assert
p
.
kind
in
{
p
.
POSITIONAL_OR_KEYWORD
,
p
.
KEYWORD_ONLY
}
if
p
.
name
==
"task"
:
init_args
[
"task"
]
=
task
elif
p
.
name
==
"cfg"
:
init_args
[
"cfg"
]
=
cfg
elif
hasattr
(
cfg
,
p
.
name
):
init_args
[
p
.
name
]
=
getattr
(
cfg
,
p
.
name
)
elif
p
.
default
!=
p
.
empty
:
pass
# we'll use the default value
else
:
raise
NotImplementedError
(
"Unable to infer Criterion arguments, please implement "
"{}.build_criterion"
.
format
(
cls
.
__name__
)
)
return
cls
(
**
init_args
)
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
raise
NotImplementedError
@
staticmethod
def
aggregate_logging_outputs
(
logging_outputs
:
List
[
Dict
[
str
,
Any
]]
)
->
Dict
[
str
,
Any
]:
"""Aggregate logging outputs from data parallel training."""
utils
.
deprecation_warning
(
"The aggregate_logging_outputs API is deprecated. "
"Please use the reduce_metrics API instead."
)
raise
NotImplementedError
@
classmethod
def
reduce_metrics
(
cls
,
logging_outputs
:
List
[
Dict
[
str
,
Any
]])
->
None
:
"""Aggregate logging outputs from data parallel training."""
utils
.
deprecation_warning
(
"Criterions should implement the reduce_metrics API. "
"Falling back to deprecated aggregate_logging_outputs API."
)
agg_logging_outputs
=
cls
.
aggregate_logging_outputs
(
logging_outputs
)
for
k
,
v
in
agg_logging_outputs
.
items
():
if
k
in
{
"nsentences"
,
"ntokens"
,
"sample_size"
}:
continue
metrics
.
log_scalar
(
k
,
v
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return
False
class
LegacyFairseqCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
args
,
task
):
super
().
__init__
(
task
=
task
)
self
.
args
=
args
utils
.
deprecation_warning
(
"Criterions should take explicit arguments instead of an "
"argparse.Namespace object, please update your criterion by "
"extending FairseqCriterion instead of LegacyFairseqCriterion."
)
@
classmethod
def
build_criterion
(
cls
,
args
,
task
):
"""Construct a criterion from command-line args."""
return
cls
(
args
,
task
)
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/criterions/hubert_criterion.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
import
re
from
dataclasses
import
dataclass
,
field
from
typing
import
List
,
Optional
import
torch
import
torch.nn.functional
as
F
from
fairseq
import
metrics
,
utils
from
fairseq.models.hubert
import
ILSHubertModel
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
from
fairseq.dataclass
import
FairseqDataclass
@
dataclass
class
HubertCriterionConfig
(
FairseqDataclass
):
pred_masked_weight
:
float
=
field
(
default
=
1.0
,
metadata
=
{
"help"
:
"weight for predictive loss for masked frames"
},
)
pred_nomask_weight
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"weight for predictive loss for unmasked frames"
},
)
loss_weights
:
Optional
[
List
[
float
]]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"weights for additional loss terms (not first one)"
},
)
log_keys
:
List
[
str
]
=
field
(
default_factory
=
lambda
:
[],
metadata
=
{
"help"
:
"output keys to log"
},
)
@
register_criterion
(
"hubert"
,
dataclass
=
HubertCriterionConfig
)
class
HubertCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
task
,
pred_masked_weight
,
pred_nomask_weight
,
loss_weights
=
None
,
log_keys
=
None
):
super
().
__init__
(
task
)
self
.
pred_masked_weight
=
pred_masked_weight
self
.
pred_nomask_weight
=
pred_nomask_weight
self
.
loss_weights
=
loss_weights
self
.
log_keys
=
[]
if
log_keys
is
None
else
log_keys
def
get_net_output
(
self
,
model
,
sample
):
"""compute the loss for the given sample"""
net_output
=
model
(
target_list
=
sample
[
"target_list"
],
**
sample
[
"net_input"
])
return
net_output
def
get_loss
(
self
,
model
,
sample
,
net_output
,
reduce
=
True
,
log_pred
=
False
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
loss
=
0.
sample_size
=
0
logging_output
=
{}
reduction
=
"sum"
if
reduce
else
"none"
loss_m_list
=
[]
logp_m_list
=
model
.
get_logits
(
net_output
,
True
)
targ_m_list
=
model
.
get_targets
(
net_output
,
True
)
assert
self
.
pred_masked_weight
==
0
or
len
(
logp_m_list
)
>
0
for
i
,
(
logp_m
,
targ_m
)
in
enumerate
(
zip
(
logp_m_list
,
targ_m_list
)):
loss_m
=
F
.
cross_entropy
(
logp_m
,
targ_m
,
reduction
=
reduction
)
loss_m_list
.
append
(
loss_m
)
logging_output
[
f
"loss_m_
{
i
}
"
]
=
loss_m
.
detach
().
item
()
if
self
.
pred_masked_weight
>
0
:
if
isinstance
(
model
,
ILSHubertModel
):
if
model
.
weighted_sum
:
norm_weights
=
F
.
softmax
(
model
.
weights
,
dim
=-
1
)
loss_m_list
=
norm_weights
*
torch
.
stack
(
loss_m_list
,
dim
=
0
)
loss
+=
self
.
pred_masked_weight
*
sum
(
loss_m_list
)
sample_size
+=
targ_m_list
[
0
].
numel
()
loss_u_list
=
[]
logp_u_list
=
model
.
get_logits
(
net_output
,
False
)
targ_u_list
=
model
.
get_targets
(
net_output
,
False
)
assert
self
.
pred_nomask_weight
==
0
or
len
(
logp_u_list
)
>
0
for
i
,
(
logp_u
,
targ_u
)
in
enumerate
(
zip
(
logp_u_list
,
targ_u_list
)):
loss_u
=
F
.
cross_entropy
(
logp_u
,
targ_u
,
reduction
=
reduction
)
loss_u_list
.
append
(
loss_u
)
logging_output
[
f
"loss_u_
{
i
}
"
]
=
loss_u
.
detach
().
item
()
if
self
.
pred_nomask_weight
>
0
:
if
model
.
weighted_sum
:
norm_weights
=
F
.
softmax
(
model
.
weights
,
dim
=-
1
)
loss_u_list
=
norm_weights
*
torch
.
stack
(
loss_u_list
,
dim
=
0
)
loss
+=
self
.
pred_nomask_weight
*
sum
(
loss_u_list
)
sample_size
+=
targ_u_list
[
0
].
numel
()
if
self
.
loss_weights
is
not
None
:
assert
hasattr
(
model
,
"get_extra_losses"
)
extra_losses
,
names
=
model
.
get_extra_losses
(
net_output
)
if
torch
.
is_tensor
(
extra_losses
):
extra_losses
=
[
extra_losses
]
names
=
[
names
]
if
len
(
self
.
loss_weights
)
==
1
and
len
(
extra_losses
)
!=
1
:
self
.
loss_weights
=
[
self
.
loss_weights
[
0
]]
*
len
(
extra_losses
)
assert
len
(
extra_losses
)
==
len
(
self
.
loss_weights
),
f
"
{
len
(
extra_losses
)
}
,
{
len
(
self
.
loss_weights
)
}
"
for
p
,
n
,
coef
in
zip
(
extra_losses
,
names
,
self
.
loss_weights
):
if
coef
!=
0
and
p
is
not
None
:
p
=
coef
*
p
.
float
()
*
sample_size
loss
+=
p
logging_output
[
f
"loss_
{
n
}
"
]
=
p
.
item
()
logging_output
=
{
"loss"
:
loss
.
item
()
if
reduce
else
loss
,
"ntokens"
:
sample_size
,
"nsentences"
:
sample
[
"id"
].
numel
(),
"sample_size"
:
sample_size
,
**
logging_output
,
}
for
lk
in
self
.
log_keys
:
if
lk
in
net_output
:
logging_output
[
lk
]
=
float
((
net_output
[
lk
]))
def
compute_correct
(
logits
):
if
logits
.
numel
()
==
0
:
return
0
,
0
else
:
assert
logits
.
dim
()
>
1
,
logits
.
shape
max
=
logits
.
argmax
(
-
1
)
==
0
min
=
logits
.
argmin
(
-
1
)
==
0
both
=
max
&
min
corr
=
max
.
long
().
sum
().
item
()
-
both
.
long
().
sum
().
item
()
count
=
max
.
numel
()
return
corr
,
count
with
torch
.
no_grad
():
for
i
,
logp_m
in
enumerate
(
logp_m_list
):
corr_m
,
count_m
=
compute_correct
(
logp_m
)
logging_output
[
f
"correct_m_
{
i
}
"
]
=
corr_m
logging_output
[
f
"count_m_
{
i
}
"
]
=
count_m
for
i
,
logp_u
in
enumerate
(
logp_u_list
):
corr_u
,
count_u
=
compute_correct
(
logp_u
)
logging_output
[
f
"correct_u_
{
i
}
"
]
=
corr_u
logging_output
[
f
"count_u_
{
i
}
"
]
=
count_u
return
loss
,
sample_size
,
logging_output
def
forward
(
self
,
model
,
sample
,
reduce
=
True
,
log_pred
=
False
):
net_output
=
self
.
get_net_output
(
model
,
sample
)
loss
,
sample_size
,
logging_output
=
self
.
get_loss
(
model
,
sample
,
net_output
,
reduce
,
log_pred
)
return
loss
,
sample_size
,
logging_output
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
"""Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
loss_sum
=
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
)
ntokens
=
sum
(
log
.
get
(
"ntokens"
,
0
)
for
log
in
logging_outputs
)
sample_size
=
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"loss"
,
loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
)
if
sample_size
!=
ntokens
:
metrics
.
log_scalar
(
"nll_loss"
,
loss_sum
/
ntokens
/
math
.
log
(
2
),
ntokens
,
round
=
3
)
metrics
.
log_derived
(
"ppl"
,
lambda
meters
:
utils
.
get_perplexity
(
meters
[
"nll_loss"
].
avg
))
else
:
metrics
.
log_derived
(
"ppl"
,
lambda
meters
:
utils
.
get_perplexity
(
meters
[
"loss"
].
avg
))
counts
=
{}
for
lk
in
logging_outputs
[
0
].
keys
():
if
lk
.
startswith
(
"count_"
):
val
=
sum
(
log
[
lk
]
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
lk
,
val
)
counts
[
lk
]
=
val
for
lk
in
logging_outputs
[
0
].
keys
():
if
lk
.
startswith
(
"loss_"
):
val
=
sum
(
log
[
lk
]
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
lk
,
val
/
sample_size
/
math
.
log
(
2
),
round
=
3
)
elif
lk
.
startswith
(
"correct_"
):
val
=
sum
(
log
[
lk
]
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
lk
,
val
/
counts
[
re
.
sub
(
"correct"
,
"count"
,
lk
)])
@
staticmethod
def
aggregate_logging_outputs
(
logging_outputs
):
"""Aggregate logging outputs from data parallel training."""
raise
NotImplementedError
()
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return
False
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/criterions/unispeech_criterion.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import
torch
import
math
from
dataclasses
import
dataclass
,
field
from
fairseq
import
pdb
from
fairseq
import
utils
,
metrics
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
from
fairseq.criterions.wav2vec_criterion
import
Wav2vecCriterion
,
Wav2VecCriterionConfig
from
fairseq.criterions.ctc
import
CtcCriterion
,
CtcCriterionConfig
from
fairseq.logging.meters
import
safe_round
@
dataclass
class
UnispeechCriterionConfig
(
Wav2VecCriterionConfig
,
CtcCriterionConfig
):
mtlalpha
:
float
=
field
(
default
=
0.5
,
metadata
=
{
"help"
:
"loss weight for multitask learning"
}
)
@
register_criterion
(
'unispeech_criterion'
,
dataclass
=
UnispeechCriterionConfig
)
class
UnispeechCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
cfg
:
UnispeechCriterionConfig
,
task
):
super
().
__init__
(
task
)
self
.
mtlalpha
=
cfg
.
mtlalpha
self
.
w2v_criterion
=
Wav2vecCriterion
(
task
,
cfg
.
infonce
,
cfg
.
loss_weights
,
cfg
.
log_keys
)
if
self
.
mtlalpha
>
0
:
self
.
ctc_criterion
=
CtcCriterion
(
cfg
,
task
)
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
net_output
=
model
(
**
sample
[
"net_input"
])
if
self
.
mtlalpha
>
0.0
:
ctc_loss
,
ctc_sample_size
,
ctc_logging_output
=
self
.
ctc_criterion
.
get_loss
(
model
,
sample
,
net_output
,
reduce
)
else
:
ctc_loss
=
0
ctc_sample_size
=
0
ctc_logging_output
=
{}
infonce_loss
,
infonce_sample_size
,
infonce_logging_output
=
self
.
w2v_criterion
.
get_loss
(
model
.
w2v_encoder
.
w2v_model
,
sample
,
net_output
[
'contrastive_res'
],
reduce
)
loss
=
self
.
mtlalpha
*
ctc_loss
+
(
1.0
-
self
.
mtlalpha
)
*
infonce_loss
sample_size
=
infonce_sample_size
logging_output
=
{
'loss'
:
loss
,
'ntokens'
:
ctc_logging_output
[
'ntokens'
],
'nsentences'
:
ctc_logging_output
[
'nsentences'
],
'ctc'
:
ctc_logging_output
,
'infonce'
:
infonce_logging_output
}
return
loss
,
sample_size
,
logging_output
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
return
False
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
loss_sum
=
utils
.
item
(
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
))
ctc_loss_sum
=
utils
.
item
(
sum
(
log
[
'ctc'
].
get
(
'loss'
,
0
)
for
log
in
logging_outputs
))
ctc_sample_size
=
utils
.
item
(
sum
(
log
[
'ctc'
].
get
(
'sample_size'
,
0
)
for
log
in
logging_outputs
))
ctc_ntokens
=
utils
.
item
(
sum
(
log
[
'ctc'
].
get
(
'ntokens'
,
0
)
for
log
in
logging_outputs
))
ctc_nsentences
=
utils
.
item
(
sum
(
log
[
'ctc'
].
get
(
'nsentences'
,
0
)
for
log
in
logging_outputs
))
ctras_loss_sum
=
utils
.
item
(
sum
(
log
[
'infonce'
].
get
(
'loss'
,
0
)
for
log
in
logging_outputs
))
ctras_sample_size
=
utils
.
item
(
sum
(
log
[
'infonce'
].
get
(
'sample_size'
,
0
)
for
log
in
logging_outputs
))
ctras_ntokens
=
utils
.
item
(
sum
(
log
[
'infonce'
].
get
(
'ntokens'
,
0
)
for
log
in
logging_outputs
))
ctras_nsentences
=
utils
.
item
(
sum
(
log
[
'infonce'
].
get
(
'nsentences'
,
0
)
for
log
in
logging_outputs
))
metrics
.
log_scalar
(
"loss"
,
loss_sum
,
1
,
round
=
3
)
metrics
.
log_scalar
(
"ctc_loss"
,
ctc_loss_sum
/
ctc_sample_size
/
math
.
log
(
2
),
ctc_sample_size
,
round
=
3
)
metrics
.
log_scalar
(
"contrastive_loss"
,
ctras_loss_sum
/
ctras_sample_size
/
math
.
log
(
2
),
ctras_sample_size
,
round
=
3
)
if
ctc_sample_size
!=
ctc_ntokens
:
metrics
.
log_scalar
(
"nll_loss"
,
ctc_loss_sum
/
ctc_ntokens
/
math
.
log
(
2
),
ctc_ntokens
,
round
=
3
)
c_errors
=
sum
(
log
[
'ctc'
].
get
(
"c_errors"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_c_errors"
,
c_errors
)
c_total
=
sum
(
log
[
'ctc'
].
get
(
"c_total"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_c_total"
,
c_total
)
w_errors
=
sum
(
log
[
'ctc'
].
get
(
"w_errors"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_w_errors"
,
w_errors
)
wv_errors
=
sum
(
log
[
'ctc'
].
get
(
"wv_errors"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_wv_errors"
,
wv_errors
)
w_total
=
sum
(
log
[
'ctc'
].
get
(
"w_total"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_w_total"
,
w_total
)
if
c_total
>
0
:
metrics
.
log_derived
(
"uer"
,
lambda
meters
:
safe_round
(
meters
[
"_c_errors"
].
sum
*
100.0
/
meters
[
"_c_total"
].
sum
,
3
)
if
meters
[
"_c_total"
].
sum
>
0
else
float
(
"nan"
),
)
if
w_total
>
0
:
metrics
.
log_derived
(
"wer"
,
lambda
meters
:
safe_round
(
meters
[
"_w_errors"
].
sum
*
100.0
/
meters
[
"_w_total"
].
sum
,
3
)
if
meters
[
"_w_total"
].
sum
>
0
else
float
(
"nan"
),
)
metrics
.
log_derived
(
"raw_wer"
,
lambda
meters
:
safe_round
(
meters
[
"_wv_errors"
].
sum
*
100.0
/
meters
[
"_w_total"
].
sum
,
3
)
if
meters
[
"_w_total"
].
sum
>
0
else
float
(
"nan"
),
)
metrics
.
log_scalar
(
"nsentences"
,
ctras_nsentences
)
metrics
.
log_scalar
(
"ctc_sample_size"
,
ctc_sample_size
)
metrics
.
log_scalar
(
"contrastive_sample_size"
,
ctras_sample_size
)
correct
=
sum
(
log
[
'infonce'
].
get
(
"correct"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_correct"
,
correct
)
total
=
sum
(
log
[
'infonce'
].
get
(
"count"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_total"
,
total
)
if
total
>
0
:
metrics
.
log_derived
(
"accuracy"
,
lambda
meters
:
safe_round
(
meters
[
"_correct"
].
sum
/
meters
[
"_total"
].
sum
,
5
)
if
meters
[
"_total"
].
sum
>
0
else
float
(
"nan"
),
)
builtin_keys
=
{
'loss'
,
'ntokens'
,
'nsentences'
,
'sample_size'
,
'correct'
,
'count'
}
for
k
in
logging_outputs
[
0
][
'infonce'
]:
if
k
not
in
builtin_keys
:
val
=
sum
(
log
[
'infonce'
].
get
(
k
,
0
)
for
log
in
logging_outputs
)
/
len
(
logging_outputs
)
if
k
.
startswith
(
'loss'
):
metrics
.
log_scalar
(
k
,
val
/
ctras_sample_size
/
math
.
log
(
2
),
ctras_sample_size
)
else
:
metrics
.
log_scalar
(
k
,
val
,
round
=
3
)
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/criterions/wav2vec_criterion.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
dataclasses
import
dataclass
,
field
from
typing
import
List
,
Optional
import
torch
import
torch.nn.functional
as
F
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
from
fairseq.dataclass
import
FairseqDataclass
from
fairseq.logging.meters
import
safe_round
from
fairseq.utils
import
is_xla_tensor
@
dataclass
class
Wav2VecCriterionConfig
(
FairseqDataclass
):
infonce
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"if set, uses cross entropy instead of binary cross entropy (i.e. InfoNCE loss)"
},
)
loss_weights
:
Optional
[
List
[
float
]]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"weights for additional loss terms (not first one)"
},
)
log_keys
:
List
[
str
]
=
field
(
default_factory
=
lambda
:
[],
metadata
=
{
"help"
:
"output keys to log"
},
)
@
register_criterion
(
"wav2vec"
,
dataclass
=
Wav2VecCriterionConfig
)
class
Wav2vecCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
task
,
infonce
=
False
,
loss_weights
=
None
,
log_keys
=
None
):
super
().
__init__
(
task
)
self
.
infonce
=
infonce
self
.
loss_weights
=
loss_weights
self
.
log_keys
=
[]
if
log_keys
is
None
else
log_keys
def
get_loss
(
self
,
model
,
sample
,
net_output
,
reduce
=
True
,
log_pred
=
False
):
logits
=
model
.
get_logits
(
net_output
).
float
()
target
=
model
.
get_targets
(
sample
,
net_output
)
weights
=
None
if
hasattr
(
model
,
"get_target_weights"
)
and
not
self
.
infonce
:
weights
=
model
.
get_target_weights
(
target
,
net_output
)
if
torch
.
is_tensor
(
weights
):
weights
=
weights
.
float
()
losses
=
[]
if
self
.
infonce
:
loss
=
F
.
cross_entropy
(
logits
,
target
,
reduction
=
"sum"
if
reduce
else
"none"
,
)
else
:
loss
=
F
.
binary_cross_entropy_with_logits
(
logits
,
target
.
float
(),
weights
,
reduction
=
"sum"
if
reduce
else
"none"
,
)
sample_size
=
target
.
numel
()
if
self
.
infonce
else
target
.
long
().
sum
().
item
()
losses
.
append
(
loss
.
detach
().
clone
())
if
self
.
loss_weights
is
not
None
:
assert
hasattr
(
model
,
"get_extra_losses"
)
extra_losses
=
model
.
get_extra_losses
(
net_output
)
if
torch
.
is_tensor
(
extra_losses
):
extra_losses
=
[
extra_losses
]
if
len
(
self
.
loss_weights
)
==
1
and
len
(
extra_losses
)
!=
1
:
self
.
loss_weights
=
[
self
.
loss_weights
[
0
]]
*
len
(
extra_losses
)
assert
len
(
extra_losses
)
==
len
(
self
.
loss_weights
),
f
"
{
len
(
extra_losses
)
}
,
{
len
(
self
.
loss_weights
)
}
"
for
p
,
coef
in
zip
(
extra_losses
,
self
.
loss_weights
):
if
coef
!=
0
and
p
is
not
None
:
p
=
coef
*
p
.
float
()
*
sample_size
loss
+=
p
losses
.
append
(
p
)
logging_output
=
{
"loss"
:
loss
.
item
()
if
reduce
else
loss
,
"ntokens"
:
sample_size
,
"nsentences"
:
sample
[
"id"
].
numel
(),
"sample_size"
:
sample_size
,
}
for
lk
in
self
.
log_keys
:
if
lk
in
net_output
:
logging_output
[
lk
]
=
float
((
net_output
[
lk
]))
if
len
(
losses
)
>
1
:
for
i
,
l
in
enumerate
(
losses
):
logging_output
[
f
"loss_
{
i
}
"
]
=
l
.
item
()
if
self
.
infonce
:
with
torch
.
no_grad
():
if
logits
.
numel
()
==
0
:
corr
=
0
count
=
0
else
:
assert
logits
.
dim
()
>
1
,
logits
.
shape
max
=
logits
.
argmax
(
-
1
)
==
0
min
=
logits
.
argmin
(
-
1
)
==
0
both
=
max
&
min
corr
=
max
.
long
().
sum
().
item
()
-
both
.
long
().
sum
().
item
()
count
=
max
.
numel
()
logging_output
[
"correct"
]
=
corr
logging_output
[
"count"
]
=
count
if
log_pred
:
logging_output
[
"logits"
]
=
logits
.
cpu
().
numpy
()
logging_output
[
"target"
]
=
target
.
cpu
().
numpy
()
return
loss
,
sample_size
,
logging_output
def
get_net_output
(
self
,
model
,
sample
):
net_output
=
model
(
**
sample
[
"net_input"
])
return
net_output
def
forward
(
self
,
model
,
sample
,
reduce
=
True
,
log_pred
=
False
):
net_output
=
self
.
get_net_output
(
model
,
sample
)
loss
,
sample_size
,
logging_output
=
self
.
get_loss
(
model
,
sample
,
net_output
,
reduce
,
log_pred
)
return
loss
,
sample_size
,
logging_output
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
"""Aggregate logging outputs from data parallel training."""
loss_sum
=
utils
.
item
(
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
))
ntokens
=
utils
.
item
(
sum
(
log
.
get
(
"ntokens"
,
0
)
for
log
in
logging_outputs
))
nsentences
=
utils
.
item
(
sum
(
log
.
get
(
"nsentences"
,
0
)
for
log
in
logging_outputs
)
)
sample_size
=
utils
.
item
(
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
)
metrics
.
log_scalar
(
"loss"
,
loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
)
metrics
.
log_scalar
(
"ntokens"
,
ntokens
)
metrics
.
log_scalar
(
"nsentences"
,
nsentences
)
correct
=
sum
(
log
.
get
(
"correct"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_correct"
,
correct
)
total
=
sum
(
log
.
get
(
"count"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_total"
,
total
)
if
total
>
0
:
metrics
.
log_derived
(
"accuracy"
,
lambda
meters
:
safe_round
(
meters
[
"_correct"
].
sum
/
meters
[
"_total"
].
sum
,
5
)
if
meters
[
"_total"
].
sum
>
0
else
float
(
"nan"
),
)
builtin_keys
=
{
"loss"
,
"ntokens"
,
"nsentences"
,
"sample_size"
,
"correct"
,
"count"
,
}
for
k
in
logging_outputs
[
0
]:
if
k
not
in
builtin_keys
:
val
=
sum
(
log
.
get
(
k
,
0
)
for
log
in
logging_outputs
)
/
len
(
logging_outputs
)
if
k
.
startswith
(
"loss"
):
metrics
.
log_scalar
(
k
,
val
/
sample_size
/
math
.
log
(
2
),
sample_size
)
else
:
metrics
.
log_scalar
(
k
,
val
,
round
=
3
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return
False
\ No newline at end of file
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/criterions/wavlm_criterion.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
import
re
from
dataclasses
import
dataclass
,
field
from
typing
import
List
,
Optional
import
torch
import
torch.nn.functional
as
F
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
from
fairseq.dataclass
import
FairseqDataclass
@
dataclass
class
WavLMCriterionConfig
(
FairseqDataclass
):
pred_masked_weight
:
float
=
field
(
default
=
1.0
,
metadata
=
{
"help"
:
"weight for predictive loss for masked frames"
},
)
pred_nomask_weight
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"weight for predictive loss for unmasked frames"
},
)
loss_weights
:
Optional
[
List
[
float
]]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"weights for additional loss terms (not first one)"
},
)
log_keys
:
List
[
str
]
=
field
(
default_factory
=
lambda
:
[],
metadata
=
{
"help"
:
"output keys to log"
},
)
@
register_criterion
(
"wavlm"
,
dataclass
=
WavLMCriterionConfig
)
class
WavLMCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
task
,
pred_masked_weight
,
pred_nomask_weight
,
loss_weights
=
None
,
log_keys
=
None
):
super
().
__init__
(
task
)
self
.
pred_masked_weight
=
pred_masked_weight
self
.
pred_nomask_weight
=
pred_nomask_weight
self
.
loss_weights
=
loss_weights
self
.
log_keys
=
[]
if
log_keys
is
None
else
log_keys
def
get_net_output
(
self
,
model
,
sample
):
"""compute the loss for the given sample"""
net_output
=
model
(
target_list
=
sample
[
"target_list"
],
**
sample
[
"net_input"
])
return
net_output
def
get_loss
(
self
,
model
,
sample
,
net_output
,
reduce
=
True
,
log_pred
=
False
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
loss
=
0.
sample_size
=
0
logging_output
=
{}
reduction
=
"sum"
if
reduce
else
"none"
loss_m_list
=
[]
logp_m_list
=
model
.
get_logits
(
net_output
,
True
)
targ_m_list
=
model
.
get_targets
(
net_output
,
True
)
assert
self
.
pred_masked_weight
==
0
or
len
(
logp_m_list
)
>
0
for
i
,
(
logp_m
,
targ_m
)
in
enumerate
(
zip
(
logp_m_list
,
targ_m_list
)):
loss_m
=
F
.
cross_entropy
(
logp_m
,
targ_m
,
reduction
=
reduction
)
loss_m_list
.
append
(
loss_m
)
logging_output
[
f
"loss_m_
{
i
}
"
]
=
loss_m
.
detach
().
item
()
if
self
.
pred_masked_weight
>
0
:
loss
+=
self
.
pred_masked_weight
*
sum
(
loss_m_list
)
sample_size
+=
targ_m_list
[
0
].
numel
()
loss_u_list
=
[]
logp_u_list
=
model
.
get_logits
(
net_output
,
False
)
targ_u_list
=
model
.
get_targets
(
net_output
,
False
)
assert
self
.
pred_nomask_weight
==
0
or
len
(
logp_u_list
)
>
0
for
i
,
(
logp_u
,
targ_u
)
in
enumerate
(
zip
(
logp_u_list
,
targ_u_list
)):
loss_u
=
F
.
cross_entropy
(
logp_u
,
targ_u
,
reduction
=
reduction
)
loss_u_list
.
append
(
loss_u
)
logging_output
[
f
"loss_u_
{
i
}
"
]
=
loss_u
.
detach
().
item
()
if
self
.
pred_nomask_weight
>
0
:
loss
+=
self
.
pred_nomask_weight
*
sum
(
loss_u_list
)
sample_size
+=
targ_u_list
[
0
].
numel
()
if
self
.
loss_weights
is
not
None
:
assert
hasattr
(
model
,
"get_extra_losses"
)
extra_losses
,
names
=
model
.
get_extra_losses
(
net_output
)
if
torch
.
is_tensor
(
extra_losses
):
extra_losses
=
[
extra_losses
]
names
=
[
names
]
if
len
(
self
.
loss_weights
)
==
1
and
len
(
extra_losses
)
!=
1
:
self
.
loss_weights
=
[
self
.
loss_weights
[
0
]]
*
len
(
extra_losses
)
assert
len
(
extra_losses
)
==
len
(
self
.
loss_weights
),
f
"
{
len
(
extra_losses
)
}
,
{
len
(
self
.
loss_weights
)
}
"
for
p
,
n
,
coef
in
zip
(
extra_losses
,
names
,
self
.
loss_weights
):
if
coef
!=
0
and
p
is
not
None
:
p
=
coef
*
p
.
float
()
*
sample_size
loss
+=
p
logging_output
[
f
"loss_
{
n
}
"
]
=
p
.
item
()
logging_output
=
{
"loss"
:
loss
.
item
()
if
reduce
else
loss
,
"ntokens"
:
sample_size
,
"nsentences"
:
sample
[
"id"
].
numel
(),
"sample_size"
:
sample_size
,
**
logging_output
,
}
for
lk
in
self
.
log_keys
:
if
lk
in
net_output
:
logging_output
[
lk
]
=
float
((
net_output
[
lk
]))
def
compute_correct
(
logits
):
if
logits
.
numel
()
==
0
:
return
0
,
0
else
:
assert
logits
.
dim
()
>
1
,
logits
.
shape
max
=
logits
.
argmax
(
-
1
)
==
0
min
=
logits
.
argmin
(
-
1
)
==
0
both
=
max
&
min
corr
=
max
.
long
().
sum
().
item
()
-
both
.
long
().
sum
().
item
()
count
=
max
.
numel
()
return
corr
,
count
with
torch
.
no_grad
():
for
i
,
logp_m
in
enumerate
(
logp_m_list
):
corr_m
,
count_m
=
compute_correct
(
logp_m
)
logging_output
[
f
"correct_m_
{
i
}
"
]
=
corr_m
logging_output
[
f
"count_m_
{
i
}
"
]
=
count_m
for
i
,
logp_u
in
enumerate
(
logp_u_list
):
corr_u
,
count_u
=
compute_correct
(
logp_u
)
logging_output
[
f
"correct_u_
{
i
}
"
]
=
corr_u
logging_output
[
f
"count_u_
{
i
}
"
]
=
count_u
return
loss
,
sample_size
,
logging_output
def
forward
(
self
,
model
,
sample
,
reduce
=
True
,
log_pred
=
False
):
net_output
=
self
.
get_net_output
(
model
,
sample
)
loss
,
sample_size
,
logging_output
=
self
.
get_loss
(
model
,
sample
,
net_output
,
reduce
,
log_pred
)
return
loss
,
sample_size
,
logging_output
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
"""Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
loss_sum
=
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
)
ntokens
=
sum
(
log
.
get
(
"ntokens"
,
0
)
for
log
in
logging_outputs
)
sample_size
=
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
nsentences
=
sum
(
log
.
get
(
"nsentences"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"loss"
,
loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
)
if
sample_size
!=
ntokens
:
metrics
.
log_scalar
(
"nll_loss"
,
loss_sum
/
ntokens
/
math
.
log
(
2
),
ntokens
,
round
=
3
)
metrics
.
log_derived
(
"ppl"
,
lambda
meters
:
utils
.
get_perplexity
(
meters
[
"nll_loss"
].
avg
))
else
:
metrics
.
log_derived
(
"ppl"
,
lambda
meters
:
utils
.
get_perplexity
(
meters
[
"loss"
].
avg
))
metrics
.
log_scalar
(
"ntokens"
,
ntokens
)
metrics
.
log_scalar
(
"nsentences"
,
nsentences
)
builtin_keys
=
{
"loss"
,
"nll_loss"
,
"ppl"
,
"ntokens"
,
"nsentences"
,
"sample_size"
,
"correct"
,
"count"
,
}
counts
=
{}
for
lk
in
logging_outputs
[
0
].
keys
():
if
lk
.
startswith
(
"count_"
):
val
=
sum
(
log
[
lk
]
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
lk
,
val
)
counts
[
lk
]
=
val
builtin_keys
.
add
(
lk
)
for
lk
in
logging_outputs
[
0
].
keys
():
if
lk
.
startswith
(
"loss_"
):
val
=
sum
(
log
[
lk
]
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
lk
,
val
/
sample_size
/
math
.
log
(
2
),
round
=
3
)
elif
lk
.
startswith
(
"correct_"
):
val
=
sum
(
log
[
lk
]
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
lk
,
val
/
counts
[
re
.
sub
(
"correct"
,
"count"
,
lk
)])
elif
lk
not
in
builtin_keys
:
val
=
sum
(
log
.
get
(
lk
,
0
)
for
log
in
logging_outputs
)
/
len
(
logging_outputs
)
metrics
.
log_scalar
(
lk
,
val
,
round
=
3
)
@
staticmethod
def
aggregate_logging_outputs
(
logging_outputs
):
"""Aggregate logging outputs from data parallel training."""
raise
NotImplementedError
()
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return
False
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/__init__.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
from
.dictionary
import
Dictionary
,
TruncatedDictionary
from
.fairseq_dataset
import
FairseqDataset
,
FairseqIterableDataset
from
.base_wrapper_dataset
import
BaseWrapperDataset
from
.add_target_dataset
import
AddTargetDataset
from
.audio.raw_audio_dataset
import
FileAudioDataset
from
.audio.hubert_dataset
import
HubertDataset
from
.audio.utterance_mixing_dataset
import
UtteranceMixingDataset
from
.concat_dataset
import
ConcatDataset
from
.id_dataset
import
IdDataset
from
.resampling_dataset
import
ResamplingDataset
from
.iterators
import
(
CountingIterator
,
EpochBatchIterator
,
GroupedIterator
,
ShardedIterator
,
)
from
.monolingual_dataset
import
MonolingualDataset
__all__
=
[
"AddTargetDataset"
,
"ConcatDataset"
,
"CountingIterator"
,
"Dictionary"
,
"EpochBatchIterator"
,
"FairseqDataset"
,
"FairseqIterableDataset"
,
"FileAudioDataset"
,
"GroupedIterator"
,
"HubertDataset"
,
"IdDataset"
,
"ResamplingDataset"
,
"ShardedIterator"
,
]
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/add_target_dataset.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
from
.
import
BaseWrapperDataset
,
data_utils
class
AddTargetDataset
(
BaseWrapperDataset
):
def
__init__
(
self
,
dataset
,
labels
,
pad
,
eos
,
batch_targets
,
process_label
=
None
,
add_to_input
=
False
,
):
super
().
__init__
(
dataset
)
self
.
labels
=
labels
self
.
batch_targets
=
batch_targets
self
.
pad
=
pad
self
.
eos
=
eos
self
.
process_label
=
process_label
self
.
add_to_input
=
add_to_input
def
get_label
(
self
,
index
):
return
(
self
.
labels
[
index
]
if
self
.
process_label
is
None
else
self
.
process_label
(
self
.
labels
[
index
])
)
def
__getitem__
(
self
,
index
):
item
=
self
.
dataset
[
index
]
item
[
"label"
]
=
self
.
get_label
(
index
)
return
item
def
size
(
self
,
index
):
sz
=
self
.
dataset
.
size
(
index
)
own_sz
=
len
(
self
.
get_label
(
index
))
return
(
sz
,
own_sz
)
def
collater
(
self
,
samples
):
collated
=
self
.
dataset
.
collater
(
samples
)
if
len
(
collated
)
==
0
:
return
collated
indices
=
set
(
collated
[
"id"
].
tolist
())
target
=
[
s
[
"label"
]
for
s
in
samples
if
s
[
"id"
]
in
indices
]
if
self
.
batch_targets
:
collated
[
"target_lengths"
]
=
torch
.
LongTensor
([
len
(
t
)
for
t
in
target
])
target
=
data_utils
.
collate_tokens
(
target
,
pad_idx
=
self
.
pad
,
left_pad
=
False
)
collated
[
"ntokens"
]
=
collated
[
"target_lengths"
].
sum
().
item
()
else
:
collated
[
"ntokens"
]
=
sum
([
len
(
t
)
for
t
in
target
])
collated
[
"target"
]
=
target
if
self
.
add_to_input
:
eos
=
target
.
new_full
((
target
.
size
(
0
),
1
),
self
.
eos
)
collated
[
"target"
]
=
torch
.
cat
([
target
,
eos
],
dim
=-
1
).
long
()
collated
[
"net_input"
][
"prev_output_tokens"
]
=
torch
.
cat
(
[
eos
,
target
],
dim
=-
1
).
long
()
collated
[
"ntokens"
]
+=
target
.
size
(
0
)
return
collated
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/audio/__init__.py
0 → 100644
View file @
39ac40a9
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/audio/audio_utils.py
0 → 100644
View file @
39ac40a9
from
pathlib
import
Path
from
typing
import
BinaryIO
,
Optional
,
Tuple
,
Union
,
List
import
numpy
as
np
import
torch
import
io
import
json
import
librosa
import
scipy
import
soundfile
as
sf
SF_AUDIO_FILE_EXTENSIONS
=
{
".wav"
,
".flac"
,
".ogg"
}
FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS
=
{
".npy"
,
".wav"
,
".flac"
,
".ogg"
}
def
preemphasis
(
x
,
preemph
):
return
scipy
.
signal
.
lfilter
([
1
,
-
preemph
],
[
1
],
x
)
def
mulaw_encode
(
x
,
mu
):
mu
=
mu
-
1
fx
=
np
.
sign
(
x
)
*
np
.
log1p
(
mu
*
np
.
abs
(
x
))
/
np
.
log1p
(
mu
)
return
np
.
floor
((
fx
+
1
)
/
2
*
mu
+
0.5
)
def
mulaw_decode
(
y
,
mu
):
mu
=
mu
-
1
x
=
np
.
sign
(
y
)
/
mu
*
((
1
+
mu
)
**
np
.
abs
(
y
)
-
1
)
return
x
def
_convert_to_mono
(
waveform
:
torch
.
FloatTensor
,
sample_rate
:
int
)
->
torch
.
FloatTensor
:
if
waveform
.
shape
[
0
]
>
1
:
try
:
import
torchaudio.sox_effects
as
ta_sox
except
ImportError
:
raise
ImportError
(
"Please install torchaudio to convert multi-channel audios"
)
effects
=
[[
'channels'
,
'1'
]]
return
ta_sox
.
apply_effects_tensor
(
waveform
,
sample_rate
,
effects
)[
0
]
return
waveform
def
convert_to_mono
(
waveform
:
np
.
ndarray
,
sample_rate
:
int
)
->
np
.
ndarray
:
if
waveform
.
shape
[
0
]
>
1
:
_waveform
=
torch
.
from_numpy
(
waveform
)
return
_convert_to_mono
(
_waveform
,
sample_rate
).
numpy
()
return
waveform
def
get_waveform
(
path_or_fp
:
Union
[
str
,
BinaryIO
],
normalization
=
True
,
mono
=
True
,
frames
=-
1
,
start
=
0
,
always_2d
=
True
)
->
Tuple
[
np
.
ndarray
,
int
]:
"""Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio.
Args:
path_or_fp (str or BinaryIO): the path or file-like object
normalization (bool): Normalize values to [-1, 1] (Default: True)
mono (bool): convert multi-channel audio to mono-channel one
frames (int): the number of frames to read. (-1 for reading all)
start (int): Where to start reading. A negative value counts from the end.
always_2d (bool): always return 2D array even for mono-channel audios
Returns:
waveform (numpy.ndarray): 1D or 2D waveform (channels x length)
sample_rate (float): sample rate
"""
if
isinstance
(
path_or_fp
,
str
):
ext
=
Path
(
path_or_fp
).
suffix
if
ext
not
in
SF_AUDIO_FILE_EXTENSIONS
:
raise
ValueError
(
f
"Unsupported audio format:
{
ext
}
"
)
try
:
import
soundfile
as
sf
except
ImportError
:
raise
ImportError
(
"Please install soundfile to load WAV/FLAC/OGG Vorbis audios"
)
waveform
,
sample_rate
=
sf
.
read
(
path_or_fp
,
dtype
=
"float32"
,
always_2d
=
True
,
frames
=
frames
,
start
=
start
)
waveform
=
waveform
.
T
# T x C -> C x T
if
mono
and
waveform
.
shape
[
0
]
>
1
:
waveform
=
convert_to_mono
(
waveform
,
sample_rate
)
if
not
normalization
:
waveform
*=
2
**
15
# denormalized to 16-bit signed integers
if
not
always_2d
:
waveform
=
waveform
.
squeeze
(
axis
=
0
)
return
waveform
,
sample_rate
def
_get_kaldi_fbank
(
waveform
:
np
.
ndarray
,
sample_rate
:
int
,
n_bins
=
80
)
->
Optional
[
np
.
ndarray
]:
"""Get mel-filter bank features via PyKaldi."""
try
:
from
kaldi.feat.mel
import
MelBanksOptions
from
kaldi.feat.fbank
import
FbankOptions
,
Fbank
from
kaldi.feat.window
import
FrameExtractionOptions
from
kaldi.matrix
import
Vector
mel_opts
=
MelBanksOptions
()
mel_opts
.
num_bins
=
n_bins
frame_opts
=
FrameExtractionOptions
()
frame_opts
.
samp_freq
=
sample_rate
opts
=
FbankOptions
()
opts
.
mel_opts
=
mel_opts
opts
.
frame_opts
=
frame_opts
fbank
=
Fbank
(
opts
=
opts
)
features
=
fbank
.
compute
(
Vector
(
waveform
.
squeeze
()),
1.0
).
numpy
()
return
features
except
ImportError
:
return
None
def
_get_torchaudio_fbank
(
waveform
:
np
.
ndarray
,
sample_rate
,
n_bins
=
80
)
->
Optional
[
np
.
ndarray
]:
"""Get mel-filter bank features via TorchAudio."""
try
:
import
torchaudio.compliance.kaldi
as
ta_kaldi
waveform
=
torch
.
from_numpy
(
waveform
)
features
=
ta_kaldi
.
fbank
(
waveform
,
num_mel_bins
=
n_bins
,
sample_frequency
=
sample_rate
)
return
features
.
numpy
()
except
ImportError
:
return
None
def
get_fbank
(
path_or_fp
:
Union
[
str
,
BinaryIO
],
n_bins
=
80
)
->
np
.
ndarray
:
"""Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi
(faster CPP implementation) to TorchAudio (Python implementation). Note that
Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the
waveform should not be normalized."""
waveform
,
sample_rate
=
get_waveform
(
path_or_fp
,
normalization
=
False
)
features
=
_get_kaldi_fbank
(
waveform
,
sample_rate
,
n_bins
)
if
features
is
None
:
features
=
_get_torchaudio_fbank
(
waveform
,
sample_rate
,
n_bins
)
if
features
is
None
:
raise
ImportError
(
"Please install pyKaldi or torchaudio to enable "
"online filterbank feature extraction"
)
return
features
def
is_npy_data
(
data
:
bytes
)
->
bool
:
return
data
[
0
]
==
147
and
data
[
1
]
==
78
def
is_sf_audio_data
(
data
:
bytes
)
->
bool
:
is_wav
=
(
data
[
0
]
==
82
and
data
[
1
]
==
73
and
data
[
2
]
==
70
)
is_flac
=
(
data
[
0
]
==
102
and
data
[
1
]
==
76
and
data
[
2
]
==
97
)
is_ogg
=
(
data
[
0
]
==
79
and
data
[
1
]
==
103
and
data
[
2
]
==
103
)
return
is_wav
or
is_flac
or
is_ogg
def
read_from_stored_zip
(
zip_path
:
str
,
offset
:
int
,
file_size
:
int
)
->
bytes
:
with
open
(
zip_path
,
"rb"
)
as
f
:
f
.
seek
(
offset
)
data
=
f
.
read
(
file_size
)
return
data
def
parse_path
(
path
:
str
)
->
Tuple
[
str
,
List
[
int
]]:
"""Parse data path which is either a path to
1. a .npy/.wav/.flac/.ogg file
2. a stored ZIP file with slicing info: "[zip_path]:[offset]:[length]"
Args:
path (str): the data path to parse
Returns:
file_path (str): the file path
slice_ptr (list of int): empty in case 1;
byte offset and length for the slice in case 2
"""
if
Path
(
path
).
suffix
in
FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS
:
_path
,
slice_ptr
=
path
,
[]
else
:
_path
,
*
slice_ptr
=
path
.
split
(
":"
)
if
not
Path
(
_path
).
is_file
():
raise
FileNotFoundError
(
f
"File not found:
{
_path
}
"
)
assert
len
(
slice_ptr
)
in
{
0
,
2
},
f
"Invalid path:
{
path
}
"
slice_ptr
=
[
int
(
i
)
for
i
in
slice_ptr
]
return
_path
,
slice_ptr
def
_group_to_batches_by_utters
(
buffer
,
sorted_idx_len_pair
,
batch_size
):
batch_list
=
[]
single_batch
=
[]
for
idx_len_pair
in
sorted_idx_len_pair
:
single_batch
.
append
(
buffer
[
idx_len_pair
[
0
]])
if
len
(
single_batch
)
==
batch_size
:
batch_list
.
append
(
single_batch
)
single_batch
=
[]
if
len
(
single_batch
)
>
0
:
batch_list
.
append
(
single_batch
)
return
batch_list
def
_group_to_batches_by_frames
(
buffer
,
sorted_idx_len_pair
,
batch_size
):
batch_list
=
[]
single_batch
=
[]
frame_num_padded
=
0
first_utt_len
=
sorted_idx_len_pair
[
0
][
1
]
max_sentence
=
batch_size
//
first_utt_len
//
8
*
8
for
idx_len_pair
in
sorted_idx_len_pair
:
if
max_sentence
==
0
:
max_sentence
=
8
frame_num_padded
+=
first_utt_len
if
frame_num_padded
>
batch_size
or
len
(
single_batch
)
==
max_sentence
:
if
len
(
single_batch
)
>
0
:
batch_list
.
append
(
single_batch
)
single_batch
=
[]
first_utt_len
=
idx_len_pair
[
1
]
frame_num_padded
=
first_utt_len
max_sentence
=
batch_size
//
first_utt_len
//
8
*
8
single_batch
.
append
(
buffer
[
idx_len_pair
[
0
]])
if
len
(
single_batch
)
>
0
:
batch_list
.
append
(
single_batch
)
return
batch_list
def
_group_to_batches_by_frame_x_label
(
buffer
,
sorted_idx_len_pair
,
batch_size
):
batch_list
=
[]
single_batch
=
[]
frame_num_padded
=
0
max_lab_len
=
sorted_idx_len_pair
[
0
][
2
]
+
1
max_utt_len
=
sorted_idx_len_pair
[
0
][
1
]
for
idx_len_pair
in
sorted_idx_len_pair
:
if
max_lab_len
<
idx_len_pair
[
2
]
+
1
:
max_lab_len
=
idx_len_pair
[
2
]
+
1
frame_num_padded
=
max_utt_len
*
max_lab_len
*
(
len
(
single_batch
)
)
if
frame_num_padded
>
batch_size
:
if
len
(
single_batch
)
>
0
:
batch_list
.
append
(
single_batch
)
single_batch
=
[]
max_utt_len
=
idx_len_pair
[
1
]
max_lab_len
=
idx_len_pair
[
2
]
+
1
single_batch
.
append
(
buffer
[
idx_len_pair
[
0
]])
if
len
(
single_batch
)
>
0
:
batch_list
.
append
(
single_batch
)
return
batch_list
class
DataParser
():
def
__init__
(
self
):
super
().
__init__
()
def
_parse_data
(
self
,
data
,
data_type
):
if
data_type
.
lower
()
==
'audio'
:
parsed_data
=
self
.
_parse_audio_data
(
data
)
elif
data_type
.
lower
()
==
'info'
:
parsed_data
=
self
.
_parse_json_data
(
data
)
elif
data_type
.
lower
()
==
"feature"
:
parsed_data
=
self
.
_parse_feat_data
(
data
)
else
:
parsed_data
=
self
.
_parse_string_data
(
data
)
return
parsed_data
def
_parse_audio_data
(
self
,
data
):
byte_stream
=
io
.
BytesIO
(
data
)
with
sf
.
SoundFile
(
byte_stream
,
'r'
)
as
f
:
samples
=
f
.
read
()
return
samples
def
_parse_json_data
(
self
,
data
):
str_data
=
str
(
data
,
'utf-8'
)
json_data
=
json
.
loads
(
str_data
)
return
json_data
def
_parse_string_data
(
self
,
data
):
str_data
=
str
(
data
,
'utf-8'
)
return
str_data
def
_parse_feat_data
(
self
,
data
):
feat
=
np
.
frombuffer
(
data
,
dtype
=
np
.
float32
)
feat
=
feat
.
reshape
(
-
1
,
80
)
return
feat
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/audio/chunk_audio_dataset.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
pdb
import
logging
import
os
import
sys
import
json
import
soundfile
as
sf
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
..
import
FairseqDataset
,
data_utils
from
fairseq.tokenizer
import
char_tokenizer
from
fairseq.data.audio.audio_utils
import
_group_to_batches_by_frames
,
_group_to_batches_by_utters
,
_group_to_batches_by_frame_x_label
,
DataParser
ENDIAN
=
'little'
logger
=
logging
.
getLogger
(
__name__
)
class
ChunkAudioDataset
(
torch
.
utils
.
data
.
IterableDataset
,
FairseqDataset
):
def
__init__
(
self
,
chunk_data_file
,
chunk_data_path
=
None
,
chunk_trans_path
=
None
,
max_sample_size
=
None
,
min_sample_size
=
None
,
max_tokens
=
None
,
pad
=
False
,
normalize
=
False
,
subset
=
None
,
shuffle
=
True
,
shard
=
True
,
label
=
False
,
dictionary
=
None
,
feature
=
"audio"
,
mean_file
=
None
,
invstd_file
=
None
,
batch_criterion
=
"frame"
):
self
.
_data_path
=
chunk_data_path
self
.
_data_file
=
chunk_data_file
self
.
_trans_path
=
chunk_trans_path
self
.
max_sample_size
=
(
max_sample_size
if
max_sample_size
is
not
None
else
sys
.
maxsize
)
self
.
min_sample_size
=
min_sample_size
self
.
max_tokens
=
max_tokens
self
.
pad
=
pad
self
.
shuffle
=
shuffle
self
.
shard
=
shard
self
.
normalize
=
normalize
self
.
label
=
label
self
.
dictionary
=
dictionary
self
.
feature
=
feature
if
mean_file
is
not
None
:
self
.
mean
=
np
.
fromfile
(
mean_file
,
sep
=
'
\n
'
)
else
:
self
.
mean
=
None
if
invstd_file
is
not
None
:
self
.
invstd
=
np
.
fromfile
(
invstd_file
,
sep
=
'
\n
'
)
else
:
self
.
invstd
=
None
with
open
(
self
.
_data_file
)
as
f
:
self
.
_chunk_list
=
json
.
load
(
f
)[
'fileInfo'
]
if
self
.
_data_path
is
None
:
self
.
_data_path
=
os
.
path
.
dirname
(
self
.
_data_file
)
if
self
.
_trans_path
is
None
:
self
.
_trans_path
=
os
.
path
.
dirname
(
self
.
_data_file
)
self
.
_chunk_num
=
len
(
self
.
_chunk_list
)
self
.
_example_num
=
0
self
.
_dist_size
=
1
self
.
_dist_rank
=
0
self
.
end_of_epoch
=
False
for
chunk
in
self
.
_chunk_list
:
self
.
_example_num
+=
int
(
chunk
[
'count'
])
logger
.
info
(
f
"Open dataset
{
self
.
_data_file
}
, total example count
{
self
.
_example_num
}
"
)
self
.
subset
=
subset
self
.
parser
=
DataParser
()
self
.
_buffer_size
=
3000
self
.
_batch_criterion
=
batch_criterion
self
.
_example_buffer
=
[]
self
.
_batch_buffer
=
[]
self
.
_first_iteration
=
True
self
.
iterable
=
None
def
__len__
(
self
):
return
self
.
_example_num
def
__iter__
(
self
):
worker_info
=
torch
.
utils
.
data
.
get_worker_info
()
if
worker_info
is
None
:
offset
=
self
.
_dist_rank
skip
=
self
.
_dist_size
else
:
offset
=
self
.
_dist_size
*
worker_info
.
id
+
self
.
_dist_rank
skip
=
self
.
_dist_size
*
worker_info
.
num_workers
#print(self._chunk_list[13])
if
self
.
shard
:
self
.
_sharded_list
=
list
(
self
.
_chunk_list
[
offset
::
skip
])
value
=
len
(
self
.
_chunk_list
)
%
self
.
_dist_size
if
value
!=
0
and
self
.
_dist_rank
>=
value
:
if
worker_info
is
None
or
worker_info
.
id
==
worker_info
.
num_workers
-
1
:
np
.
random
.
seed
(
self
.
_dist_rank
)
pad_chunk
=
np
.
random
.
choice
(
self
.
_chunk_list
)
self
.
_sharded_list
.
append
(
pad_chunk
)
else
:
self
.
_sharded_list
=
self
.
_chunk_list
self
.
iterable
=
iter
(
self
.
_chunk_deserializer
())
#print("{}/{} worker init in gpu {}, sharded data {}/{}".format(worker_info.id, worker_info.num_workers, self._dist_rank, len(self._sharded_list), len(self._chunk_list)))
return
self
def
reset
(
self
,
world_size
=
1
,
world_rank
=
0
):
#print("Reset Dataset")
self
.
_example_buffer
=
[]
self
.
_batch_buffer
=
[]
self
.
_first_iteration
=
True
self
.
_dist_size
=
world_size
self
.
_dist_rank
=
world_rank
np
.
random
.
seed
(
self
.
epoch
)
if
self
.
shuffle
:
np
.
random
.
shuffle
(
self
.
_chunk_list
)
def
set_epoch
(
self
,
epoch
):
self
.
epoch
=
epoch
def
__next__
(
self
):
return
self
.
_dynamicbatcher
()
def
_read_chunk
(
self
,
file_path
,
chunk_type
,
chunk_size
):
example_list
=
[]
with
open
(
file_path
,
'rb'
)
as
f
:
target_type
=
f
.
read
(
len
(
chunk_type
.
encode
())).
decode
()
if
chunk_type
.
lower
()
!=
target_type
.
lower
():
raise
ValueError
(
'Taget type is not expected in {}, expected {}, but got {}'
.
format
(
file_path
,
chunk_type
,
target_type
))
version_number
=
int
.
from_bytes
(
f
.
read
(
4
),
byteorder
=
ENDIAN
)
for
i
in
range
(
chunk_size
):
example_index
=
int
.
from_bytes
(
f
.
read
(
4
),
byteorder
=
ENDIAN
)
if
example_index
!=
i
:
raise
ValueError
(
'The example index is corrupted in {},
\
expected {}, but got {}'
.
format
(
file_path
,
i
,
example_index
))
data_size
=
int
.
from_bytes
(
f
.
read
(
4
),
byteorder
=
ENDIAN
)
data
=
f
.
read
(
data_size
)
example_list
.
append
(
data
)
return
example_list
def
_chunk_deserializer
(
self
):
try
:
iterator
=
iter
(
self
.
_sharded_list
)
chunk
=
next
(
iterator
)
while
True
:
chunk_type
=
[
'info'
,
self
.
feature
]
if
self
.
label
:
chunk_type
.
append
(
'transcription'
)
chunk_name
=
chunk
[
'name'
]
chunk_size
=
int
(
chunk
[
'count'
])
example_dict
=
{}
for
extension
in
chunk_type
:
if
extension
==
'transcription'
:
file_path
=
os
.
path
.
join
(
self
.
_trans_path
,
chunk_name
+
'.transcription'
)
else
:
file_path
=
os
.
path
.
join
(
self
.
_data_path
,
chunk_name
+
'.'
+
extension
)
example_dict
[
extension
]
=
self
.
_read_chunk
(
file_path
,
extension
,
chunk_size
)
example_lens
=
[
len
(
example_dict
[
x
])
for
x
in
chunk_type
]
if
not
all
(
x
==
chunk_size
for
x
in
example_lens
):
error_msg
=
'Chunk size is not consistent in chunk {}'
.
format
(
chunk_name
)
raise
ValueError
(
error_msg
)
for
i
in
range
(
chunk_size
):
one_example
=
{}
for
extension
in
chunk_type
:
one_example
[
extension
]
=
self
.
parser
.
_parse_data
(
example_dict
[
extension
][
i
],
extension
)
if
self
.
subset
is
not
None
and
self
.
subset
not
in
one_example
[
'info'
][
'corpusname'
]:
break
if
'transcription'
in
one_example
:
one_example
[
'y'
]
=
self
.
dictionary
.
encode_line
(
one_example
[
'transcription'
].
upper
(),
line_tokenizer
=
char_tokenizer
,
add_if_not_exist
=
False
,
append_eos
=
False
)
if
self
.
feature
not
in
one_example
:
continue
yield
one_example
chunk
=
next
(
iterator
)
except
StopIteration
:
return
def
_fill_buffer_by_length
(
self
,
buffer
,
length
):
try
:
i
=
0
while
i
<
length
:
example
=
next
(
self
.
iterable
)
x_len
=
example
[
self
.
feature
].
shape
[
0
]
if
self
.
pad
and
self
.
max_sample_size
is
not
None
and
x_len
>
self
.
max_sample_size
:
continue
if
self
.
min_sample_size
is
not
None
and
x_len
<
self
.
min_sample_size
:
continue
buffer
.
append
(
example
)
i
+=
1
except
StopIteration
:
pass
def
_create_batch_list
(
self
,
example_list
):
idx_len_pair
=
[]
for
idx
in
range
(
len
(
example_list
)):
uttlen
=
len
(
example_list
[
idx
][
self
.
feature
])
if
'y'
in
example_list
[
idx
]:
target_len
=
len
(
example_list
[
idx
][
'y'
])
else
:
target_len
=
1
idx_len_pair
.
append
((
idx
,
uttlen
,
target_len
))
sorted_idx_len_pair
=
sorted
(
idx_len_pair
,
key
=
lambda
var
:
var
[
1
],
reverse
=
self
.
pad
)
if
self
.
_batch_criterion
==
"frame"
:
group_batches_fn
=
_group_to_batches_by_frames
elif
self
.
_batch_criterion
==
"utterance"
:
group_batches_fn
=
_group_to_batches_by_utters
elif
self
.
_batch_criterion
==
"frame_x_label"
:
group_batches_fn
=
_group_to_batches_by_frame_x_label
else
:
raise
ValueError
(
"Only support for grouping batches by 'frame', 'utterance', 'frame_x_label'"
)
batch_list
=
group_batches_fn
(
self
.
_example_buffer
,
sorted_idx_len_pair
,
self
.
max_tokens
)
if
self
.
shuffle
:
np
.
random
.
shuffle
(
batch_list
)
return
batch_list
def
_dynamicbatcher
(
self
):
if
self
.
_first_iteration
:
self
.
_first_iteration
=
False
self
.
_fill_buffer_by_length
(
self
.
_example_buffer
,
self
.
_buffer_size
)
if
self
.
shuffle
:
np
.
random
.
shuffle
(
self
.
_example_buffer
)
if
not
self
.
_batch_buffer
and
not
self
.
_example_buffer
:
raise
StopIteration
if
not
self
.
_batch_buffer
:
self
.
_batch_buffer
=
self
.
_create_batch_list
(
self
.
_example_buffer
)
self
.
_example_buffer
=
[]
single_batch
=
self
.
_batch_buffer
.
pop
()
self
.
_fill_buffer_by_length
(
self
.
_example_buffer
,
len
(
single_batch
))
if
self
.
feature
==
"audio"
:
sources
=
[
self
.
postprocess
(
torch
.
from_numpy
(
s
[
self
.
feature
])).
float
()
for
s
in
single_batch
]
else
:
sources
=
[
torch
.
from_numpy
(
self
.
mvn
(
s
[
self
.
feature
])).
float
()
for
s
in
single_batch
]
infos
=
[
s
[
'info'
]
for
s
in
single_batch
]
ids
=
torch
.
LongTensor
(
list
(
range
(
len
(
single_batch
))))
if
self
.
label
:
target
=
[
s
[
'y'
]
for
s
in
single_batch
]
return
{
'info'
:
infos
,
'id'
:
ids
,
'source'
:
sources
,
"target"
:
target
}
return
{
'info'
:
infos
,
'id'
:
ids
,
'source'
:
sources
}
def
collater
(
self
,
samples
):
samples
=
samples
[
0
]
if
len
(
samples
[
"source"
])
==
0
:
return
{}
sources
=
samples
[
'source'
]
sizes
=
[
len
(
s
)
for
s
in
sources
]
if
self
.
pad
:
target_size
=
min
(
max
(
sizes
),
self
.
max_sample_size
)
else
:
target_size
=
min
(
min
(
sizes
),
self
.
max_sample_size
)
if
self
.
feature
==
"audio"
:
collated_sources
=
sources
[
0
].
new_zeros
(
len
(
sources
),
target_size
)
else
:
collated_sources
=
sources
[
0
].
new_zeros
(
len
(
sources
),
target_size
,
80
)
padding_mask
=
(
torch
.
BoolTensor
(
collated_sources
.
shape
).
fill_
(
False
)
if
self
.
pad
else
None
)
for
i
,
(
source
,
size
)
in
enumerate
(
zip
(
sources
,
sizes
)):
diff
=
size
-
target_size
if
diff
==
0
:
collated_sources
[
i
]
=
source
elif
diff
<
0
:
assert
self
.
pad
if
self
.
feature
==
"audio"
:
collated_sources
[
i
]
=
torch
.
cat
(
[
source
,
source
.
new_full
((
-
diff
,),
0.0
)]
)
else
:
collated_sources
[
i
]
=
torch
.
cat
(
[
source
,
source
.
new_full
((
-
diff
,
80
),
0.0
)]
)
padding_mask
[
i
,
diff
:]
=
True
else
:
collated_sources
[
i
]
=
self
.
crop_to_max_size
(
source
,
target_size
)
input
=
{
"source"
:
collated_sources
}
if
self
.
pad
:
input
[
"padding_mask"
]
=
padding_mask
collated
=
{
"info"
:
samples
[
"info"
],
"id"
:
samples
[
"id"
],
"net_input"
:
input
}
if
not
self
.
label
:
return
collated
target
=
samples
[
'target'
]
collated
[
"target_lengths"
]
=
torch
.
LongTensor
([
len
(
t
)
for
t
in
target
])
target
=
data_utils
.
collate_tokens
(
target
,
pad_idx
=
self
.
pad
,
left_pad
=
False
)
collated
[
"ntokens"
]
=
collated
[
"target_lengths"
].
sum
().
item
()
collated
[
"target"
]
=
target
return
collated
def
postprocess
(
self
,
feats
):
if
feats
.
dim
()
==
2
:
feats
=
feats
.
mean
(
-
1
)
assert
feats
.
dim
()
==
1
,
feats
.
dim
()
if
self
.
normalize
:
with
torch
.
no_grad
():
feats
=
F
.
layer_norm
(
feats
,
feats
.
shape
)
return
feats
def
mvn
(
self
,
feats
):
feats
=
(
feats
-
self
.
mean
)
*
self
.
invstd
return
feats
def
crop_to_max_size
(
self
,
wav
,
target_size
):
size
=
len
(
wav
)
diff
=
size
-
target_size
if
diff
<=
0
:
return
wav
start
=
np
.
random
.
randint
(
0
,
diff
+
1
)
end
=
size
-
diff
+
start
return
wav
[
start
:
end
]
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/audio/feats_dataset.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
logging
import
os
import
sys
import
io
import
pdb
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
fairseq.data
import
data_utils
from
..
import
FairseqDataset
from
..data_utils
import
compute_mask_indices
,
get_buckets
,
get_bucketed_sizes
from
fairseq.data.audio.raw_audio_dataset
import
RawAudioDataset
from
fairseq.data.audio.audio_utils
import
(
parse_path
,
read_from_stored_zip
,
is_sf_audio_data
,
mulaw_encode
,
preemphasis
,
)
logger
=
logging
.
getLogger
(
__name__
)
class
FeatsAudioDataset
(
RawAudioDataset
):
def
__init__
(
self
,
manifest_path
,
sample_rate
,
input_feature
=
"mfcc"
,
output_feature
=
"mfcc"
,
max_sample_size
=
None
,
min_sample_size
=
0
,
shuffle
=
True
,
pad
=
False
,
normalize
=
False
,
):
super
().
__init__
(
sample_rate
=
sample_rate
,
max_sample_size
=
max_sample_size
,
min_sample_size
=
min_sample_size
,
shuffle
=
shuffle
,
pad
=
False
,
normalize
=
normalize
)
self
.
chunk_names
=
[]
self
.
chunk_indices
=
[]
self
.
fnames
=
[]
self
.
skipped
=
[]
self
.
speakers
=
[]
self
.
input_feature
=
input_feature
self
.
output_feature
=
output_feature
self
.
speaker_dict
=
{}
speaker_count
=
0
skipped
=
0
count
=
0
sizes
=
[]
self
.
skipped_indices
=
set
()
with
open
(
manifest_path
,
"r"
)
as
f
:
self
.
root_dir
=
f
.
readline
().
strip
()
for
i
,
line
in
enumerate
(
f
):
items
=
line
.
strip
().
split
(
"
\t
"
)
#assert len(items) == 2, line
sz
=
int
(
items
[
1
])
if
self
.
input_feature
!=
"wav"
:
sz
=
int
(
sz
/
self
.
sample_rate
*
100
)
if
min_sample_size
is
not
None
and
sz
<
min_sample_size
:
skipped
+=
1
self
.
skipped
.
append
(
i
)
self
.
skipped_indices
.
add
(
i
)
continue
fname
=
items
[
0
].
split
(
":"
)
if
len
(
fname
)
>
1
:
if
len
(
self
.
chunk_names
)
==
0
or
fname
[
0
]
!=
self
.
chunk_names
[
-
1
]:
self
.
chunk_names
.
append
(
fname
[
0
])
self
.
chunk_indices
.
append
(
len
(
self
.
fnames
))
self
.
fnames
.
append
(
items
[
0
])
if
len
(
items
)
>
2
:
speaker
=
int
(
items
[
2
])
else
:
speaker
=
int
(
items
[
0
].
split
(
"/"
)[
-
1
].
split
(
"-"
)[
0
])
if
speaker
not
in
self
.
speaker_dict
:
self
.
speaker_dict
[
speaker
]
=
speaker_count
speaker_count
+=
1
self
.
speakers
.
append
(
self
.
speaker_dict
[
speaker
])
sizes
.
append
(
sz
)
logger
.
info
(
f
"loaded
{
len
(
self
.
fnames
)
}
, skipped
{
skipped
}
samples"
)
self
.
sizes
=
np
.
array
(
sizes
,
dtype
=
np
.
int64
)
try
:
import
pyarrow
self
.
fnames
=
pyarrow
.
array
(
self
.
fnames
)
except
:
logger
.
debug
(
"Could not create a pyarrow array. Please install pyarrow for better performance"
)
pass
def
get_mfcc
(
self
,
wav
,
sample_rate
=
16000
,
normalize
=
True
):
try
:
import
torchaudio
import
torchaudio.compliance.kaldi
as
ta_kaldi
with
torch
.
no_grad
():
x
=
torch
.
from_numpy
(
wav
).
float
()
x
=
x
.
view
(
1
,
-
1
)
mfccs
=
ta_kaldi
.
mfcc
(
waveform
=
x
,
sample_frequency
=
sample_rate
,
use_energy
=
False
,
)
# (time, freq)
mfccs
=
mfccs
.
transpose
(
0
,
1
)
# (freq, time)
deltas
=
torchaudio
.
functional
.
compute_deltas
(
mfccs
)
ddeltas
=
torchaudio
.
functional
.
compute_deltas
(
deltas
)
concat
=
torch
.
cat
([
mfccs
,
deltas
,
ddeltas
],
dim
=
0
)
concat
=
concat
.
transpose
(
0
,
1
).
contiguous
()
# (freq, time)
if
normalize
:
mean
=
concat
.
mean
(
dim
=
0
)
std
=
concat
.
std
(
dim
=
0
)
concat
=
(
concat
-
mean
)
/
std
return
concat
except
ImportError
:
return
None
def
get_logmel
(
self
,
wav
,
sample_rate
=
16000
,
preemph
=
0.97
,
n_fft
=
2048
,
n_mels
=
80
,
hop_length
=
160
,
win_length
=
400
,
fmin
=
50
,
top_db
=
80
,
bits
=
8
,
offset
=
0.0
,
duration
=
None
):
wav
=
wav
/
np
.
abs
(
wav
).
max
()
*
0.999
try
:
import
librosa
mel
=
librosa
.
feature
.
melspectrogram
(
preemphasis
(
wav
,
preemph
),
sr
=
sample_rate
,
n_fft
=
n_fft
,
n_mels
=
n_mels
,
hop_length
=
hop_length
,
win_length
=
win_length
,
fmin
=
fmin
,
power
=
1
)
logmel
=
librosa
.
amplitude_to_db
(
mel
,
top_db
=
top_db
)
logmel
=
logmel
/
top_db
+
1
logmel
=
torch
.
from_numpy
(
logmel
).
transpose
(
0
,
1
)
return
logmel
except
ImportError
:
return
None
def
get_fbank
(
self
,
wav
,
n_bins
=
80
,
sample_rate
=
16000
,
normalize
=
True
):
try
:
import
torchaudio.compliance.kaldi
as
ta_kaldi
x
=
torch
.
from_numpy
(
wav
).
float
()
x
=
x
.
view
(
1
,
-
1
)
features
=
ta_kaldi
.
fbank
(
x
,
num_mel_bins
=
n_bins
,
sample_frequency
=
sample_rate
)
if
normalize
:
mean
=
features
.
mean
(
dim
=
0
)
std
=
features
.
std
(
dim
=
0
)
features
=
(
features
-
mean
)
/
std
return
features
except
ImportError
:
return
None
def
mulaw_encode
(
self
,
wav
):
wav
=
wav
/
np
.
abs
(
wav
).
max
()
*
0.999
wav
=
mulaw_encode
(
wav
,
mu
=
2
**
8
)
return
wav
def
__getitem__
(
self
,
index
):
import
soundfile
as
sf
path_or_fp
=
os
.
path
.
join
(
self
.
root_dir
,
str
(
self
.
fnames
[
index
]))
_path
,
slice_ptr
=
parse_path
(
path_or_fp
)
if
len
(
slice_ptr
)
==
2
:
byte_data
=
read_from_stored_zip
(
_path
,
slice_ptr
[
0
],
slice_ptr
[
1
])
assert
is_sf_audio_data
(
byte_data
)
path_or_fp
=
io
.
BytesIO
(
byte_data
)
wav
,
curr_sample_rate
=
sf
.
read
(
path_or_fp
,
dtype
=
"float32"
)
if
self
.
input_feature
==
"wav"
:
feats
=
torch
.
from_numpy
(
wav
).
float
()
feats
=
self
.
postprocess
(
feats
,
curr_sample_rate
)
elif
self
.
input_feature
==
"fbank"
:
feats
=
self
.
get_fbank
(
wav
,
n_bins
=
80
,
sample_rate
=
curr_sample_rate
)
elif
self
.
input_feature
==
"mfcc"
:
feats
=
self
.
get_mfcc
(
wav
,
sample_rate
=
curr_sample_rate
)
elif
self
.
input_feature
==
"logmel"
:
feats
=
self
.
get_logmel
(
wav
,
sample_rate
=
curr_sample_rate
)
elif
self
.
input_feature
==
"mulaw"
:
feats
=
self
.
mulaw_encode
(
wav
)
feats
=
torch
.
from_numpy
(
feats
).
long
()
else
:
raise
ValueError
(
"Unknown extra features {}"
.
format
(
self
.
input_feature
))
if
self
.
output_feature
==
self
.
input_feature
:
target
=
feats
elif
self
.
output_feature
==
"wav"
:
target
=
torch
.
from_numpy
(
wav
).
float
()
feats
=
self
.
postprocess
(
feats
,
curr_sample_rate
)
elif
self
.
output_feature
==
"fbank"
:
target
=
self
.
get_fbank
(
wav
,
n_bins
=
80
,
sample_rate
=
curr_sample_rate
)
elif
self
.
output_feature
==
"mfcc"
:
target
=
self
.
get_mfcc
(
wav
,
sample_rate
=
curr_sample_rate
)
elif
self
.
output_feature
==
"logmel"
:
target
=
self
.
get_logmel
(
wav
,
sample_rate
=
curr_sample_rate
)
elif
self
.
output_feature
==
"mulaw"
:
target
=
self
.
mulaw_encode
(
wav
)
target
=
torch
.
from_numpy
(
target
).
long
()
else
:
raise
ValueError
(
"Unknown extra features {}"
.
format
(
self
.
output_feature
))
return
{
"id"
:
index
,
"input"
:
feats
,
"target"
:
target
,
"speaker"
:
self
.
speakers
[
index
]}
def
collater
(
self
,
samples
):
samples
=
[
s
for
s
in
samples
if
s
[
"input"
]
is
not
None
]
if
len
(
samples
)
==
0
:
return
{}
inputs
=
[
s
[
"input"
]
for
s
in
samples
]
targets
=
[
s
[
"target"
]
for
s
in
samples
]
sizes
=
[
len
(
s
)
for
s
in
inputs
]
speakers
=
[
s
[
'speaker'
]
for
s
in
samples
]
input_size
=
min
(
min
(
sizes
),
self
.
max_sample_size
)
if
input_size
%
2
!=
0
:
input_size
=
input_size
-
1
"""
if self.input_feature == "wav" or self.input_feature == "mulaw":
if self.output_feature in ["mfcc", "fbank"]:
target_rate = 1.0 / 160
if self.output_feature == "logmel":
target_rate = 1.0 / 160
start_offset = -1
end_offset = 1
elif self.input_feature == "mfcc" or self.input_feature == "fbank":
if self.output_feature not in ["mfcc", "fbank", "logmel"]:
target_rate = 160
elif self.input_feature == "logmel":
if self.output_feature not in ["mfcc", "fbank", "logmel"]:
target_rate = 160
"""
if
self
.
input_feature
==
self
.
output_feature
:
target_rate
=
1
offset
=
0
elif
self
.
input_feature
==
"logmel"
and
self
.
output_feature
==
"mulaw"
:
target_rate
=
160
offset
=
1
else
:
raise
ValueError
(
"Unsupport {} and {}"
.
format
(
self
.
input_feature
,
self
.
output_feature
))
if
inputs
[
0
].
dim
()
==
2
:
collated_inputs
=
inputs
[
0
].
new_zeros
(
len
(
inputs
),
input_size
+
offset
*
2
,
inputs
[
0
].
shape
[
-
1
])
else
:
collated_inputs
=
inputs
[
0
].
new_zeros
(
len
(
inputs
),
input_size
+
offset
*
2
)
if
targets
[
0
].
dim
()
==
2
:
collated_targets
=
targets
[
0
].
new_zeros
(
len
(
inputs
),
(
input_size
)
*
target_rate
+
offset
,
targets
[
0
].
shape
[
-
1
])
else
:
collated_targets
=
targets
[
0
].
new_zeros
(
len
(
inputs
),
(
input_size
)
*
target_rate
+
offset
)
for
i
,
(
input
,
size
)
in
enumerate
(
zip
(
inputs
,
sizes
)):
size
=
len
(
input
)
start
=
np
.
random
.
randint
(
offset
,
size
-
input_size
+
1
)
collated_inputs
[
i
]
=
input
[
start
-
offset
:
start
+
input_size
+
offset
]
collated_targets
[
i
]
=
targets
[
i
][
start
*
target_rate
:
(
start
+
input_size
)
*
target_rate
+
offset
]
out
=
{
"id"
:
torch
.
LongTensor
([
s
[
"id"
]
for
s
in
samples
]),
"speakers"
:
torch
.
LongTensor
(
speakers
)}
out
[
"input"
]
=
collated_inputs
out
[
"target"
]
=
collated_targets
return
out
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/audio/feature_transforms/__init__.py
0 → 100644
View file @
39ac40a9
import
importlib
import
os
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
Optional
class
AudioFeatureTransform
(
ABC
):
@
classmethod
@
abstractmethod
def
from_config_dict
(
cls
,
config
:
Optional
[
Dict
]
=
None
):
pass
AUDIO_FEATURE_TRANSFORM_REGISTRY
=
{}
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES
=
set
()
def
register_audio_feature_transform
(
name
):
def
register_audio_feature_transform_cls
(
cls
):
if
name
in
AUDIO_FEATURE_TRANSFORM_REGISTRY
:
raise
ValueError
(
f
"Cannot register duplicate transform (
{
name
}
)"
)
if
not
issubclass
(
cls
,
AudioFeatureTransform
):
raise
ValueError
(
f
"Transform (
{
name
}
:
{
cls
.
__name__
}
) must extend "
"AudioFeatureTransform"
)
if
cls
.
__name__
in
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES
:
raise
ValueError
(
f
"Cannot register audio feature transform with duplicate "
f
"class name (
{
cls
.
__name__
}
)"
)
AUDIO_FEATURE_TRANSFORM_REGISTRY
[
name
]
=
cls
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES
.
add
(
cls
.
__name__
)
return
cls
return
register_audio_feature_transform_cls
def
get_audio_feature_transform
(
name
):
return
AUDIO_FEATURE_TRANSFORM_REGISTRY
[
name
]
transforms_dir
=
os
.
path
.
dirname
(
__file__
)
for
file
in
os
.
listdir
(
transforms_dir
):
path
=
os
.
path
.
join
(
transforms_dir
,
file
)
if
(
not
file
.
startswith
(
"_"
)
and
not
file
.
startswith
(
"."
)
and
(
file
.
endswith
(
".py"
)
or
os
.
path
.
isdir
(
path
))
):
name
=
file
[:
file
.
find
(
".py"
)]
if
file
.
endswith
(
".py"
)
else
file
importlib
.
import_module
(
"fairseq.data.audio.feature_transforms."
+
name
)
class
CompositeAudioFeatureTransform
(
AudioFeatureTransform
):
@
classmethod
def
from_config_dict
(
cls
,
config
=
None
):
_config
=
{}
if
config
is
None
else
config
_transforms
=
_config
.
get
(
"transforms"
)
if
_transforms
is
None
:
return
None
transforms
=
[
get_audio_feature_transform
(
_t
).
from_config_dict
(
_config
.
get
(
_t
))
for
_t
in
_transforms
]
return
CompositeAudioFeatureTransform
(
transforms
)
def
__init__
(
self
,
transforms
):
self
.
transforms
=
[
t
for
t
in
transforms
if
t
is
not
None
]
def
__call__
(
self
,
x
):
for
t
in
self
.
transforms
:
x
=
t
(
x
)
return
x
def
__repr__
(
self
):
format_string
=
(
[
self
.
__class__
.
__name__
+
"("
]
+
[
f
"
{
t
.
__repr__
()
}
"
for
t
in
self
.
transforms
]
+
[
")"
]
)
return
"
\n
"
.
join
(
format_string
)
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/audio/feature_transforms/global_cmvn.py
0 → 100644
View file @
39ac40a9
import
numpy
as
np
from
fairseq.data.audio.feature_transforms
import
(
AudioFeatureTransform
,
register_audio_feature_transform
,
)
@
register_audio_feature_transform
(
"global_cmvn"
)
class
GlobalCMVN
(
AudioFeatureTransform
):
"""Global CMVN (cepstral mean and variance normalization). The global mean
and variance need to be pre-computed and stored in NumPy format (.npz)."""
@
classmethod
def
from_config_dict
(
cls
,
config
=
None
):
_config
=
{}
if
config
is
None
else
config
return
GlobalCMVN
(
_config
.
get
(
"stats_npz_path"
))
def
__init__
(
self
,
stats_npz_path
):
self
.
stats_npz_path
=
stats_npz_path
stats
=
np
.
load
(
stats_npz_path
)
self
.
mean
,
self
.
std
=
stats
[
"mean"
],
stats
[
"std"
]
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
f
'(stats_npz_path="
{
self
.
stats_npz_path
}
")'
def
__call__
(
self
,
x
):
x
=
np
.
subtract
(
x
,
self
.
mean
)
x
=
np
.
divide
(
x
,
self
.
std
)
return
x
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/data/audio/feature_transforms/specaugment.py
0 → 100644
View file @
39ac40a9
import
math
import
numbers
from
typing
import
Optional
import
numpy
as
np
from
fairseq.data.audio.feature_transforms
import
(
AudioFeatureTransform
,
register_audio_feature_transform
,
)
@
register_audio_feature_transform
(
"specaugment"
)
class
SpecAugmentTransform
(
AudioFeatureTransform
):
"""SpecAugment (https://arxiv.org/abs/1904.08779)"""
@
classmethod
def
from_config_dict
(
cls
,
config
=
None
):
_config
=
{}
if
config
is
None
else
config
return
SpecAugmentTransform
(
_config
.
get
(
"time_warp_W"
,
0
),
_config
.
get
(
"freq_mask_N"
,
0
),
_config
.
get
(
"freq_mask_F"
,
0
),
_config
.
get
(
"time_mask_N"
,
0
),
_config
.
get
(
"time_mask_T"
,
0
),
_config
.
get
(
"time_mask_p"
,
0.0
),
_config
.
get
(
"mask_value"
,
None
),
)
def
__init__
(
self
,
time_warp_w
:
int
=
0
,
freq_mask_n
:
int
=
0
,
freq_mask_f
:
int
=
0
,
time_mask_n
:
int
=
0
,
time_mask_t
:
int
=
0
,
time_mask_p
:
float
=
0.0
,
mask_value
:
Optional
[
float
]
=
0.0
,
):
# Sanity checks
assert
mask_value
is
None
or
isinstance
(
mask_value
,
numbers
.
Number
),
f
"mask_value (type:
{
type
(
mask_value
)
}
) must be None or a number"
if
freq_mask_n
>
0
:
assert
freq_mask_f
>
0
,
(
f
"freq_mask_F (
{
freq_mask_f
}
) "
f
"must be larger than 0 when doing freq masking."
)
if
time_mask_n
>
0
:
assert
time_mask_t
>
0
,
(
f
"time_mask_T (
{
time_mask_t
}
) must be larger than 0 when "
f
"doing time masking."
)
self
.
time_warp_w
=
time_warp_w
self
.
freq_mask_n
=
freq_mask_n
self
.
freq_mask_f
=
freq_mask_f
self
.
time_mask_n
=
time_mask_n
self
.
time_mask_t
=
time_mask_t
self
.
time_mask_p
=
time_mask_p
self
.
mask_value
=
mask_value
def
__repr__
(
self
):
return
(
self
.
__class__
.
__name__
+
"("
+
", "
.
join
(
[
f
"time_warp_w=
{
self
.
time_warp_w
}
"
,
f
"freq_mask_n=
{
self
.
freq_mask_n
}
"
,
f
"freq_mask_f=
{
self
.
freq_mask_f
}
"
,
f
"time_mask_n=
{
self
.
time_mask_n
}
"
,
f
"time_mask_t=
{
self
.
time_mask_t
}
"
,
f
"time_mask_p=
{
self
.
time_mask_p
}
"
,
]
)
+
")"
)
def
__call__
(
self
,
spectrogram
):
assert
len
(
spectrogram
.
shape
)
==
2
,
"spectrogram must be a 2-D tensor."
distorted
=
spectrogram
.
copy
()
# make a copy of input spectrogram.
num_frames
=
spectrogram
.
shape
[
0
]
# or 'tau' in the paper.
num_freqs
=
spectrogram
.
shape
[
1
]
# or 'miu' in the paper.
mask_value
=
self
.
mask_value
if
mask_value
is
None
:
# if no value was specified, use local mean.
mask_value
=
spectrogram
.
mean
()
if
num_frames
==
0
:
return
spectrogram
if
num_freqs
<
self
.
freq_mask_f
:
return
spectrogram
if
self
.
time_warp_w
>
0
:
if
2
*
self
.
time_warp_w
<
num_frames
:
import
cv2
w0
=
np
.
random
.
randint
(
self
.
time_warp_w
,
num_frames
-
self
.
time_warp_w
)
w
=
np
.
random
.
randint
(
-
self
.
time_warp_w
+
1
,
self
.
time_warp_w
)
upper
,
lower
=
distorted
[:
w0
,
:],
distorted
[
w0
:,
:]
upper
=
cv2
.
resize
(
upper
,
dsize
=
(
num_freqs
,
w0
+
w
),
interpolation
=
cv2
.
INTER_LINEAR
)
lower
=
cv2
.
resize
(
lower
,
dsize
=
(
num_freqs
,
num_frames
-
w0
-
w
),
interpolation
=
cv2
.
INTER_LINEAR
,
)
distorted
=
np
.
concatenate
((
upper
,
lower
),
axis
=
0
)
for
_i
in
range
(
self
.
freq_mask_n
):
f
=
np
.
random
.
randint
(
0
,
self
.
freq_mask_f
)
f0
=
np
.
random
.
randint
(
0
,
num_freqs
-
f
)
if
f
!=
0
:
distorted
[:,
f0
:
f0
+
f
]
=
mask_value
max_time_mask_t
=
min
(
self
.
time_mask_t
,
math
.
floor
(
num_frames
*
self
.
time_mask_p
)
)
if
max_time_mask_t
<
1
:
return
distorted
for
_i
in
range
(
self
.
time_mask_n
):
t
=
np
.
random
.
randint
(
0
,
max_time_mask_t
)
t0
=
np
.
random
.
randint
(
0
,
num_frames
-
t
)
if
t
!=
0
:
distorted
[
t0
:
t0
+
t
,
:]
=
mask_value
return
distorted
Prev
1
…
17
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment