Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
dcuai
dlexamples
Commits
c0f05c10
Commit
c0f05c10
authored
Nov 29, 2022
by
hepj
Browse files
更新transformer代码
parent
c056df78
Changes
321
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2931 additions
and
0 deletions
+2931
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/sentence_prediction.py
...new-Transformer/fairseq/criterions/sentence_prediction.py
+141
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/sentence_prediction_adapters.py
...former/fairseq/criterions/sentence_prediction_adapters.py
+63
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/sentence_ranking.py
...LP/new-Transformer/fairseq/criterions/sentence_ranking.py
+120
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/speech_to_speech_criterion.py
...nsformer/fairseq/criterions/speech_to_speech_criterion.py
+310
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/speech_ulm_criterion.py
...ew-Transformer/fairseq/criterions/speech_ulm_criterion.py
+126
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/tacotron2_loss.py
.../NLP/new-Transformer/fairseq/criterions/tacotron2_loss.py
+226
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/wav2vec_criterion.py
...P/new-Transformer/fairseq/criterions/wav2vec_criterion.py
+230
-0
PyTorch/NLP/new-Transformer/fairseq/data/__init__.py
PyTorch/NLP/new-Transformer/fairseq/data/__init__.py
+130
-0
PyTorch/NLP/new-Transformer/fairseq/data/add_target_dataset.py
...ch/NLP/new-Transformer/fairseq/data/add_target_dataset.py
+83
-0
PyTorch/NLP/new-Transformer/fairseq/data/append_token_dataset.py
.../NLP/new-Transformer/fairseq/data/append_token_dataset.py
+41
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/__init__.py
PyTorch/NLP/new-Transformer/fairseq/data/audio/__init__.py
+0
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/audio_utils.py
...rch/NLP/new-Transformer/fairseq/data/audio/audio_utils.py
+293
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/data_cfg.py
PyTorch/NLP/new-Transformer/fairseq/data/audio/data_cfg.py
+299
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/feature_transforms/__init__.py
...sformer/fairseq/data/audio/feature_transforms/__init__.py
+82
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/feature_transforms/delta_deltas.py
...mer/fairseq/data/audio/feature_transforms/delta_deltas.py
+37
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/feature_transforms/global_cmvn.py
...rmer/fairseq/data/audio/feature_transforms/global_cmvn.py
+29
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/feature_transforms/specaugment.py
...rmer/fairseq/data/audio/feature_transforms/specaugment.py
+131
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/feature_transforms/utterance_cmvn.py
...r/fairseq/data/audio/feature_transforms/utterance_cmvn.py
+41
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/frm_text_to_speech_dataset.py
...nsformer/fairseq/data/audio/frm_text_to_speech_dataset.py
+205
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/hubert_dataset.py
.../NLP/new-Transformer/fairseq/data/audio/hubert_dataset.py
+344
-0
No files found.
Too many changes to show.
To preserve performance only
321 of 321+
files are displayed.
Plain diff
Email patch
PyTorch/NLP/new-Transformer/fairseq/criterions/sentence_prediction.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
dataclasses
import
dataclass
,
field
import
torch
import
torch.nn.functional
as
F
from
fairseq
import
metrics
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
from
fairseq.dataclass
import
FairseqDataclass
@
dataclass
class
SentencePredictionConfig
(
FairseqDataclass
):
classification_head_name
:
str
=
field
(
default
=
"sentence_classification_head"
,
metadata
=
{
"help"
:
"name of the classification head to use"
},
)
regression_target
:
bool
=
field
(
default
=
False
,
)
@
register_criterion
(
"sentence_prediction"
,
dataclass
=
SentencePredictionConfig
)
class
SentencePredictionCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
cfg
:
SentencePredictionConfig
,
task
):
super
().
__init__
(
task
)
self
.
classification_head_name
=
cfg
.
classification_head_name
self
.
regression_target
=
cfg
.
regression_target
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert
(
hasattr
(
model
,
"classification_heads"
)
and
self
.
classification_head_name
in
model
.
classification_heads
),
"model must provide sentence classification head for --criterion=sentence_prediction"
logits
,
_
=
model
(
**
sample
[
"net_input"
],
features_only
=
True
,
classification_head_name
=
self
.
classification_head_name
,
)
targets
=
model
.
get_targets
(
sample
,
[
logits
]).
view
(
-
1
)
sample_size
=
targets
.
numel
()
if
not
self
.
regression_target
:
lprobs
=
F
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
task_loss
=
F
.
nll_loss
(
lprobs
,
targets
,
reduction
=
"sum"
)
else
:
logits
=
logits
.
view
(
-
1
).
float
()
targets
=
targets
.
float
()
task_loss
=
F
.
mse_loss
(
logits
,
targets
,
reduction
=
"sum"
)
logging_output
=
{}
loss
=
task_loss
# mha & ffn regularization update
if
(
hasattr
(
model
.
args
,
"mha_reg_scale_factor"
)
and
model
.
args
.
mha_reg_scale_factor
!=
0.0
):
mha_reg_loss
=
model
.
_get_adaptive_head_loss
()
loss
+=
mha_reg_loss
logging_output
.
update
({
"mha_reg_loss"
:
mha_reg_loss
})
if
(
hasattr
(
model
.
args
,
"ffn_reg_scale_factor"
)
and
model
.
args
.
ffn_reg_scale_factor
!=
0.0
):
ffn_reg_loss
=
model
.
_get_adaptive_ffn_loss
()
loss
+=
ffn_reg_loss
logging_output
.
update
({
"ffn_reg_loss"
:
ffn_reg_loss
})
logging_output
.
update
(
{
"loss"
:
loss
.
data
,
"ntokens"
:
sample
[
"ntokens"
],
"nsentences"
:
sample_size
,
"sample_size"
:
sample_size
,
}
)
if
not
self
.
regression_target
:
preds
=
logits
.
argmax
(
dim
=
1
)
logging_output
[
"ncorrect"
]
=
(
preds
==
targets
).
sum
()
return
loss
,
sample_size
,
logging_output
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
"""Aggregate logging outputs from data parallel training."""
loss_sum
=
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
)
ntokens
=
sum
(
log
.
get
(
"ntokens"
,
0
)
for
log
in
logging_outputs
)
nsentences
=
sum
(
log
.
get
(
"nsentences"
,
0
)
for
log
in
logging_outputs
)
sample_size
=
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
mha_reg_loss_sum
=
sum
(
log
.
get
(
"mha_reg_loss"
,
0
)
for
log
in
logging_outputs
)
ffn_reg_loss_sum
=
sum
(
log
.
get
(
"ffn_reg_loss"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"loss"
,
loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
)
if
mha_reg_loss_sum
:
metrics
.
log_scalar
(
"mha_reg_loss"
,
mha_reg_loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
,
)
if
ffn_reg_loss_sum
:
metrics
.
log_scalar
(
"ffn_reg_loss"
,
ffn_reg_loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
,
)
if
sample_size
!=
ntokens
:
metrics
.
log_scalar
(
"nll_loss"
,
loss_sum
/
ntokens
/
math
.
log
(
2
),
ntokens
,
round
=
3
)
if
len
(
logging_outputs
)
>
0
and
"ncorrect"
in
logging_outputs
[
0
]:
ncorrect
=
sum
(
log
.
get
(
"ncorrect"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"accuracy"
,
100.0
*
ncorrect
/
nsentences
,
nsentences
,
round
=
1
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return
True
PyTorch/NLP/new-Transformer/fairseq/criterions/sentence_prediction_adapters.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
import
torch.nn.functional
as
F
from
fairseq.criterions
import
register_criterion
from
fairseq.criterions.sentence_prediction
import
(
SentencePredictionCriterion
,
SentencePredictionConfig
,
)
@
register_criterion
(
"sentence_prediction_adapters"
,
dataclass
=
SentencePredictionConfig
)
class
SentencePredictionCriterionAdapters
(
SentencePredictionCriterion
):
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert
(
hasattr
(
model
,
"classification_heads"
)
and
self
.
classification_head_name
in
model
.
classification_heads
),
"model must provide sentence classification head for --criterion=sentence_prediction"
if
not
hasattr
(
sample
,
"lang_id"
):
# If no language ID is given, we fall back to English
lang_id
=
[
"en_XX"
]
*
sample
[
"nsentences"
]
else
:
lang_id
=
sample
[
"lang_id"
]
logits
,
_
=
model
(
**
sample
[
"net_input"
],
features_only
=
True
,
classification_head_name
=
self
.
classification_head_name
,
lang_id
=
lang_id
,
)
targets
=
model
.
get_targets
(
sample
,
[
logits
]).
view
(
-
1
)
sample_size
=
targets
.
numel
()
if
not
self
.
regression_target
:
lprobs
=
F
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
loss
=
F
.
nll_loss
(
lprobs
,
targets
,
reduction
=
"sum"
)
else
:
logits
=
logits
.
view
(
-
1
).
float
()
targets
=
targets
.
float
()
loss
=
F
.
mse_loss
(
logits
,
targets
,
reduction
=
"sum"
)
logging_output
=
{
"loss"
:
loss
.
data
,
"ntokens"
:
sample
[
"ntokens"
],
"nsentences"
:
sample_size
,
"sample_size"
:
sample_size
,
}
if
not
self
.
regression_target
:
preds
=
logits
.
argmax
(
dim
=
1
)
logging_output
[
"ncorrect"
]
=
(
preds
==
targets
).
sum
()
return
loss
,
sample_size
,
logging_output
PyTorch/NLP/new-Transformer/fairseq/criterions/sentence_ranking.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
import
torch
import
torch.nn.functional
as
F
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
@
register_criterion
(
"sentence_ranking"
)
class
SentenceRankingCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
task
,
ranking_head_name
,
save_predictions
,
num_classes
):
super
().
__init__
(
task
)
self
.
ranking_head_name
=
ranking_head_name
if
save_predictions
is
not
None
:
self
.
prediction_h
=
open
(
save_predictions
,
"w"
)
else
:
self
.
prediction_h
=
None
self
.
num_classes
=
num_classes
def
__del__
(
self
):
if
self
.
prediction_h
is
not
None
:
self
.
prediction_h
.
close
()
@
staticmethod
def
add_args
(
parser
):
# fmt: off
parser
.
add_argument
(
'--save-predictions'
,
metavar
=
'FILE'
,
help
=
'file to save predictions to'
)
parser
.
add_argument
(
'--ranking-head-name'
,
default
=
'sentence_classification_head'
,
help
=
'name of the ranking head to use'
)
# fmt: on
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute ranking loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert
(
hasattr
(
model
,
"classification_heads"
)
and
self
.
ranking_head_name
in
model
.
classification_heads
),
"model must provide sentence ranking head for --criterion=sentence_ranking"
scores
=
[]
for
idx
in
range
(
self
.
num_classes
):
score
,
_
=
model
(
**
sample
[
"net_input{idx}"
.
format
(
idx
=
idx
+
1
)],
classification_head_name
=
self
.
ranking_head_name
,
)
scores
.
append
(
score
)
logits
=
torch
.
cat
(
scores
,
dim
=
1
)
sample_size
=
logits
.
size
(
0
)
if
"target"
in
sample
:
targets
=
model
.
get_targets
(
sample
,
[
logits
]).
view
(
-
1
)
lprobs
=
F
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
loss
=
F
.
nll_loss
(
lprobs
,
targets
,
reduction
=
"sum"
)
else
:
targets
=
None
loss
=
torch
.
tensor
(
0.0
,
requires_grad
=
True
)
if
self
.
prediction_h
is
not
None
:
preds
=
logits
.
argmax
(
dim
=
1
)
for
i
,
(
id
,
pred
)
in
enumerate
(
zip
(
sample
[
"id"
].
tolist
(),
preds
.
tolist
())):
if
targets
is
not
None
:
label
=
targets
[
i
].
item
()
print
(
"{}
\t
{}
\t
{}"
.
format
(
id
,
pred
,
label
),
file
=
self
.
prediction_h
)
else
:
print
(
"{}
\t
{}"
.
format
(
id
,
pred
),
file
=
self
.
prediction_h
)
logging_output
=
{
"loss"
:
loss
.
data
,
"ntokens"
:
sample
[
"ntokens"
],
"nsentences"
:
sample_size
,
"sample_size"
:
sample_size
,
}
if
targets
is
not
None
:
logging_output
[
"ncorrect"
]
=
(
logits
.
argmax
(
dim
=
1
)
==
targets
).
sum
()
return
loss
,
sample_size
,
logging_output
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
"""Aggregate logging outputs from data parallel training."""
loss_sum
=
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
)
ntokens
=
sum
(
log
.
get
(
"ntokens"
,
0
)
for
log
in
logging_outputs
)
nsentences
=
sum
(
log
.
get
(
"nsentences"
,
0
)
for
log
in
logging_outputs
)
sample_size
=
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"loss"
,
loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
)
if
sample_size
!=
ntokens
:
metrics
.
log_scalar
(
"nll_loss"
,
loss_sum
/
ntokens
/
math
.
log
(
2
),
ntokens
,
round
=
3
)
if
len
(
logging_outputs
)
>
0
and
"ncorrect"
in
logging_outputs
[
0
]:
ncorrect
=
sum
(
log
.
get
(
"ncorrect"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"accuracy"
,
100.0
*
ncorrect
/
nsentences
,
nsentences
,
round
=
1
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return
True
PyTorch/NLP/new-Transformer/fairseq/criterions/speech_to_speech_criterion.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
import
torch
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
register_criterion
from
fairseq.criterions.ctc
import
CtcCriterion
from
fairseq.criterions.label_smoothed_cross_entropy
import
(
LabelSmoothedCrossEntropyCriterion
,
LabelSmoothedCrossEntropyCriterionConfig
,
)
from
fairseq.criterions.tacotron2_loss
import
(
Tacotron2Criterion
,
Tacotron2CriterionConfig
,
)
class
MultitaskCriterion
:
def
__init__
(
self
,
multitask_tasks
):
self
.
multitask_criterion
=
{}
self
.
multitask_loss_weight
=
{}
for
task_name
,
task_obj
in
multitask_tasks
.
items
():
if
task_obj
.
args
.
decoder_type
==
"ctc"
:
self
.
multitask_criterion
[
task_name
]
=
CtcCriterion
(
task_obj
.
args
.
criterion_cfg
,
task_obj
)
else
:
self
.
multitask_criterion
[
task_name
]
=
LabelSmoothedCrossEntropyCriterion
(
task_obj
,
task_obj
.
args
.
criterion_cfg
.
sentence_avg
,
label_smoothing
=
task_obj
.
args
.
criterion_cfg
.
label_smoothing
,
)
def
set_multitask_loss_weight
(
self
,
task_name
,
weight
=
0.0
):
self
.
multitask_loss_weight
[
task_name
]
=
weight
def
get_multitask_loss
(
self
,
model
,
sample
,
model_out
):
logging_output
=
{}
loss
=
0.0
for
task_name
,
task_criterion
in
self
.
multitask_criterion
.
items
():
layer_id
=
task_criterion
.
task
.
args
.
input_layer
if
isinstance
(
task_criterion
,
CtcCriterion
):
if
task_criterion
.
task
.
args
.
input_from
==
"encoder"
:
non_padding_mask
=
~
model_out
[
"encoder_padding_mask"
][
0
]
input_lengths
=
non_padding_mask
.
long
().
sum
(
-
1
)
task_sample
=
{
"net_input"
:
{
"src_tokens"
:
model_out
[
"encoder_states"
][
layer_id
],
# check batch idx
"src_lengths"
:
input_lengths
,
},
"id"
:
sample
[
"id"
],
}
else
:
task_sample
=
{
"net_input"
:
{
"src_tokens"
:
model_out
[
"inner_states"
][
layer_id
],
"src_lengths"
:
sample
[
"target_lengths"
],
},
"id"
:
sample
[
"id"
],
}
else
:
task_sample
=
{
"net_input"
:
{
"src_tokens"
:
sample
[
"multitask"
][
task_name
][
"net_input"
][
"prev_output_tokens"
],
"encoder_out"
:
{
"encoder_out"
:
[
model_out
[
"encoder_states"
][
layer_id
]],
"encoder_padding_mask"
:
model_out
[
"encoder_padding_mask"
],
},
}
}
for
key
in
[
"target"
,
"target_lengths"
,
"ntokens"
]:
task_sample
[
key
]
=
sample
[
"multitask"
][
task_name
][
key
]
task_loss
,
task_sample_size
,
task_logging_output
=
task_criterion
(
model
.
multitask_decoders
[
task_name
],
task_sample
)
loss
=
loss
+
self
.
multitask_loss_weight
[
task_name
]
*
task_loss
task_logging_output
[
"loss_weight"
]
=
self
.
multitask_loss_weight
[
task_name
]
logging_output
[
task_name
]
=
task_logging_output
return
loss
,
logging_output
@
classmethod
def
reduce_metrics
(
cls
,
logging_outputs
)
->
None
:
for
task_name
in
logging_outputs
[
0
][
"multitask"
].
keys
():
# different criterion may return different logging
# currently only reduce on loss, the most common one
# ideally the way that losses are reduced should also depend on the task type
loss_sum
=
sum
(
log
[
"multitask"
][
task_name
].
get
(
"loss"
,
0
)
for
log
in
logging_outputs
)
sample_size
=
sum
(
log
[
"multitask"
][
task_name
].
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
f
"multitask_
{
task_name
}
_loss"
,
loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
,
)
loss_weight
=
logging_outputs
[
0
][
"multitask"
][
task_name
].
get
(
"loss_weight"
,
0
)
metrics
.
log_scalar
(
f
"multitask_
{
task_name
}
_loss_weight"
,
loss_weight
,
weight
=
0
,
priority
=
250
,
)
@
register_criterion
(
"speech_to_unit"
,
dataclass
=
LabelSmoothedCrossEntropyCriterionConfig
)
class
SpeechToUnitMultitaskTaskCriterion
(
LabelSmoothedCrossEntropyCriterion
,
MultitaskCriterion
):
def
__init__
(
self
,
task
,
sentence_avg
,
label_smoothing
,
ignore_prefix_size
=
0
,
report_accuracy
=
False
,
):
super
().
__init__
(
task
,
sentence_avg
,
label_smoothing
,
ignore_prefix_size
,
report_accuracy
)
MultitaskCriterion
.
__init__
(
self
,
task
.
multitask_tasks
)
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
net_output
,
extra
=
model
(
src_tokens
=
sample
[
"net_input"
][
"src_tokens"
],
src_lengths
=
sample
[
"net_input"
][
"src_lengths"
],
prev_output_tokens
=
sample
[
"net_input"
][
"prev_output_tokens"
],
tgt_speaker
=
sample
[
"net_input"
][
"tgt_speaker"
],
return_all_hiddens
=
True
,
)
loss
,
nll_loss
=
self
.
compute_loss
(
model
,
[
net_output
],
sample
,
reduce
=
reduce
)
sample_size
=
(
sample
[
"target"
].
size
(
0
)
if
self
.
sentence_avg
else
sample
[
"ntokens"
]
)
logging_output
=
{
"loss"
:
loss
.
data
,
"nll_loss"
:
nll_loss
.
data
,
"ntokens"
:
sample
[
"ntokens"
],
"nsentences"
:
sample
[
"target"
].
size
(
0
),
"sample_size"
:
sample_size
,
}
if
self
.
report_accuracy
:
n_correct
,
total
=
self
.
compute_accuracy
(
model
,
[
net_output
],
sample
)
logging_output
[
"n_correct"
]
=
utils
.
item
(
n_correct
.
data
)
logging_output
[
"total"
]
=
utils
.
item
(
total
.
data
)
if
len
(
self
.
multitask_criterion
)
==
0
:
return
loss
,
sample_size
,
logging_output
# multitask
multitask_loss
,
multitask_log
=
self
.
get_multitask_loss
(
model
,
sample
,
extra
)
loss
+=
multitask_loss
logging_output
[
"multitask"
]
=
multitask_log
return
loss
,
sample_size
,
logging_output
@
classmethod
def
reduce_metrics
(
cls
,
logging_outputs
)
->
None
:
super
().
reduce_metrics
(
logging_outputs
)
# inference metrics
if
"targ_frames"
in
logging_outputs
[
0
]:
n
=
sum
(
log
.
get
(
"norm_frames"
,
0
)
for
log
in
logging_outputs
)
for
key
,
new_key
in
[
(
"mcd_loss"
,
"mcd_loss"
),
(
"pred_frames"
,
"pred_ratio"
),
(
"nins"
,
"ins_rate"
),
(
"ndel"
,
"del_rate"
),
]:
val
=
sum
(
log
.
get
(
key
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
new_key
,
val
/
n
,
n
,
round
=
3
)
if
"multitask"
not
in
logging_outputs
[
0
]:
return
MultitaskCriterion
.
reduce_metrics
(
logging_outputs
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return
False
@
register_criterion
(
"speech_to_spectrogram"
,
dataclass
=
Tacotron2CriterionConfig
)
class
SpeechToSpectrogramMultitaskTaskCriterion
(
Tacotron2Criterion
,
MultitaskCriterion
):
def
__init__
(
self
,
task
,
sentence_avg
,
use_guided_attention_loss
,
guided_attention_loss_sigma
,
bce_pos_weight
,
ctc_weight
,
):
super
().
__init__
(
task
,
sentence_avg
,
use_guided_attention_loss
,
guided_attention_loss_sigma
,
bce_pos_weight
,
ctc_weight
,
)
MultitaskCriterion
.
__init__
(
self
,
task
.
multitask_tasks
)
def
forward
(
self
,
model
,
sample
,
reduction
=
"mean"
):
bsz
,
max_len
,
_
=
sample
[
"target"
].
size
()
feat_tgt
=
sample
[
"target"
]
feat_len
=
sample
[
"target_lengths"
].
view
(
bsz
,
1
).
expand
(
-
1
,
max_len
)
eos_tgt
=
torch
.
arange
(
max_len
).
to
(
sample
[
"target"
].
device
)
eos_tgt
=
eos_tgt
.
view
(
1
,
max_len
).
expand
(
bsz
,
-
1
)
eos_tgt
=
(
eos_tgt
==
(
feat_len
-
1
)).
float
()
feat_out
,
eos_out
,
extra
=
model
(
src_tokens
=
sample
[
"net_input"
][
"src_tokens"
],
src_lengths
=
sample
[
"net_input"
][
"src_lengths"
],
prev_output_tokens
=
sample
[
"net_input"
][
"prev_output_tokens"
],
tgt_speaker
=
sample
[
"net_input"
][
"tgt_speaker"
],
target_lengths
=
sample
[
"target_lengths"
],
return_all_hiddens
=
True
,
)
l1_loss
,
mse_loss
,
eos_loss
=
self
.
compute_loss
(
extra
[
"feature_out"
],
feat_out
,
eos_out
,
feat_tgt
,
eos_tgt
,
sample
[
"target_lengths"
],
reduction
,
)
attn_loss
=
torch
.
tensor
(
0.0
).
type_as
(
l1_loss
)
if
self
.
guided_attn
is
not
None
:
attn_loss
=
self
.
guided_attn
(
extra
[
"attn"
],
sample
[
"net_input"
][
"src_lengths"
],
sample
[
"target_lengths"
],
reduction
,
)
loss
=
(
l1_loss
+
mse_loss
+
eos_loss
+
attn_loss
)
# do not include ctc loss as there's no text target
sample_size
=
sample
[
"nsentences"
]
if
self
.
sentence_avg
else
sample
[
"ntokens"
]
logging_output
=
{
"loss"
:
utils
.
item
(
loss
.
data
),
"ntokens"
:
sample
[
"ntokens"
],
"nsentences"
:
sample
[
"nsentences"
],
"sample_size"
:
sample_size
,
"l1_loss"
:
utils
.
item
(
l1_loss
.
data
),
"mse_loss"
:
utils
.
item
(
mse_loss
.
data
),
"eos_loss"
:
utils
.
item
(
eos_loss
.
data
),
"attn_loss"
:
utils
.
item
(
attn_loss
.
data
),
}
if
len
(
self
.
multitask_criterion
)
==
0
:
return
loss
,
sample_size
,
logging_output
# multitask
multitask_loss
,
multitask_log
=
self
.
get_multitask_loss
(
model
,
sample
,
extra
)
loss
+=
multitask_loss
logging_output
[
"multitask"
]
=
multitask_log
return
loss
,
sample_size
,
logging_output
@
classmethod
def
reduce_metrics
(
cls
,
logging_outputs
)
->
None
:
super
().
reduce_metrics
(
logging_outputs
)
# inference metrics
if
"targ_frames"
in
logging_outputs
[
0
]:
n
=
sum
(
log
.
get
(
"norm_frames"
,
0
)
for
log
in
logging_outputs
)
for
key
,
new_key
in
[
(
"mcd_loss"
,
"mcd_loss"
),
(
"pred_frames"
,
"pred_ratio"
),
(
"nins"
,
"ins_rate"
),
(
"ndel"
,
"del_rate"
),
]:
val
=
sum
(
log
.
get
(
key
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
new_key
,
val
/
n
,
n
,
round
=
3
)
if
"multitask"
not
in
logging_outputs
[
0
]:
return
MultitaskCriterion
.
reduce_metrics
(
logging_outputs
)
PyTorch/NLP/new-Transformer/fairseq/criterions/speech_ulm_criterion.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
from
dataclasses
import
dataclass
,
field
import
torch.nn.functional
as
F
from
fairseq
import
metrics
from
fairseq.tasks
import
FairseqTask
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
from
fairseq.dataclass
import
FairseqDataclass
from
omegaconf
import
II
@
dataclass
class
SpeechUnitLmCriterionConfig
(
FairseqDataclass
):
sentence_avg
:
bool
=
II
(
"optimization.sentence_avg"
)
loss_weights
:
str
=
field
(
default
=
"1.;0.0;0.0"
,
metadata
=
{
"help"
:
"Weights of the losses that correspond to token, duration, and F0 streams"
},
)
discrete_duration
:
bool
=
II
(
"task.discrete_duration"
)
discrete_f0
:
bool
=
II
(
"task.discrete_f0"
)
def
mae_loss
(
pred
,
targ
,
mask
,
reduce
=
True
):
if
pred
.
ndim
==
3
:
pred
=
pred
.
squeeze
(
2
)
else
:
assert
pred
.
ndim
==
2
loss
=
(
pred
.
float
()
-
targ
.
float
()).
abs
()
*
(
~
mask
).
float
()
loss
=
loss
.
sum
()
if
reduce
else
loss
.
view
(
-
1
)
return
loss
def
nll_loss
(
pred
,
targ
,
mask
,
reduce
=
True
):
lprob
=
F
.
log_softmax
(
pred
,
dim
=-
1
)
loss
=
F
.
nll_loss
(
lprob
.
view
(
-
1
,
lprob
.
size
(
-
1
)),
targ
.
view
(
-
1
),
reduction
=
"none"
)
loss
=
loss
*
(
~
mask
).
float
().
view
(
-
1
)
loss
=
loss
.
sum
()
if
reduce
else
loss
.
view
(
-
1
)
return
loss
@
register_criterion
(
"speech_unit_lm_criterion"
,
dataclass
=
SpeechUnitLmCriterionConfig
)
class
SpeechUnitLmCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
cfg
:
SpeechUnitLmCriterionConfig
,
task
:
FairseqTask
):
super
().
__init__
(
task
)
self
.
sentence_avg
=
cfg
.
sentence_avg
self
.
weights
=
torch
.
tensor
([
float
(
w
)
for
w
in
cfg
.
loss_weights
.
split
(
";"
)])
assert
self
.
weights
.
size
(
0
)
==
3
assert
(
self
.
weights
>=
0.0
).
all
()
self
.
dur_loss_fn
=
nll_loss
if
cfg
.
discrete_duration
else
mae_loss
self
.
f0_loss_fn
=
nll_loss
if
cfg
.
discrete_f0
else
mae_loss
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output
=
model
(
**
sample
[
"net_input"
])
token_loss
=
nll_loss
(
net_output
[
"token"
],
sample
[
"target"
],
sample
[
"mask"
],
reduce
)
dur_loss
=
self
.
dur_loss_fn
(
net_output
[
"duration"
],
sample
[
"dur_target"
],
sample
[
"dur_mask"
],
reduce
,
)
f0_loss
=
self
.
f0_loss_fn
(
net_output
[
"f0"
],
sample
[
"f0_target"
],
sample
[
"f0_mask"
],
reduce
,
)
loss
=
self
.
weights
.
to
(
token_loss
.
device
)
*
torch
.
stack
(
[
token_loss
,
dur_loss
,
f0_loss
],
dim
=-
1
)
loss
=
loss
.
sum
()
if
reduce
else
loss
.
sum
(
-
1
)
sample_size
=
(
sample
[
"target"
].
size
(
0
)
if
self
.
sentence_avg
else
sample
[
"ntokens"
]
)
logging_output
=
{
"loss"
:
loss
.
detach
().
sum
().
item
(),
"token_loss"
:
token_loss
.
detach
().
sum
().
item
(),
"dur_loss"
:
dur_loss
.
detach
().
sum
().
item
(),
"f0_loss"
:
f0_loss
.
detach
().
sum
().
item
(),
"ntokens"
:
sample
[
"ntokens"
],
"nsentences"
:
sample
[
"target"
].
size
(
0
),
"sample_size"
:
sample_size
,
}
return
loss
,
sample_size
,
logging_output
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
"""Aggregate logging outputs from data parallel training."""
loss_sum
=
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
)
token_loss_sum
=
sum
(
log
.
get
(
"token_loss"
,
0
)
for
log
in
logging_outputs
)
dur_loss_sum
=
sum
(
log
.
get
(
"dur_loss"
,
0
)
for
log
in
logging_outputs
)
f0_loss_sum
=
sum
(
log
.
get
(
"f0_loss"
,
0
)
for
log
in
logging_outputs
)
sample_size
=
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"loss"
,
loss_sum
/
sample_size
,
sample_size
,
round
=
3
)
metrics
.
log_scalar
(
"token_loss"
,
token_loss_sum
/
sample_size
,
sample_size
,
round
=
3
)
metrics
.
log_scalar
(
"dur_loss"
,
dur_loss_sum
/
sample_size
,
sample_size
,
round
=
3
)
metrics
.
log_scalar
(
"f0_loss"
,
f0_loss_sum
/
sample_size
,
sample_size
,
round
=
3
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
return
True
PyTorch/NLP/new-Transformer/fairseq/criterions/tacotron2_loss.py
0 → 100644
View file @
c0f05c10
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
logging
from
dataclasses
import
dataclass
,
field
from
functools
import
lru_cache
from
typing
import
Any
,
Dict
,
List
import
torch
import
torch.nn.functional
as
F
from
omegaconf
import
II
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
from
fairseq.data.data_utils
import
lengths_to_mask
from
fairseq.dataclass
import
FairseqDataclass
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
Tacotron2CriterionConfig
(
FairseqDataclass
):
bce_pos_weight
:
float
=
field
(
default
=
1.0
,
metadata
=
{
"help"
:
"weight of positive examples for BCE loss"
},
)
use_guided_attention_loss
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"use guided attention loss"
},
)
guided_attention_loss_sigma
:
float
=
field
(
default
=
0.4
,
metadata
=
{
"help"
:
"weight of positive examples for BCE loss"
},
)
ctc_weight
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"weight for CTC loss"
})
sentence_avg
:
bool
=
II
(
"optimization.sentence_avg"
)
class
GuidedAttentionLoss
(
torch
.
nn
.
Module
):
"""
Efficiently Trainable Text-to-Speech System Based on Deep Convolutional
Networks with Guided Attention (https://arxiv.org/abs/1710.08969)
"""
def
__init__
(
self
,
sigma
):
super
().
__init__
()
self
.
sigma
=
sigma
@
staticmethod
@
lru_cache
(
maxsize
=
8
)
def
_get_weight
(
s_len
,
t_len
,
sigma
):
grid_x
,
grid_y
=
torch
.
meshgrid
(
torch
.
arange
(
t_len
),
torch
.
arange
(
s_len
))
grid_x
=
grid_x
.
to
(
s_len
.
device
)
grid_y
=
grid_y
.
to
(
s_len
.
device
)
w
=
(
grid_y
.
float
()
/
s_len
-
grid_x
.
float
()
/
t_len
)
**
2
return
1.0
-
torch
.
exp
(
-
w
/
(
2
*
(
sigma
**
2
)))
def
_get_weights
(
self
,
src_lens
,
tgt_lens
):
bsz
,
max_s_len
,
max_t_len
=
len
(
src_lens
),
max
(
src_lens
),
max
(
tgt_lens
)
weights
=
torch
.
zeros
((
bsz
,
max_t_len
,
max_s_len
))
for
i
,
(
s_len
,
t_len
)
in
enumerate
(
zip
(
src_lens
,
tgt_lens
)):
weights
[
i
,
:
t_len
,
:
s_len
]
=
self
.
_get_weight
(
s_len
,
t_len
,
self
.
sigma
)
return
weights
@
staticmethod
def
_get_masks
(
src_lens
,
tgt_lens
):
in_masks
=
lengths_to_mask
(
src_lens
)
out_masks
=
lengths_to_mask
(
tgt_lens
)
return
out_masks
.
unsqueeze
(
2
)
&
in_masks
.
unsqueeze
(
1
)
def
forward
(
self
,
attn
,
src_lens
,
tgt_lens
,
reduction
=
"mean"
):
weights
=
self
.
_get_weights
(
src_lens
,
tgt_lens
).
to
(
attn
.
device
)
masks
=
self
.
_get_masks
(
src_lens
,
tgt_lens
).
to
(
attn
.
device
)
loss
=
(
weights
*
attn
.
transpose
(
1
,
2
)).
masked_select
(
masks
)
loss
=
torch
.
sum
(
loss
)
if
reduction
==
"sum"
else
torch
.
mean
(
loss
)
return
loss
@
register_criterion
(
"tacotron2"
,
dataclass
=
Tacotron2CriterionConfig
)
class
Tacotron2Criterion
(
FairseqCriterion
):
def
__init__
(
self
,
task
,
sentence_avg
,
use_guided_attention_loss
,
guided_attention_loss_sigma
,
bce_pos_weight
,
ctc_weight
,
):
super
().
__init__
(
task
)
self
.
sentence_avg
=
sentence_avg
self
.
bce_pos_weight
=
bce_pos_weight
self
.
guided_attn
=
None
if
use_guided_attention_loss
:
self
.
guided_attn
=
GuidedAttentionLoss
(
guided_attention_loss_sigma
)
self
.
ctc_weight
=
ctc_weight
def
forward
(
self
,
model
,
sample
,
reduction
=
"mean"
):
bsz
,
max_len
,
_
=
sample
[
"target"
].
size
()
feat_tgt
=
sample
[
"target"
]
feat_len
=
sample
[
"target_lengths"
].
view
(
bsz
,
1
).
expand
(
-
1
,
max_len
)
eos_tgt
=
torch
.
arange
(
max_len
).
to
(
sample
[
"target"
].
device
)
eos_tgt
=
eos_tgt
.
view
(
1
,
max_len
).
expand
(
bsz
,
-
1
)
eos_tgt
=
(
eos_tgt
==
(
feat_len
-
1
)).
float
()
src_tokens
=
sample
[
"net_input"
][
"src_tokens"
]
src_lens
=
sample
[
"net_input"
][
"src_lengths"
]
tgt_lens
=
sample
[
"target_lengths"
]
feat_out
,
eos_out
,
extra
=
model
(
src_tokens
=
src_tokens
,
src_lengths
=
src_lens
,
prev_output_tokens
=
sample
[
"net_input"
][
"prev_output_tokens"
],
incremental_state
=
None
,
target_lengths
=
tgt_lens
,
speaker
=
sample
[
"speaker"
],
)
l1_loss
,
mse_loss
,
eos_loss
=
self
.
compute_loss
(
extra
[
"feature_out"
],
feat_out
,
eos_out
,
feat_tgt
,
eos_tgt
,
tgt_lens
,
reduction
,
)
attn_loss
=
torch
.
tensor
(
0.0
).
type_as
(
l1_loss
)
if
self
.
guided_attn
is
not
None
:
attn_loss
=
self
.
guided_attn
(
extra
[
"attn"
],
src_lens
,
tgt_lens
,
reduction
)
ctc_loss
=
torch
.
tensor
(
0.0
).
type_as
(
l1_loss
)
if
self
.
ctc_weight
>
0.0
:
net_output
=
(
feat_out
,
eos_out
,
extra
)
lprobs
=
model
.
get_normalized_probs
(
net_output
,
log_probs
=
True
)
lprobs
=
lprobs
.
transpose
(
0
,
1
)
# T x B x C
src_mask
=
lengths_to_mask
(
src_lens
)
src_tokens_flat
=
src_tokens
.
masked_select
(
src_mask
)
ctc_loss
=
(
F
.
ctc_loss
(
lprobs
,
src_tokens_flat
,
tgt_lens
,
src_lens
,
reduction
=
reduction
,
zero_infinity
=
True
,
)
*
self
.
ctc_weight
)
loss
=
l1_loss
+
mse_loss
+
eos_loss
+
attn_loss
+
ctc_loss
sample_size
=
sample
[
"nsentences"
]
if
self
.
sentence_avg
else
sample
[
"ntokens"
]
logging_output
=
{
"loss"
:
utils
.
item
(
loss
.
data
),
"ntokens"
:
sample
[
"ntokens"
],
"nsentences"
:
sample
[
"nsentences"
],
"sample_size"
:
sample_size
,
"l1_loss"
:
utils
.
item
(
l1_loss
.
data
),
"mse_loss"
:
utils
.
item
(
mse_loss
.
data
),
"eos_loss"
:
utils
.
item
(
eos_loss
.
data
),
"attn_loss"
:
utils
.
item
(
attn_loss
.
data
),
"ctc_loss"
:
utils
.
item
(
ctc_loss
.
data
),
}
return
loss
,
sample_size
,
logging_output
def
compute_loss
(
self
,
feat_out
,
feat_out_post
,
eos_out
,
feat_tgt
,
eos_tgt
,
tgt_lens
,
reduction
=
"mean"
,
):
mask
=
lengths_to_mask
(
tgt_lens
)
_eos_out
=
eos_out
[
mask
].
squeeze
()
_eos_tgt
=
eos_tgt
[
mask
]
_feat_tgt
=
feat_tgt
[
mask
]
_feat_out
=
feat_out
[
mask
]
_feat_out_post
=
feat_out_post
[
mask
]
l1_loss
=
F
.
l1_loss
(
_feat_out
,
_feat_tgt
,
reduction
=
reduction
)
+
F
.
l1_loss
(
_feat_out_post
,
_feat_tgt
,
reduction
=
reduction
)
mse_loss
=
F
.
mse_loss
(
_feat_out
,
_feat_tgt
,
reduction
=
reduction
)
+
F
.
mse_loss
(
_feat_out_post
,
_feat_tgt
,
reduction
=
reduction
)
eos_loss
=
F
.
binary_cross_entropy_with_logits
(
_eos_out
,
_eos_tgt
,
pos_weight
=
torch
.
tensor
(
self
.
bce_pos_weight
),
reduction
=
reduction
,
)
return
l1_loss
,
mse_loss
,
eos_loss
@
classmethod
def
reduce_metrics
(
cls
,
logging_outputs
:
List
[
Dict
[
str
,
Any
]])
->
None
:
ns
=
[
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
]
ntot
=
sum
(
ns
)
ws
=
[
n
/
(
ntot
+
1e-8
)
for
n
in
ns
]
for
key
in
[
"loss"
,
"l1_loss"
,
"mse_loss"
,
"eos_loss"
,
"attn_loss"
,
"ctc_loss"
]:
vals
=
[
log
.
get
(
key
,
0
)
for
log
in
logging_outputs
]
val
=
sum
(
val
*
w
for
val
,
w
in
zip
(
vals
,
ws
))
metrics
.
log_scalar
(
key
,
val
,
ntot
,
round
=
3
)
metrics
.
log_scalar
(
"sample_size"
,
ntot
,
len
(
logging_outputs
))
# inference metrics
if
"targ_frames"
not
in
logging_outputs
[
0
]:
return
n
=
sum
(
log
.
get
(
"targ_frames"
,
0
)
for
log
in
logging_outputs
)
for
key
,
new_key
in
[
(
"mcd_loss"
,
"mcd_loss"
),
(
"pred_frames"
,
"pred_ratio"
),
(
"nins"
,
"ins_rate"
),
(
"ndel"
,
"del_rate"
),
]:
val
=
sum
(
log
.
get
(
key
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
new_key
,
val
/
n
,
n
,
round
=
3
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
return
False
PyTorch/NLP/new-Transformer/fairseq/criterions/wav2vec_criterion.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
dataclasses
import
dataclass
,
field
from
typing
import
List
,
Optional
import
torch
import
torch.nn.functional
as
F
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
from
fairseq.dataclass
import
FairseqDataclass
from
fairseq.logging.meters
import
safe_round
from
fairseq.utils
import
is_xla_tensor
@
dataclass
class
Wav2VecCriterionConfig
(
FairseqDataclass
):
infonce
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"if set, uses cross entropy instead of binary cross entropy (i.e. InfoNCE loss)"
},
)
loss_weights
:
Optional
[
List
[
float
]]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"weights for additional loss terms (not first one)"
},
)
log_keys
:
List
[
str
]
=
field
(
default_factory
=
lambda
:
[],
metadata
=
{
"help"
:
"output keys to log"
},
)
@
register_criterion
(
"wav2vec"
,
dataclass
=
Wav2VecCriterionConfig
)
class
Wav2vecCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
task
,
infonce
=
False
,
loss_weights
=
None
,
log_keys
=
None
):
super
().
__init__
(
task
)
self
.
infonce
=
infonce
self
.
loss_weights
=
loss_weights
self
.
log_keys
=
[]
if
log_keys
is
None
else
log_keys
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output
=
model
(
**
sample
[
"net_input"
])
logits
=
model
.
get_logits
(
net_output
).
float
()
target
=
model
.
get_targets
(
sample
,
net_output
)
self
.
xla
=
is_xla_tensor
(
logits
)
# XXX: handle weights on xla.
weights
=
None
if
hasattr
(
model
,
"get_target_weights"
)
and
not
self
.
infonce
:
weights
=
model
.
get_target_weights
(
target
,
net_output
)
if
torch
.
is_tensor
(
weights
):
weights
=
weights
.
float
()
losses
=
[]
reduction
=
"none"
if
((
not
reduce
)
or
self
.
xla
)
else
"sum"
if
self
.
infonce
:
loss
=
F
.
cross_entropy
(
logits
,
target
,
reduction
=
reduction
)
else
:
loss
=
F
.
binary_cross_entropy_with_logits
(
logits
,
target
.
float
(),
weights
,
reduction
=
reduction
)
if
self
.
xla
:
# tpu-comment: since dynamic shapes lead to recompilations on xla,
# we don't shrink tensors using mask_indices.
# Instead, we use mask indices to adjust loss.
mi
=
(
sample
[
"net_input"
][
"mask_indices"
]
.
transpose
(
0
,
1
)
# logits are transposed in `model.get_logits`
.
reshape
(
logits
.
size
(
0
))
)
loss
=
(
loss
*
mi
).
sum
()
if
reduce
else
(
loss
*
mi
)
if
"sample_size"
in
sample
:
sample_size
=
sample
[
"sample_size"
]
elif
"mask_indices"
in
sample
[
"net_input"
]:
sample_size
=
sample
[
"net_input"
][
"mask_indices"
].
sum
()
else
:
sample_size
=
target
.
numel
()
if
self
.
infonce
else
target
.
long
().
sum
().
item
()
losses
.
append
(
loss
.
detach
().
clone
())
if
self
.
loss_weights
is
not
None
:
assert
hasattr
(
model
,
"get_extra_losses"
)
extra_losses
=
model
.
get_extra_losses
(
net_output
)
if
torch
.
is_tensor
(
extra_losses
):
extra_losses
=
[
extra_losses
]
if
len
(
self
.
loss_weights
)
==
1
and
len
(
extra_losses
)
!=
1
:
self
.
loss_weights
=
[
self
.
loss_weights
[
0
]]
*
len
(
extra_losses
)
assert
len
(
extra_losses
)
==
len
(
self
.
loss_weights
),
f
"
{
len
(
extra_losses
)
}
,
{
len
(
self
.
loss_weights
)
}
"
for
p
,
coef
in
zip
(
extra_losses
,
self
.
loss_weights
):
if
coef
!=
0
and
p
is
not
None
:
p
=
coef
*
p
.
float
()
*
sample_size
loss
+=
p
losses
.
append
(
p
)
logging_output
=
{
"loss"
:
loss
.
item
()
if
(
reduce
and
not
self
.
xla
)
else
loss
.
detach
(),
"ntokens"
:
sample_size
,
"nsentences"
:
sample
[
"id"
].
numel
(),
"sample_size"
:
sample_size
,
}
for
lk
in
self
.
log_keys
:
# Only store "logits" and "target" for computing MAP and MAUC
# during validation
if
lk
==
"logits"
:
if
not
self
.
training
:
logging_output
[
"logits"
]
=
logits
.
cpu
().
numpy
()
elif
lk
==
"target"
:
if
not
self
.
training
:
# If the targets have been mixed with the predictions of
# teacher models, find the original targets
if
hasattr
(
model
,
"get_original_targets"
):
original_target
=
model
.
get_original_targets
(
sample
,
net_output
)
else
:
original_target
=
target
logging_output
[
"target"
]
=
original_target
.
cpu
().
numpy
()
elif
lk
in
net_output
:
value
=
net_output
[
lk
]
if
not
is_xla_tensor
(
value
):
value
=
float
(
value
)
logging_output
[
lk
]
=
value
if
len
(
losses
)
>
1
:
for
i
,
l
in
enumerate
(
losses
):
logging_output
[
f
"loss_
{
i
}
"
]
=
l
.
item
()
if
not
self
.
xla
else
l
.
detach
()
if
self
.
infonce
:
with
torch
.
no_grad
():
if
logits
.
numel
()
==
0
:
corr
=
0
count
=
0
else
:
assert
logits
.
dim
()
>
1
,
logits
.
shape
max
=
logits
.
argmax
(
-
1
)
==
0
min
=
logits
.
argmin
(
-
1
)
==
0
if
is_xla_tensor
(
logits
):
max
,
min
=
max
*
mi
,
min
*
mi
both
=
max
&
min
corr
=
max
.
long
().
sum
()
-
both
.
long
().
sum
()
count
=
mi
.
sum
()
else
:
both
=
max
&
min
corr
=
max
.
long
().
sum
().
item
()
-
both
.
long
().
sum
().
item
()
count
=
float
(
max
.
numel
())
logging_output
[
"correct"
]
=
corr
logging_output
[
"count"
]
=
count
return
loss
,
sample_size
,
logging_output
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
"""Aggregate logging outputs from data parallel training."""
loss_sum
=
utils
.
item
(
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
))
ntokens
=
utils
.
item
(
sum
(
log
.
get
(
"ntokens"
,
0
)
for
log
in
logging_outputs
))
nsentences
=
utils
.
item
(
sum
(
log
.
get
(
"nsentences"
,
0
)
for
log
in
logging_outputs
)
)
sample_size
=
utils
.
item
(
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
)
metrics
.
log_scalar
(
"loss"
,
loss_sum
/
(
sample_size
or
1
)
/
math
.
log
(
2
),
sample_size
,
round
=
3
)
metrics
.
log_scalar
(
"ntokens"
,
ntokens
)
metrics
.
log_scalar
(
"nsentences"
,
nsentences
)
correct
=
sum
(
log
.
get
(
"correct"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_correct"
,
correct
)
total
=
sum
(
log
.
get
(
"count"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_total"
,
total
)
if
total
>
0
:
metrics
.
log_derived
(
"accuracy"
,
lambda
meters
:
safe_round
(
meters
[
"_correct"
].
sum
/
meters
[
"_total"
].
sum
,
5
)
if
meters
[
"_total"
].
sum
>
0
else
float
(
"nan"
),
)
builtin_keys
=
{
"loss"
,
"ntokens"
,
"nsentences"
,
"sample_size"
,
"correct"
,
"count"
,
}
for
k
in
logging_outputs
[
0
]:
if
k
not
in
builtin_keys
:
val
=
sum
(
log
.
get
(
k
,
0
)
for
log
in
logging_outputs
)
if
k
.
startswith
(
"loss"
):
metrics
.
log_scalar
(
k
,
val
/
(
sample_size
or
1
)
/
math
.
log
(
2
),
sample_size
,
round
=
3
)
else
:
metrics
.
log_scalar
(
k
,
val
/
len
(
logging_outputs
),
round
=
3
)
# FIXME: revert when gather based xla reduction is implemented
# @staticmethod
# def logging_outputs_can_be_summed() -> bool:
def
logging_outputs_can_be_summed
(
self
)
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
# XXX: Gather based reduction not implemented for xla yet.
# So we fall to sum based reduction for xla.
return
self
.
xla
PyTorch/NLP/new-Transformer/fairseq/data/__init__.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
from
.dictionary
import
Dictionary
,
TruncatedDictionary
from
.fairseq_dataset
import
FairseqDataset
,
FairseqIterableDataset
from
.base_wrapper_dataset
import
BaseWrapperDataset
from
.add_target_dataset
import
AddTargetDataset
from
.append_token_dataset
import
AppendTokenDataset
from
.audio.raw_audio_dataset
import
BinarizedAudioDataset
,
FileAudioDataset
from
.audio.hubert_dataset
import
HubertDataset
from
.backtranslation_dataset
import
BacktranslationDataset
from
.bucket_pad_length_dataset
import
BucketPadLengthDataset
from
.colorize_dataset
import
ColorizeDataset
from
.concat_dataset
import
ConcatDataset
from
.concat_sentences_dataset
import
ConcatSentencesDataset
from
.denoising_dataset
import
DenoisingDataset
from
.id_dataset
import
IdDataset
from
.indexed_dataset
import
(
IndexedCachedDataset
,
IndexedDataset
,
IndexedRawTextDataset
,
MMapIndexedDataset
,
)
from
.language_pair_dataset
import
LanguagePairDataset
from
.list_dataset
import
ListDataset
from
.lm_context_window_dataset
import
LMContextWindowDataset
from
.lru_cache_dataset
import
LRUCacheDataset
from
.mask_tokens_dataset
import
MaskTokensDataset
from
.monolingual_dataset
import
MonolingualDataset
from
.multi_corpus_sampled_dataset
import
MultiCorpusSampledDataset
from
.nested_dictionary_dataset
import
NestedDictionaryDataset
from
.noising
import
NoisingDataset
from
.numel_dataset
import
NumelDataset
from
.num_samples_dataset
import
NumSamplesDataset
from
.offset_tokens_dataset
import
OffsetTokensDataset
from
.pad_dataset
import
LeftPadDataset
,
PadDataset
,
RightPadDataset
from
.prepend_dataset
import
PrependDataset
from
.prepend_token_dataset
import
PrependTokenDataset
from
.raw_label_dataset
import
RawLabelDataset
from
.replace_dataset
import
ReplaceDataset
from
.resampling_dataset
import
ResamplingDataset
from
.roll_dataset
import
RollDataset
from
.round_robin_zip_datasets
import
RoundRobinZipDatasets
from
.sort_dataset
import
SortDataset
from
.strip_token_dataset
import
StripTokenDataset
from
.subsample_dataset
import
SubsampleDataset
from
.token_block_dataset
import
TokenBlockDataset
from
.transform_eos_dataset
import
TransformEosDataset
from
.transform_eos_lang_pair_dataset
import
TransformEosLangPairDataset
from
.shorten_dataset
import
TruncateDataset
,
RandomCropDataset
from
.multilingual.sampled_multi_dataset
import
SampledMultiDataset
from
.multilingual.sampled_multi_epoch_dataset
import
SampledMultiEpochDataset
from
.fasta_dataset
import
FastaDataset
,
EncodedFastaDataset
from
.transform_eos_concat_langpair_dataset
import
TransformEosConcatLangPairDataset
from
.iterators
import
(
CountingIterator
,
EpochBatchIterator
,
GroupedIterator
,
ShardedIterator
,
)
__all__
=
[
"AddTargetDataset"
,
"AppendTokenDataset"
,
"BacktranslationDataset"
,
"BaseWrapperDataset"
,
"BinarizedAudioDataset"
,
"BucketPadLengthDataset"
,
"ColorizeDataset"
,
"ConcatDataset"
,
"ConcatSentencesDataset"
,
"CountingIterator"
,
"DenoisingDataset"
,
"Dictionary"
,
"EncodedFastaDataset"
,
"EpochBatchIterator"
,
"FairseqDataset"
,
"FairseqIterableDataset"
,
"FastaDataset"
,
"FileAudioDataset"
,
"GroupedIterator"
,
"HubertDataset"
,
"IdDataset"
,
"IndexedCachedDataset"
,
"IndexedDataset"
,
"IndexedRawTextDataset"
,
"LanguagePairDataset"
,
"LeftPadDataset"
,
"ListDataset"
,
"LMContextWindowDataset"
,
"LRUCacheDataset"
,
"MaskTokensDataset"
,
"MMapIndexedDataset"
,
"MonolingualDataset"
,
"MultiCorpusSampledDataset"
,
"NestedDictionaryDataset"
,
"NoisingDataset"
,
"NumelDataset"
,
"NumSamplesDataset"
,
"OffsetTokensDataset"
,
"PadDataset"
,
"PrependDataset"
,
"PrependTokenDataset"
,
"RandomCropDataset"
,
"RawLabelDataset"
,
"ResamplingDataset"
,
"ReplaceDataset"
,
"RightPadDataset"
,
"RollDataset"
,
"RoundRobinZipDatasets"
,
"SampledMultiDataset"
,
"SampledMultiEpochDataset"
,
"ShardedIterator"
,
"SortDataset"
,
"StripTokenDataset"
,
"SubsampleDataset"
,
"TokenBlockDataset"
,
"TransformEosDataset"
,
"TransformEosLangPairDataset"
,
"TransformEosConcatLangPairDataset"
,
"TruncateDataset"
,
"TruncatedDictionary"
,
]
PyTorch/NLP/new-Transformer/fairseq/data/add_target_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
from
.
import
BaseWrapperDataset
,
data_utils
from
fairseq.data.text_compressor
import
TextCompressor
,
TextCompressionLevel
class
AddTargetDataset
(
BaseWrapperDataset
):
def
__init__
(
self
,
dataset
,
labels
,
pad
,
eos
,
batch_targets
,
process_label
=
None
,
label_len_fn
=
None
,
add_to_input
=
False
,
text_compression_level
=
TextCompressionLevel
.
none
,
):
super
().
__init__
(
dataset
)
self
.
labels
=
labels
self
.
batch_targets
=
batch_targets
self
.
pad
=
pad
self
.
eos
=
eos
self
.
process_label
=
process_label
self
.
label_len_fn
=
label_len_fn
self
.
add_to_input
=
add_to_input
self
.
text_compressor
=
TextCompressor
(
level
=
text_compression_level
)
def
get_label
(
self
,
index
,
process_fn
=
None
):
lbl
=
self
.
labels
[
index
]
lbl
=
self
.
text_compressor
.
decompress
(
lbl
)
return
lbl
if
process_fn
is
None
else
process_fn
(
lbl
)
def
__getitem__
(
self
,
index
):
item
=
self
.
dataset
[
index
]
item
[
"label"
]
=
self
.
get_label
(
index
,
process_fn
=
self
.
process_label
)
return
item
def
size
(
self
,
index
):
sz
=
self
.
dataset
.
size
(
index
)
own_sz
=
self
.
label_len_fn
(
self
.
get_label
(
index
))
return
sz
,
own_sz
def
collater
(
self
,
samples
):
collated
=
self
.
dataset
.
collater
(
samples
)
if
len
(
collated
)
==
0
:
return
collated
indices
=
set
(
collated
[
"id"
].
tolist
())
target
=
[
s
[
"label"
]
for
s
in
samples
if
s
[
"id"
]
in
indices
]
if
self
.
add_to_input
:
eos
=
torch
.
LongTensor
([
self
.
eos
])
prev_output_tokens
=
[
torch
.
cat
([
eos
,
t
],
axis
=-
1
)
for
t
in
target
]
target
=
[
torch
.
cat
([
t
,
eos
],
axis
=-
1
)
for
t
in
target
]
collated
[
"net_input"
][
"prev_output_tokens"
]
=
prev_output_tokens
if
self
.
batch_targets
:
collated
[
"target_lengths"
]
=
torch
.
LongTensor
([
len
(
t
)
for
t
in
target
])
target
=
data_utils
.
collate_tokens
(
target
,
pad_idx
=
self
.
pad
,
left_pad
=
False
)
collated
[
"ntokens"
]
=
collated
[
"target_lengths"
].
sum
().
item
()
if
getattr
(
collated
[
"net_input"
],
"prev_output_tokens"
,
None
):
collated
[
"net_input"
][
"prev_output_tokens"
]
=
data_utils
.
collate_tokens
(
collated
[
"net_input"
][
"prev_output_tokens"
],
pad_idx
=
self
.
pad
,
left_pad
=
False
,
)
else
:
collated
[
"ntokens"
]
=
sum
([
len
(
t
)
for
t
in
target
])
collated
[
"target"
]
=
target
return
collated
def
filter_indices_by_size
(
self
,
indices
,
max_sizes
):
indices
,
ignored
=
data_utils
.
_filter_by_size_dynamic
(
indices
,
self
.
size
,
max_sizes
)
return
indices
,
ignored
PyTorch/NLP/new-Transformer/fairseq/data/append_token_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
numpy
as
np
import
torch
from
.
import
BaseWrapperDataset
class
AppendTokenDataset
(
BaseWrapperDataset
):
def
__init__
(
self
,
dataset
,
token
=
None
):
super
().
__init__
(
dataset
)
self
.
token
=
token
if
token
is
not
None
:
self
.
_sizes
=
np
.
array
(
dataset
.
sizes
)
+
1
else
:
self
.
_sizes
=
dataset
.
sizes
def
__getitem__
(
self
,
idx
):
item
=
self
.
dataset
[
idx
]
if
self
.
token
is
not
None
:
item
=
torch
.
cat
([
item
,
item
.
new
([
self
.
token
])])
return
item
@
property
def
sizes
(
self
):
return
self
.
_sizes
def
num_tokens
(
self
,
index
):
n
=
self
.
dataset
.
num_tokens
(
index
)
if
self
.
token
is
not
None
:
n
+=
1
return
n
def
size
(
self
,
index
):
n
=
self
.
dataset
.
size
(
index
)
if
self
.
token
is
not
None
:
n
+=
1
return
n
PyTorch/NLP/Transformer/
scripts
/__init__.py
→
PyTorch/NLP/
new-
Transformer/
fairseq/data/audio
/__init__.py
View file @
c0f05c10
File moved
PyTorch/NLP/new-Transformer/fairseq/data/audio/audio_utils.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
mmap
from
pathlib
import
Path
from
typing
import
BinaryIO
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
SF_AUDIO_FILE_EXTENSIONS
=
{
".wav"
,
".flac"
,
".ogg"
}
FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS
=
{
".npy"
,
".wav"
,
".flac"
,
".ogg"
}
def
convert_waveform
(
waveform
:
Union
[
np
.
ndarray
,
torch
.
Tensor
],
sample_rate
:
int
,
normalize_volume
:
bool
=
False
,
to_mono
:
bool
=
False
,
to_sample_rate
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
Union
[
np
.
ndarray
,
torch
.
Tensor
],
int
]:
"""convert a waveform:
- to a target sample rate
- from multi-channel to mono channel
- volume normalization
Args:
waveform (numpy.ndarray or torch.Tensor): 2D original waveform
(channels x length)
sample_rate (int): original sample rate
normalize_volume (bool): perform volume normalization
to_mono (bool): convert to mono channel if having multiple channels
to_sample_rate (Optional[int]): target sample rate
Returns:
waveform (numpy.ndarray): converted 2D waveform (channels x length)
sample_rate (float): target sample rate
"""
try
:
import
torchaudio.sox_effects
as
ta_sox
except
ImportError
:
raise
ImportError
(
"Please install torchaudio: pip install torchaudio"
)
effects
=
[]
if
normalize_volume
:
effects
.
append
([
"gain"
,
"-n"
])
if
to_sample_rate
is
not
None
and
to_sample_rate
!=
sample_rate
:
effects
.
append
([
"rate"
,
f
"
{
to_sample_rate
}
"
])
if
to_mono
and
waveform
.
shape
[
0
]
>
1
:
effects
.
append
([
"channels"
,
"1"
])
if
len
(
effects
)
>
0
:
is_np_input
=
isinstance
(
waveform
,
np
.
ndarray
)
_waveform
=
torch
.
from_numpy
(
waveform
)
if
is_np_input
else
waveform
converted
,
converted_sample_rate
=
ta_sox
.
apply_effects_tensor
(
_waveform
,
sample_rate
,
effects
)
if
is_np_input
:
converted
=
converted
.
numpy
()
return
converted
,
converted_sample_rate
return
waveform
,
sample_rate
def
get_waveform
(
path_or_fp
:
Union
[
str
,
BinaryIO
],
normalization
:
bool
=
True
,
mono
:
bool
=
True
,
frames
:
int
=
-
1
,
start
:
int
=
0
,
always_2d
:
bool
=
True
,
output_sample_rate
:
Optional
[
int
]
=
None
,
normalize_volume
:
bool
=
False
,
)
->
Tuple
[
np
.
ndarray
,
int
]:
"""Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio.
Args:
path_or_fp (str or BinaryIO): the path or file-like object
normalization (bool): normalize values to [-1, 1] (Default: True)
mono (bool): convert multi-channel audio to mono-channel one
frames (int): the number of frames to read. (-1 for reading all)
start (int): Where to start reading. A negative value counts from the end.
always_2d (bool): always return 2D array even for mono-channel audios
output_sample_rate (Optional[int]): output sample rate
normalize_volume (bool): normalize volume
Returns:
waveform (numpy.ndarray): 1D or 2D waveform (channels x length)
sample_rate (float): sample rate
"""
if
isinstance
(
path_or_fp
,
str
):
ext
=
Path
(
path_or_fp
).
suffix
if
ext
not
in
SF_AUDIO_FILE_EXTENSIONS
:
raise
ValueError
(
f
"Unsupported audio format:
{
ext
}
"
)
try
:
import
soundfile
as
sf
except
ImportError
:
raise
ImportError
(
"Please install soundfile: pip install soundfile"
)
waveform
,
sample_rate
=
sf
.
read
(
path_or_fp
,
dtype
=
"float32"
,
always_2d
=
True
,
frames
=
frames
,
start
=
start
)
waveform
=
waveform
.
T
# T x C -> C x T
waveform
,
sample_rate
=
convert_waveform
(
waveform
,
sample_rate
,
normalize_volume
=
normalize_volume
,
to_mono
=
mono
,
to_sample_rate
=
output_sample_rate
,
)
if
not
normalization
:
waveform
*=
2
**
15
# denormalized to 16-bit signed integers
if
not
always_2d
:
waveform
=
waveform
.
squeeze
(
axis
=
0
)
return
waveform
,
sample_rate
def
_get_kaldi_fbank
(
waveform
:
np
.
ndarray
,
sample_rate
:
int
,
n_bins
=
80
)
->
Optional
[
np
.
ndarray
]:
"""Get mel-filter bank features via PyKaldi."""
try
:
from
kaldi.feat.fbank
import
Fbank
,
FbankOptions
from
kaldi.feat.mel
import
MelBanksOptions
from
kaldi.feat.window
import
FrameExtractionOptions
from
kaldi.matrix
import
Vector
mel_opts
=
MelBanksOptions
()
mel_opts
.
num_bins
=
n_bins
frame_opts
=
FrameExtractionOptions
()
frame_opts
.
samp_freq
=
sample_rate
opts
=
FbankOptions
()
opts
.
mel_opts
=
mel_opts
opts
.
frame_opts
=
frame_opts
fbank
=
Fbank
(
opts
=
opts
)
features
=
fbank
.
compute
(
Vector
(
waveform
.
squeeze
()),
1.0
).
numpy
()
return
features
except
ImportError
:
return
None
def
_get_torchaudio_fbank
(
waveform
:
np
.
ndarray
,
sample_rate
,
n_bins
=
80
)
->
Optional
[
np
.
ndarray
]:
"""Get mel-filter bank features via TorchAudio."""
try
:
import
torchaudio.compliance.kaldi
as
ta_kaldi
waveform
=
torch
.
from_numpy
(
waveform
)
features
=
ta_kaldi
.
fbank
(
waveform
,
num_mel_bins
=
n_bins
,
sample_frequency
=
sample_rate
)
return
features
.
numpy
()
except
ImportError
:
return
None
def
get_fbank
(
path_or_fp
:
Union
[
str
,
BinaryIO
],
n_bins
=
80
)
->
np
.
ndarray
:
"""Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi
(faster CPP implementation) to TorchAudio (Python implementation). Note that
Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the
waveform should not be normalized."""
waveform
,
sample_rate
=
get_waveform
(
path_or_fp
,
normalization
=
False
)
features
=
_get_kaldi_fbank
(
waveform
,
sample_rate
,
n_bins
)
if
features
is
None
:
features
=
_get_torchaudio_fbank
(
waveform
,
sample_rate
,
n_bins
)
if
features
is
None
:
raise
ImportError
(
"Please install pyKaldi or torchaudio to enable "
"online filterbank feature extraction"
)
return
features
def
is_npy_data
(
data
:
bytes
)
->
bool
:
return
data
[
0
]
==
147
and
data
[
1
]
==
78
def
is_sf_audio_data
(
data
:
bytes
)
->
bool
:
is_wav
=
data
[
0
]
==
82
and
data
[
1
]
==
73
and
data
[
2
]
==
70
is_flac
=
data
[
0
]
==
102
and
data
[
1
]
==
76
and
data
[
2
]
==
97
is_ogg
=
data
[
0
]
==
79
and
data
[
1
]
==
103
and
data
[
2
]
==
103
return
is_wav
or
is_flac
or
is_ogg
def
mmap_read
(
path
:
str
,
offset
:
int
,
length
:
int
)
->
bytes
:
with
open
(
path
,
"rb"
)
as
f
:
with
mmap
.
mmap
(
f
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_o
:
data
=
mmap_o
[
offset
:
offset
+
length
]
return
data
def
read_from_stored_zip
(
zip_path
:
str
,
offset
:
int
,
length
:
int
)
->
bytes
:
return
mmap_read
(
zip_path
,
offset
,
length
)
def
parse_path
(
path
:
str
)
->
Tuple
[
str
,
List
[
int
]]:
"""Parse data path which is either a path to
1. a .npy/.wav/.flac/.ogg file
2. a stored ZIP file with slicing info: "[zip_path]:[offset]:[length]"
Args:
path (str): the data path to parse
Returns:
file_path (str): the file path
slice_ptr (list of int): empty in case 1;
byte offset and length for the slice in case 2
"""
if
Path
(
path
).
suffix
in
FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS
:
_path
,
slice_ptr
=
path
,
[]
else
:
_path
,
*
slice_ptr
=
path
.
split
(
":"
)
if
not
Path
(
_path
).
is_file
():
raise
FileNotFoundError
(
f
"File not found:
{
_path
}
"
)
assert
len
(
slice_ptr
)
in
{
0
,
2
},
f
"Invalid path:
{
path
}
"
slice_ptr
=
[
int
(
i
)
for
i
in
slice_ptr
]
return
_path
,
slice_ptr
def
get_window
(
window_fn
:
callable
,
n_fft
:
int
,
win_length
:
int
)
->
torch
.
Tensor
:
padding
=
n_fft
-
win_length
assert
padding
>=
0
return
F
.
pad
(
window_fn
(
win_length
),
(
padding
//
2
,
padding
-
padding
//
2
))
def
get_fourier_basis
(
n_fft
:
int
)
->
torch
.
Tensor
:
basis
=
np
.
fft
.
fft
(
np
.
eye
(
n_fft
))
basis
=
np
.
vstack
(
[
np
.
real
(
basis
[:
n_fft
//
2
+
1
,
:]),
np
.
imag
(
basis
[:
n_fft
//
2
+
1
,
:])]
)
return
torch
.
from_numpy
(
basis
).
float
()
def
get_mel_filters
(
sample_rate
:
int
,
n_fft
:
int
,
n_mels
:
int
,
f_min
:
float
,
f_max
:
float
)
->
torch
.
Tensor
:
try
:
import
librosa
except
ImportError
:
raise
ImportError
(
"Please install librosa: pip install librosa"
)
basis
=
librosa
.
filters
.
mel
(
sample_rate
,
n_fft
,
n_mels
,
f_min
,
f_max
)
return
torch
.
from_numpy
(
basis
).
float
()
class
TTSSpectrogram
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
n_fft
:
int
,
win_length
:
int
,
hop_length
:
int
,
window_fn
:
callable
=
torch
.
hann_window
,
return_phase
:
bool
=
False
,
)
->
None
:
super
(
TTSSpectrogram
,
self
).
__init__
()
self
.
n_fft
=
n_fft
self
.
hop_length
=
hop_length
self
.
return_phase
=
return_phase
basis
=
get_fourier_basis
(
n_fft
).
unsqueeze
(
1
)
basis
*=
get_window
(
window_fn
,
n_fft
,
win_length
)
self
.
register_buffer
(
"basis"
,
basis
)
def
forward
(
self
,
waveform
:
torch
.
Tensor
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
padding
=
(
self
.
n_fft
//
2
,
self
.
n_fft
//
2
)
x
=
F
.
pad
(
waveform
.
unsqueeze
(
1
),
padding
,
mode
=
"reflect"
)
x
=
F
.
conv1d
(
x
,
self
.
basis
,
stride
=
self
.
hop_length
)
real_part
=
x
[:,
:
self
.
n_fft
//
2
+
1
,
:]
imag_part
=
x
[:,
self
.
n_fft
//
2
+
1
:,
:]
magnitude
=
torch
.
sqrt
(
real_part
**
2
+
imag_part
**
2
)
if
self
.
return_phase
:
phase
=
torch
.
atan2
(
imag_part
,
real_part
)
return
magnitude
,
phase
return
magnitude
class
TTSMelScale
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
n_mels
:
int
,
sample_rate
:
int
,
f_min
:
float
,
f_max
:
float
,
n_stft
:
int
)
->
None
:
super
(
TTSMelScale
,
self
).
__init__
()
basis
=
get_mel_filters
(
sample_rate
,
(
n_stft
-
1
)
*
2
,
n_mels
,
f_min
,
f_max
)
self
.
register_buffer
(
"basis"
,
basis
)
def
forward
(
self
,
specgram
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
matmul
(
self
.
basis
,
specgram
)
PyTorch/NLP/new-Transformer/fairseq/data/audio/data_cfg.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
argparse
import
Namespace
from
pathlib
import
Path
from
typing
import
Dict
,
Optional
from
fairseq.data
import
Dictionary
def
get_config_from_yaml
(
yaml_path
:
Path
):
try
:
import
yaml
except
ImportError
:
print
(
"Please install PyYAML: pip install PyYAML"
)
config
=
{}
if
yaml_path
.
is_file
():
try
:
with
open
(
yaml_path
)
as
f
:
config
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
except
Exception
as
e
:
raise
Exception
(
f
"Failed to load config from
{
yaml_path
.
as_posix
()
}
:
{
e
}
"
)
else
:
raise
FileNotFoundError
(
f
"
{
yaml_path
.
as_posix
()
}
not found"
)
return
config
class
S2TDataConfig
(
object
):
"""Wrapper class for data config YAML"""
def
__init__
(
self
,
yaml_path
:
Path
):
self
.
config
=
get_config_from_yaml
(
yaml_path
)
self
.
root
=
yaml_path
.
parent
def
_auto_convert_to_abs_path
(
self
,
x
):
if
isinstance
(
x
,
str
):
if
not
Path
(
x
).
exists
()
and
(
self
.
root
/
x
).
exists
():
return
(
self
.
root
/
x
).
as_posix
()
elif
isinstance
(
x
,
dict
):
return
{
k
:
self
.
_auto_convert_to_abs_path
(
v
)
for
k
,
v
in
x
.
items
()}
return
x
@
property
def
vocab_filename
(
self
):
"""fairseq vocabulary file under data root"""
return
self
.
config
.
get
(
"vocab_filename"
,
"dict.txt"
)
@
property
def
speaker_set_filename
(
self
):
"""speaker set file under data root"""
return
self
.
config
.
get
(
"speaker_set_filename"
,
None
)
@
property
def
shuffle
(
self
)
->
bool
:
"""Shuffle dataset samples before batching"""
return
self
.
config
.
get
(
"shuffle"
,
False
)
@
property
def
pre_tokenizer
(
self
)
->
Dict
:
"""Pre-tokenizer to apply before subword tokenization. Returning
a dictionary with `tokenizer` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
tokenizer
=
self
.
config
.
get
(
"pre_tokenizer"
,
{
"tokenizer"
:
None
})
return
self
.
_auto_convert_to_abs_path
(
tokenizer
)
@
property
def
bpe_tokenizer
(
self
)
->
Dict
:
"""Subword tokenizer to apply after pre-tokenization. Returning
a dictionary with `bpe` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
tokenizer
=
self
.
config
.
get
(
"bpe_tokenizer"
,
{
"bpe"
:
None
})
return
self
.
_auto_convert_to_abs_path
(
tokenizer
)
@
property
def
prepend_tgt_lang_tag
(
self
)
->
bool
:
"""Prepend target lang ID token as the target BOS (e.g. for to-many
multilingual setting). During inference, this requires `--prefix-size 1`
to force BOS to be lang ID token."""
return
self
.
config
.
get
(
"prepend_tgt_lang_tag"
,
False
)
@
property
def
prepend_bos_and_append_tgt_lang_tag
(
self
)
->
bool
:
"""Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
return
self
.
config
.
get
(
"prepend_bos_and_append_tgt_lang_tag"
,
False
)
@
property
def
input_feat_per_channel
(
self
):
"""The dimension of input features (per audio channel)"""
return
self
.
config
.
get
(
"input_feat_per_channel"
,
80
)
@
property
def
input_channels
(
self
):
"""The number of channels in the input audio"""
return
self
.
config
.
get
(
"input_channels"
,
1
)
@
property
def
sample_rate
(
self
):
return
self
.
config
.
get
(
"sample_rate"
,
16_000
)
@
property
def
sampling_alpha
(
self
):
"""Hyper-parameter alpha = 1/T for temperature-based resampling.
(alpha = 1 for no resampling)"""
return
self
.
config
.
get
(
"sampling_alpha"
,
1.0
)
@
property
def
use_audio_input
(
self
):
"""Needed by the dataset loader to see if the model requires
raw audio as inputs."""
return
self
.
config
.
get
(
"use_audio_input"
,
False
)
def
standardize_audio
(
self
)
->
bool
:
return
self
.
use_audio_input
and
self
.
config
.
get
(
"standardize_audio"
,
False
)
@
property
def
use_sample_rate
(
self
):
"""Needed by the dataset loader to see if the model requires
raw audio with specific sample rate as inputs."""
return
self
.
config
.
get
(
"use_sample_rate"
,
16000
)
@
property
def
audio_root
(
self
):
"""Audio paths in the manifest TSV can be relative and this provides
the root path. Set this to empty string when using absolute paths."""
return
self
.
config
.
get
(
"audio_root"
,
""
)
def
get_feature_transforms
(
self
,
split
,
is_train
):
"""Split-specific feature transforms. Allowing train set
wildcard `_train`, evaluation set wildcard `_eval` and general
wildcard `*` for matching."""
from
copy
import
deepcopy
cfg
=
deepcopy
(
self
.
config
)
_cur
=
cfg
.
get
(
"transforms"
,
{})
cur
=
_cur
.
get
(
split
)
cur
=
_cur
.
get
(
"_train"
)
if
cur
is
None
and
is_train
else
cur
cur
=
_cur
.
get
(
"_eval"
)
if
cur
is
None
and
not
is_train
else
cur
cur
=
_cur
.
get
(
"*"
)
if
cur
is
None
else
cur
cfg
[
"transforms"
]
=
cur
return
cfg
@
property
def
global_cmvn_stats_npz
(
self
)
->
Optional
[
str
]:
path
=
self
.
config
.
get
(
"global_cmvn"
,
{}).
get
(
"stats_npz_path"
,
None
)
return
self
.
_auto_convert_to_abs_path
(
path
)
@
property
def
vocoder
(
self
)
->
Dict
[
str
,
str
]:
vocoder
=
self
.
config
.
get
(
"vocoder"
,
{
"type"
:
"griffin_lim"
})
return
self
.
_auto_convert_to_abs_path
(
vocoder
)
@
property
def
hub
(
self
)
->
Dict
[
str
,
str
]:
return
self
.
config
.
get
(
"hub"
,
{})
class
S2SDataConfig
(
S2TDataConfig
):
"""Wrapper class for data config YAML"""
@
property
def
vocab_filename
(
self
):
"""fairseq vocabulary file under data root"""
return
self
.
config
.
get
(
"vocab_filename"
,
None
)
@
property
def
pre_tokenizer
(
self
)
->
Dict
:
return
None
@
property
def
bpe_tokenizer
(
self
)
->
Dict
:
return
None
@
property
def
input_transformed_channels
(
self
):
"""The number of channels in the audio after feature transforms"""
# TODO: move this into individual transforms
_cur
=
self
.
config
.
get
(
"transforms"
,
{})
cur
=
_cur
.
get
(
"_train"
,
[])
_channels
=
self
.
input_channels
if
"delta_deltas"
in
cur
:
_channels
*=
3
return
_channels
@
property
def
output_sample_rate
(
self
):
"""The audio sample rate of output target speech"""
return
self
.
config
.
get
(
"output_sample_rate"
,
22050
)
@
property
def
target_speaker_embed
(
self
):
"""Target speaker embedding file (one line per target audio sample)"""
return
self
.
config
.
get
(
"target_speaker_embed"
,
None
)
@
property
def
prepend_tgt_lang_tag_as_bos
(
self
)
->
bool
:
"""Prepend target lang ID token as the target BOS."""
return
self
.
config
.
get
(
"prepend_tgt_lang_tag_as_bos"
,
False
)
class
MultitaskConfig
(
object
):
"""Wrapper class for data config YAML"""
def
__init__
(
self
,
yaml_path
:
Path
):
config
=
get_config_from_yaml
(
yaml_path
)
self
.
config
=
{}
for
k
,
v
in
config
.
items
():
self
.
config
[
k
]
=
SingleTaskConfig
(
k
,
v
)
def
get_all_tasks
(
self
):
return
self
.
config
def
get_single_task
(
self
,
name
):
assert
name
in
self
.
config
,
f
"multitask '
{
name
}
' does not exist!"
return
self
.
config
[
name
]
class
SingleTaskConfig
(
object
):
def
__init__
(
self
,
name
,
config
):
self
.
task_name
=
name
self
.
config
=
config
dict_path
=
config
.
get
(
"dict"
,
""
)
self
.
tgt_dict
=
Dictionary
.
load
(
dict_path
)
if
Path
(
dict_path
).
exists
()
else
None
@
property
def
data
(
self
):
return
self
.
config
.
get
(
"data"
,
""
)
@
property
def
decoder_type
(
self
):
return
self
.
config
.
get
(
"decoder_type"
,
"transformer"
)
@
property
def
decoder_args
(
self
):
"""Decoder arch related args"""
args
=
self
.
config
.
get
(
"decoder_args"
,
{})
return
Namespace
(
**
args
)
@
property
def
criterion_cfg
(
self
):
"""cfg for the multitask criterion"""
if
self
.
decoder_type
==
"ctc"
:
from
fairseq.criterions.ctc
import
CtcCriterionConfig
cfg
=
CtcCriterionConfig
cfg
.
zero_infinity
=
self
.
config
.
get
(
"zero_infinity"
,
True
)
else
:
from
fairseq.criterions.label_smoothed_cross_entropy
import
(
LabelSmoothedCrossEntropyCriterionConfig
,
)
cfg
=
LabelSmoothedCrossEntropyCriterionConfig
cfg
.
label_smoothing
=
self
.
config
.
get
(
"label_smoothing"
,
0.2
)
return
cfg
@
property
def
input_from
(
self
):
"""Condition on encoder/decoder of the main model"""
return
"decoder"
if
"decoder_layer"
in
self
.
config
else
"encoder"
@
property
def
input_layer
(
self
):
if
self
.
input_from
==
"decoder"
:
return
self
.
config
[
"decoder_layer"
]
-
1
else
:
# default using the output from the last encoder layer (-1)
return
self
.
config
.
get
(
"encoder_layer"
,
0
)
-
1
@
property
def
loss_weight_schedule
(
self
):
return
(
"decay"
if
"loss_weight_max"
in
self
.
config
and
"loss_weight_decay_steps"
in
self
.
config
else
"fixed"
)
def
get_loss_weight
(
self
,
num_updates
):
if
self
.
loss_weight_schedule
==
"fixed"
:
weight
=
self
.
config
.
get
(
"loss_weight"
,
1.0
)
else
:
# "decay"
assert
(
self
.
config
.
get
(
"loss_weight_decay_steps"
,
0
)
>
0
),
"loss_weight_decay_steps must be greater than 0 for a decay schedule"
loss_weight_min
=
self
.
config
.
get
(
"loss_weight_min"
,
0.0001
)
loss_weight_decay_stepsize
=
(
self
.
config
[
"loss_weight_max"
]
-
loss_weight_min
)
/
self
.
config
[
"loss_weight_decay_steps"
]
weight
=
max
(
self
.
config
[
"loss_weight_max"
]
-
loss_weight_decay_stepsize
*
num_updates
,
loss_weight_min
,
)
return
weight
PyTorch/NLP/new-Transformer/fairseq/data/audio/feature_transforms/__init__.py
0 → 100644
View file @
c0f05c10
import
importlib
import
os
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
Optional
class
AudioFeatureTransform
(
ABC
):
@
classmethod
@
abstractmethod
def
from_config_dict
(
cls
,
config
:
Optional
[
Dict
]
=
None
):
pass
AUDIO_FEATURE_TRANSFORM_REGISTRY
=
{}
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES
=
set
()
def
register_audio_feature_transform
(
name
):
def
register_audio_feature_transform_cls
(
cls
):
if
name
in
AUDIO_FEATURE_TRANSFORM_REGISTRY
:
raise
ValueError
(
f
"Cannot register duplicate transform (
{
name
}
)"
)
if
not
issubclass
(
cls
,
AudioFeatureTransform
):
raise
ValueError
(
f
"Transform (
{
name
}
:
{
cls
.
__name__
}
) must extend "
"AudioFeatureTransform"
)
if
cls
.
__name__
in
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES
:
raise
ValueError
(
f
"Cannot register audio feature transform with duplicate "
f
"class name (
{
cls
.
__name__
}
)"
)
AUDIO_FEATURE_TRANSFORM_REGISTRY
[
name
]
=
cls
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES
.
add
(
cls
.
__name__
)
return
cls
return
register_audio_feature_transform_cls
def
get_audio_feature_transform
(
name
):
return
AUDIO_FEATURE_TRANSFORM_REGISTRY
[
name
]
transforms_dir
=
os
.
path
.
dirname
(
__file__
)
for
file
in
os
.
listdir
(
transforms_dir
):
path
=
os
.
path
.
join
(
transforms_dir
,
file
)
if
(
not
file
.
startswith
(
"_"
)
and
not
file
.
startswith
(
"."
)
and
(
file
.
endswith
(
".py"
)
or
os
.
path
.
isdir
(
path
))
):
name
=
file
[:
file
.
find
(
".py"
)]
if
file
.
endswith
(
".py"
)
else
file
importlib
.
import_module
(
"fairseq.data.audio.feature_transforms."
+
name
)
class
CompositeAudioFeatureTransform
(
AudioFeatureTransform
):
@
classmethod
def
from_config_dict
(
cls
,
config
=
None
):
_config
=
{}
if
config
is
None
else
config
_transforms
=
_config
.
get
(
"transforms"
)
if
_transforms
is
None
:
return
None
transforms
=
[
get_audio_feature_transform
(
_t
).
from_config_dict
(
_config
.
get
(
_t
))
for
_t
in
_transforms
]
return
CompositeAudioFeatureTransform
(
transforms
)
def
__init__
(
self
,
transforms
):
self
.
transforms
=
[
t
for
t
in
transforms
if
t
is
not
None
]
def
__call__
(
self
,
x
):
for
t
in
self
.
transforms
:
x
=
t
(
x
)
return
x
def
__repr__
(
self
):
format_string
=
(
[
self
.
__class__
.
__name__
+
"("
]
+
[
f
"
{
t
.
__repr__
()
}
"
for
t
in
self
.
transforms
]
+
[
")"
]
)
return
"
\n
"
.
join
(
format_string
)
PyTorch/NLP/new-Transformer/fairseq/data/audio/feature_transforms/delta_deltas.py
0 → 100644
View file @
c0f05c10
import
numpy
as
np
import
torch
from
fairseq.data.audio.feature_transforms
import
(
AudioFeatureTransform
,
register_audio_feature_transform
,
)
@
register_audio_feature_transform
(
"delta_deltas"
)
class
DeltaDeltas
(
AudioFeatureTransform
):
"""Expand delta-deltas features from spectrum."""
@
classmethod
def
from_config_dict
(
cls
,
config
=
None
):
_config
=
{}
if
config
is
None
else
config
return
DeltaDeltas
(
_config
.
get
(
"win_length"
,
5
))
def
__init__
(
self
,
win_length
=
5
):
self
.
win_length
=
win_length
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
def
__call__
(
self
,
spectrogram
):
from
torchaudio.functional
import
compute_deltas
assert
len
(
spectrogram
.
shape
)
==
2
,
"spectrogram must be a 2-D tensor."
# spectrogram is T x F, while compute_deltas takes (…, F, T)
spectrogram
=
torch
.
from_numpy
(
spectrogram
).
transpose
(
0
,
1
)
delta
=
compute_deltas
(
spectrogram
)
delta_delta
=
compute_deltas
(
delta
)
out_feat
=
np
.
concatenate
(
[
spectrogram
,
delta
.
numpy
(),
delta_delta
.
numpy
()],
axis
=
0
)
out_feat
=
np
.
transpose
(
out_feat
)
return
out_feat
PyTorch/NLP/new-Transformer/fairseq/data/audio/feature_transforms/global_cmvn.py
0 → 100644
View file @
c0f05c10
import
numpy
as
np
from
fairseq.data.audio.feature_transforms
import
(
AudioFeatureTransform
,
register_audio_feature_transform
,
)
@
register_audio_feature_transform
(
"global_cmvn"
)
class
GlobalCMVN
(
AudioFeatureTransform
):
"""Global CMVN (cepstral mean and variance normalization). The global mean
and variance need to be pre-computed and stored in NumPy format (.npz)."""
@
classmethod
def
from_config_dict
(
cls
,
config
=
None
):
_config
=
{}
if
config
is
None
else
config
return
GlobalCMVN
(
_config
.
get
(
"stats_npz_path"
))
def
__init__
(
self
,
stats_npz_path
):
self
.
stats_npz_path
=
stats_npz_path
stats
=
np
.
load
(
stats_npz_path
)
self
.
mean
,
self
.
std
=
stats
[
"mean"
],
stats
[
"std"
]
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
f
'(stats_npz_path="
{
self
.
stats_npz_path
}
")'
def
__call__
(
self
,
x
):
x
=
np
.
subtract
(
x
,
self
.
mean
)
x
=
np
.
divide
(
x
,
self
.
std
)
return
x
PyTorch/NLP/new-Transformer/fairseq/data/audio/feature_transforms/specaugment.py
0 → 100644
View file @
c0f05c10
import
math
import
numbers
from
typing
import
Optional
import
numpy
as
np
from
fairseq.data.audio.feature_transforms
import
(
AudioFeatureTransform
,
register_audio_feature_transform
,
)
@
register_audio_feature_transform
(
"specaugment"
)
class
SpecAugmentTransform
(
AudioFeatureTransform
):
"""SpecAugment (https://arxiv.org/abs/1904.08779)"""
@
classmethod
def
from_config_dict
(
cls
,
config
=
None
):
_config
=
{}
if
config
is
None
else
config
return
SpecAugmentTransform
(
_config
.
get
(
"time_warp_W"
,
0
),
_config
.
get
(
"freq_mask_N"
,
0
),
_config
.
get
(
"freq_mask_F"
,
0
),
_config
.
get
(
"time_mask_N"
,
0
),
_config
.
get
(
"time_mask_T"
,
0
),
_config
.
get
(
"time_mask_p"
,
0.0
),
_config
.
get
(
"mask_value"
,
None
),
)
def
__init__
(
self
,
time_warp_w
:
int
=
0
,
freq_mask_n
:
int
=
0
,
freq_mask_f
:
int
=
0
,
time_mask_n
:
int
=
0
,
time_mask_t
:
int
=
0
,
time_mask_p
:
float
=
0.0
,
mask_value
:
Optional
[
float
]
=
0.0
,
):
# Sanity checks
assert
mask_value
is
None
or
isinstance
(
mask_value
,
numbers
.
Number
),
f
"mask_value (type:
{
type
(
mask_value
)
}
) must be None or a number"
if
freq_mask_n
>
0
:
assert
freq_mask_f
>
0
,
(
f
"freq_mask_F (
{
freq_mask_f
}
) "
f
"must be larger than 0 when doing freq masking."
)
if
time_mask_n
>
0
:
assert
time_mask_t
>
0
,
(
f
"time_mask_T (
{
time_mask_t
}
) must be larger than 0 when "
f
"doing time masking."
)
self
.
time_warp_w
=
time_warp_w
self
.
freq_mask_n
=
freq_mask_n
self
.
freq_mask_f
=
freq_mask_f
self
.
time_mask_n
=
time_mask_n
self
.
time_mask_t
=
time_mask_t
self
.
time_mask_p
=
time_mask_p
self
.
mask_value
=
mask_value
def
__repr__
(
self
):
return
(
self
.
__class__
.
__name__
+
"("
+
", "
.
join
(
[
f
"time_warp_w=
{
self
.
time_warp_w
}
"
,
f
"freq_mask_n=
{
self
.
freq_mask_n
}
"
,
f
"freq_mask_f=
{
self
.
freq_mask_f
}
"
,
f
"time_mask_n=
{
self
.
time_mask_n
}
"
,
f
"time_mask_t=
{
self
.
time_mask_t
}
"
,
f
"time_mask_p=
{
self
.
time_mask_p
}
"
,
]
)
+
")"
)
def
__call__
(
self
,
spectrogram
):
assert
len
(
spectrogram
.
shape
)
==
2
,
"spectrogram must be a 2-D tensor."
distorted
=
spectrogram
.
copy
()
# make a copy of input spectrogram.
num_frames
=
spectrogram
.
shape
[
0
]
# or 'tau' in the paper.
num_freqs
=
spectrogram
.
shape
[
1
]
# or 'miu' in the paper.
mask_value
=
self
.
mask_value
if
mask_value
is
None
:
# if no value was specified, use local mean.
mask_value
=
spectrogram
.
mean
()
if
num_frames
==
0
:
return
spectrogram
if
num_freqs
<
self
.
freq_mask_f
:
return
spectrogram
if
self
.
time_warp_w
>
0
:
if
2
*
self
.
time_warp_w
<
num_frames
:
import
cv2
w0
=
np
.
random
.
randint
(
self
.
time_warp_w
,
num_frames
-
self
.
time_warp_w
)
w
=
np
.
random
.
randint
(
-
self
.
time_warp_w
+
1
,
self
.
time_warp_w
)
upper
,
lower
=
distorted
[:
w0
,
:],
distorted
[
w0
:,
:]
upper
=
cv2
.
resize
(
upper
,
dsize
=
(
num_freqs
,
w0
+
w
),
interpolation
=
cv2
.
INTER_LINEAR
)
lower
=
cv2
.
resize
(
lower
,
dsize
=
(
num_freqs
,
num_frames
-
w0
-
w
),
interpolation
=
cv2
.
INTER_LINEAR
,
)
distorted
=
np
.
concatenate
((
upper
,
lower
),
axis
=
0
)
for
_i
in
range
(
self
.
freq_mask_n
):
f
=
np
.
random
.
randint
(
0
,
self
.
freq_mask_f
)
f0
=
np
.
random
.
randint
(
0
,
num_freqs
-
f
)
if
f
!=
0
:
distorted
[:,
f0
:
f0
+
f
]
=
mask_value
max_time_mask_t
=
min
(
self
.
time_mask_t
,
math
.
floor
(
num_frames
*
self
.
time_mask_p
)
)
if
max_time_mask_t
<
1
:
return
distorted
for
_i
in
range
(
self
.
time_mask_n
):
t
=
np
.
random
.
randint
(
0
,
max_time_mask_t
)
t0
=
np
.
random
.
randint
(
0
,
num_frames
-
t
)
if
t
!=
0
:
distorted
[
t0
:
t0
+
t
,
:]
=
mask_value
return
distorted
PyTorch/NLP/new-Transformer/fairseq/data/audio/feature_transforms/utterance_cmvn.py
0 → 100644
View file @
c0f05c10
import
numpy
as
np
from
fairseq.data.audio.feature_transforms
import
(
AudioFeatureTransform
,
register_audio_feature_transform
,
)
@
register_audio_feature_transform
(
"utterance_cmvn"
)
class
UtteranceCMVN
(
AudioFeatureTransform
):
"""Utterance-level CMVN (cepstral mean and variance normalization)"""
@
classmethod
def
from_config_dict
(
cls
,
config
=
None
):
_config
=
{}
if
config
is
None
else
config
return
UtteranceCMVN
(
_config
.
get
(
"norm_means"
,
True
),
_config
.
get
(
"norm_vars"
,
True
),
)
def
__init__
(
self
,
norm_means
=
True
,
norm_vars
=
True
):
self
.
norm_means
,
self
.
norm_vars
=
norm_means
,
norm_vars
def
__repr__
(
self
):
return
(
self
.
__class__
.
__name__
+
f
"(norm_means=
{
self
.
norm_means
}
, norm_vars=
{
self
.
norm_vars
}
)"
)
def
__call__
(
self
,
x
):
mean
=
x
.
mean
(
axis
=
0
)
square_sums
=
(
x
**
2
).
sum
(
axis
=
0
)
if
self
.
norm_means
:
x
=
np
.
subtract
(
x
,
mean
)
if
self
.
norm_vars
:
var
=
square_sums
/
x
.
shape
[
0
]
-
mean
**
2
std
=
np
.
sqrt
(
np
.
maximum
(
var
,
1e-10
))
x
=
np
.
divide
(
x
,
std
)
return
x
PyTorch/NLP/new-Transformer/fairseq/data/audio/frm_text_to_speech_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.abs
import
csv
import
logging
import
os.path
as
op
from
typing
import
List
,
Optional
import
numpy
as
np
import
torch
from
fairseq.data
import
Dictionary
from
fairseq.data.audio.speech_to_text_dataset
import
S2TDataConfig
from
fairseq.data.audio.text_to_speech_dataset
import
(
TextToSpeechDataset
,
TextToSpeechDatasetCreator
,
)
logger
=
logging
.
getLogger
(
__name__
)
class
FrmTextToSpeechDataset
(
TextToSpeechDataset
):
def
__init__
(
self
,
split
:
str
,
is_train_split
:
bool
,
data_cfg
:
S2TDataConfig
,
audio_paths
:
List
[
str
],
n_frames
:
List
[
int
],
src_texts
:
Optional
[
List
[
str
]]
=
None
,
tgt_texts
:
Optional
[
List
[
str
]]
=
None
,
speakers
:
Optional
[
List
[
str
]]
=
None
,
src_langs
:
Optional
[
List
[
str
]]
=
None
,
tgt_langs
:
Optional
[
List
[
str
]]
=
None
,
ids
:
Optional
[
List
[
str
]]
=
None
,
tgt_dict
:
Optional
[
Dictionary
]
=
None
,
pre_tokenizer
=
None
,
bpe_tokenizer
=
None
,
n_frames_per_step
=
1
,
speaker_to_id
=
None
,
do_chunk
=
False
,
chunk_bound
=-
1
,
chunk_init
=
50
,
chunk_incr
=
5
,
add_eos
=
True
,
dedup
=
True
,
ref_fpu
=-
1
,
):
# It assumes texts are encoded at a fixed frame-rate
super
().
__init__
(
split
=
split
,
is_train_split
=
is_train_split
,
data_cfg
=
data_cfg
,
audio_paths
=
audio_paths
,
n_frames
=
n_frames
,
src_texts
=
src_texts
,
tgt_texts
=
tgt_texts
,
speakers
=
speakers
,
src_langs
=
src_langs
,
tgt_langs
=
tgt_langs
,
ids
=
ids
,
tgt_dict
=
tgt_dict
,
pre_tokenizer
=
pre_tokenizer
,
bpe_tokenizer
=
bpe_tokenizer
,
n_frames_per_step
=
n_frames_per_step
,
speaker_to_id
=
speaker_to_id
,
)
self
.
do_chunk
=
do_chunk
self
.
chunk_bound
=
chunk_bound
self
.
chunk_init
=
chunk_init
self
.
chunk_incr
=
chunk_incr
self
.
add_eos
=
add_eos
self
.
dedup
=
dedup
self
.
ref_fpu
=
ref_fpu
self
.
chunk_size
=
-
1
if
do_chunk
:
assert
self
.
chunk_incr
>=
0
assert
self
.
pre_tokenizer
is
None
def
__getitem__
(
self
,
index
):
index
,
source
,
target
,
speaker_id
,
_
,
_
,
_
=
super
().
__getitem__
(
index
)
if
target
[
-
1
].
item
()
==
self
.
tgt_dict
.
eos_index
:
target
=
target
[:
-
1
]
fpu
=
source
.
size
(
0
)
/
target
.
size
(
0
)
# frame-per-unit
fps
=
self
.
n_frames_per_step
assert
(
self
.
ref_fpu
==
-
1
or
abs
((
fpu
*
fps
-
self
.
ref_fpu
)
/
self
.
ref_fpu
)
<
0.1
),
f
"
{
fpu
*
fps
}
!=
{
self
.
ref_fpu
}
"
# only chunk training split
if
self
.
is_train_split
and
self
.
do_chunk
and
self
.
chunk_size
>
0
:
lang
=
target
[:
int
(
self
.
data_cfg
.
prepend_tgt_lang_tag
)]
text
=
target
[
int
(
self
.
data_cfg
.
prepend_tgt_lang_tag
)
:]
size
=
len
(
text
)
chunk_size
=
min
(
self
.
chunk_size
,
size
)
chunk_start
=
np
.
random
.
randint
(
size
-
chunk_size
+
1
)
text
=
text
[
chunk_start
:
chunk_start
+
chunk_size
]
target
=
torch
.
cat
((
lang
,
text
),
0
)
f_size
=
int
(
np
.
floor
(
chunk_size
*
fpu
))
f_start
=
int
(
np
.
floor
(
chunk_start
*
fpu
))
assert
f_size
>
0
source
=
source
[
f_start
:
f_start
+
f_size
,
:]
if
self
.
dedup
:
target
=
torch
.
unique_consecutive
(
target
)
if
self
.
add_eos
:
eos_idx
=
self
.
tgt_dict
.
eos_index
target
=
torch
.
cat
((
target
,
torch
.
LongTensor
([
eos_idx
])),
0
)
return
index
,
source
,
target
,
speaker_id
def
set_epoch
(
self
,
epoch
):
if
self
.
is_train_split
and
self
.
do_chunk
:
old
=
self
.
chunk_size
self
.
chunk_size
=
self
.
chunk_init
+
epoch
*
self
.
chunk_incr
if
self
.
chunk_bound
>
0
:
self
.
chunk_size
=
min
(
self
.
chunk_size
,
self
.
chunk_bound
)
logger
.
info
(
(
f
"
{
self
.
split
}
: setting chunk size "
f
"from
{
old
}
to
{
self
.
chunk_size
}
"
)
)
class
FrmTextToSpeechDatasetCreator
(
TextToSpeechDatasetCreator
):
# inherit for key names
@
classmethod
def
from_tsv
(
cls
,
root
:
str
,
data_cfg
:
S2TDataConfig
,
split
:
str
,
tgt_dict
,
pre_tokenizer
,
bpe_tokenizer
,
is_train_split
:
bool
,
n_frames_per_step
:
int
,
speaker_to_id
,
do_chunk
:
bool
=
False
,
chunk_bound
:
int
=
-
1
,
chunk_init
:
int
=
50
,
chunk_incr
:
int
=
5
,
add_eos
:
bool
=
True
,
dedup
:
bool
=
True
,
ref_fpu
:
float
=
-
1
,
)
->
FrmTextToSpeechDataset
:
tsv_path
=
op
.
join
(
root
,
f
"
{
split
}
.tsv"
)
if
not
op
.
isfile
(
tsv_path
):
raise
FileNotFoundError
(
f
"Dataset not found:
{
tsv_path
}
"
)
with
open
(
tsv_path
)
as
f
:
reader
=
csv
.
DictReader
(
f
,
delimiter
=
"
\t
"
,
quotechar
=
None
,
doublequote
=
False
,
lineterminator
=
"
\n
"
,
quoting
=
csv
.
QUOTE_NONE
,
)
s
=
[
dict
(
e
)
for
e
in
reader
]
assert
len
(
s
)
>
0
ids
=
[
ss
[
cls
.
KEY_ID
]
for
ss
in
s
]
audio_paths
=
[
op
.
join
(
data_cfg
.
audio_root
,
ss
[
cls
.
KEY_AUDIO
])
for
ss
in
s
]
n_frames
=
[
int
(
ss
[
cls
.
KEY_N_FRAMES
])
for
ss
in
s
]
tgt_texts
=
[
ss
[
cls
.
KEY_TGT_TEXT
]
for
ss
in
s
]
src_texts
=
[
ss
.
get
(
cls
.
KEY_SRC_TEXT
,
cls
.
DEFAULT_SRC_TEXT
)
for
ss
in
s
]
speakers
=
[
ss
.
get
(
cls
.
KEY_SPEAKER
,
cls
.
DEFAULT_SPEAKER
)
for
ss
in
s
]
src_langs
=
[
ss
.
get
(
cls
.
KEY_SRC_LANG
,
cls
.
DEFAULT_LANG
)
for
ss
in
s
]
tgt_langs
=
[
ss
.
get
(
cls
.
KEY_TGT_LANG
,
cls
.
DEFAULT_LANG
)
for
ss
in
s
]
return
FrmTextToSpeechDataset
(
split
=
split
,
is_train_split
=
is_train_split
,
data_cfg
=
data_cfg
,
audio_paths
=
audio_paths
,
n_frames
=
n_frames
,
src_texts
=
src_texts
,
tgt_texts
=
tgt_texts
,
speakers
=
speakers
,
src_langs
=
src_langs
,
tgt_langs
=
tgt_langs
,
ids
=
ids
,
tgt_dict
=
tgt_dict
,
pre_tokenizer
=
pre_tokenizer
,
bpe_tokenizer
=
bpe_tokenizer
,
n_frames_per_step
=
n_frames_per_step
,
speaker_to_id
=
speaker_to_id
,
do_chunk
=
do_chunk
,
chunk_bound
=
chunk_bound
,
chunk_init
=
chunk_init
,
chunk_incr
=
chunk_incr
,
add_eos
=
add_eos
,
dedup
=
dedup
,
ref_fpu
=
ref_fpu
,
)
PyTorch/NLP/new-Transformer/fairseq/data/audio/hubert_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
itertools
import
logging
import
os
import
sys
from
typing
import
Any
,
List
,
Optional
,
Union
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
fairseq.data
import
data_utils
from
fairseq.data.fairseq_dataset
import
FairseqDataset
logger
=
logging
.
getLogger
(
__name__
)
def
load_audio
(
manifest_path
,
max_keep
,
min_keep
):
n_long
,
n_short
=
0
,
0
names
,
inds
,
sizes
=
[],
[],
[]
with
open
(
manifest_path
)
as
f
:
root
=
f
.
readline
().
strip
()
for
ind
,
line
in
enumerate
(
f
):
items
=
line
.
strip
().
split
(
"
\t
"
)
assert
len
(
items
)
==
2
,
line
sz
=
int
(
items
[
1
])
if
min_keep
is
not
None
and
sz
<
min_keep
:
n_short
+=
1
elif
max_keep
is
not
None
and
sz
>
max_keep
:
n_long
+=
1
else
:
names
.
append
(
items
[
0
])
inds
.
append
(
ind
)
sizes
.
append
(
sz
)
tot
=
ind
+
1
logger
.
info
(
(
f
"max_keep=
{
max_keep
}
, min_keep=
{
min_keep
}
, "
f
"loaded
{
len
(
names
)
}
, skipped
{
n_short
}
short and
{
n_long
}
long, "
f
"longest-loaded=
{
max
(
sizes
)
}
, shortest-loaded=
{
min
(
sizes
)
}
"
)
)
return
root
,
names
,
inds
,
tot
,
sizes
def
load_label
(
label_path
,
inds
,
tot
):
with
open
(
label_path
)
as
f
:
labels
=
[
line
.
rstrip
()
for
line
in
f
]
assert
(
len
(
labels
)
==
tot
),
f
"number of labels does not match (
{
len
(
labels
)
}
!=
{
tot
}
)"
labels
=
[
labels
[
i
]
for
i
in
inds
]
return
labels
def
load_label_offset
(
label_path
,
inds
,
tot
):
with
open
(
label_path
)
as
f
:
code_lengths
=
[
len
(
line
.
encode
(
"utf-8"
))
for
line
in
f
]
assert
(
len
(
code_lengths
)
==
tot
),
f
"number of labels does not match (
{
len
(
code_lengths
)
}
!=
{
tot
}
)"
offsets
=
list
(
itertools
.
accumulate
([
0
]
+
code_lengths
))
offsets
=
[(
offsets
[
i
],
offsets
[
i
+
1
])
for
i
in
inds
]
return
offsets
def
verify_label_lengths
(
audio_sizes
,
audio_rate
,
label_path
,
label_rate
,
inds
,
tot
,
tol
=
0.1
,
# tolerance in seconds
):
if
label_rate
<
0
:
logger
.
info
(
f
"
{
label_path
}
is sequence label. skipped"
)
return
with
open
(
label_path
)
as
f
:
lengths
=
[
len
(
line
.
rstrip
().
split
())
for
line
in
f
]
assert
len
(
lengths
)
==
tot
lengths
=
[
lengths
[
i
]
for
i
in
inds
]
num_invalid
=
0
for
i
,
ind
in
enumerate
(
inds
):
dur_from_audio
=
audio_sizes
[
i
]
/
audio_rate
dur_from_label
=
lengths
[
i
]
/
label_rate
if
abs
(
dur_from_audio
-
dur_from_label
)
>
tol
:
logger
.
warning
(
(
f
"audio and label duration differ too much "
f
"(|
{
dur_from_audio
}
-
{
dur_from_label
}
| >
{
tol
}
) "
f
"in line
{
ind
+
1
}
of
{
label_path
}
. Check if `label_rate` "
f
"is correctly set (currently
{
label_rate
}
). "
f
"num. of samples =
{
audio_sizes
[
i
]
}
; "
f
"label length =
{
lengths
[
i
]
}
"
)
)
num_invalid
+=
1
if
num_invalid
>
0
:
logger
.
warning
(
f
"total
{
num_invalid
}
(audio, label) pairs with mismatched lengths"
)
class
HubertDataset
(
FairseqDataset
):
def
__init__
(
self
,
manifest_path
:
str
,
sample_rate
:
float
,
label_paths
:
List
[
str
],
label_rates
:
Union
[
List
[
float
],
float
],
# -1 for sequence labels
pad_list
:
List
[
str
],
eos_list
:
List
[
str
],
label_processors
:
Optional
[
List
[
Any
]]
=
None
,
max_keep_sample_size
:
Optional
[
int
]
=
None
,
min_keep_sample_size
:
Optional
[
int
]
=
None
,
max_sample_size
:
Optional
[
int
]
=
None
,
shuffle
:
bool
=
True
,
pad_audio
:
bool
=
False
,
normalize
:
bool
=
False
,
store_labels
:
bool
=
True
,
random_crop
:
bool
=
False
,
single_target
:
bool
=
False
,
):
self
.
audio_root
,
self
.
audio_names
,
inds
,
tot
,
self
.
sizes
=
load_audio
(
manifest_path
,
max_keep_sample_size
,
min_keep_sample_size
)
self
.
sample_rate
=
sample_rate
self
.
shuffle
=
shuffle
self
.
random_crop
=
random_crop
self
.
num_labels
=
len
(
label_paths
)
self
.
pad_list
=
pad_list
self
.
eos_list
=
eos_list
self
.
label_processors
=
label_processors
self
.
single_target
=
single_target
self
.
label_rates
=
(
[
label_rates
for
_
in
range
(
len
(
label_paths
))]
if
isinstance
(
label_rates
,
float
)
else
label_rates
)
self
.
store_labels
=
store_labels
if
store_labels
:
self
.
label_list
=
[
load_label
(
p
,
inds
,
tot
)
for
p
in
label_paths
]
else
:
self
.
label_paths
=
label_paths
self
.
label_offsets_list
=
[
load_label_offset
(
p
,
inds
,
tot
)
for
p
in
label_paths
]
assert
label_processors
is
None
or
len
(
label_processors
)
==
self
.
num_labels
for
label_path
,
label_rate
in
zip
(
label_paths
,
self
.
label_rates
):
verify_label_lengths
(
self
.
sizes
,
sample_rate
,
label_path
,
label_rate
,
inds
,
tot
)
self
.
max_sample_size
=
(
max_sample_size
if
max_sample_size
is
not
None
else
sys
.
maxsize
)
self
.
pad_audio
=
pad_audio
self
.
normalize
=
normalize
logger
.
info
(
f
"pad_audio=
{
pad_audio
}
, random_crop=
{
random_crop
}
, "
f
"normalize=
{
normalize
}
, max_sample_size=
{
self
.
max_sample_size
}
"
)
def
get_audio
(
self
,
index
):
import
soundfile
as
sf
wav_path
=
os
.
path
.
join
(
self
.
audio_root
,
self
.
audio_names
[
index
])
wav
,
cur_sample_rate
=
sf
.
read
(
wav_path
)
wav
=
torch
.
from_numpy
(
wav
).
float
()
wav
=
self
.
postprocess
(
wav
,
cur_sample_rate
)
return
wav
def
get_label
(
self
,
index
,
label_idx
):
if
self
.
store_labels
:
label
=
self
.
label_list
[
label_idx
][
index
]
else
:
with
open
(
self
.
label_paths
[
label_idx
])
as
f
:
offset_s
,
offset_e
=
self
.
label_offsets_list
[
label_idx
][
index
]
f
.
seek
(
offset_s
)
label
=
f
.
read
(
offset_e
-
offset_s
)
if
self
.
label_processors
is
not
None
:
label
=
self
.
label_processors
[
label_idx
](
label
)
return
label
def
get_labels
(
self
,
index
):
return
[
self
.
get_label
(
index
,
i
)
for
i
in
range
(
self
.
num_labels
)]
def
__getitem__
(
self
,
index
):
wav
=
self
.
get_audio
(
index
)
labels
=
self
.
get_labels
(
index
)
return
{
"id"
:
index
,
"source"
:
wav
,
"label_list"
:
labels
}
def
__len__
(
self
):
return
len
(
self
.
sizes
)
def
crop_to_max_size
(
self
,
wav
,
target_size
):
size
=
len
(
wav
)
diff
=
size
-
target_size
if
diff
<=
0
:
return
wav
,
0
start
,
end
=
0
,
target_size
if
self
.
random_crop
:
start
=
np
.
random
.
randint
(
0
,
diff
+
1
)
end
=
size
-
diff
+
start
return
wav
[
start
:
end
],
start
def
collater
(
self
,
samples
):
# target = max(sizes) -> random_crop not used
# target = max_sample_size -> random_crop used for long
samples
=
[
s
for
s
in
samples
if
s
[
"source"
]
is
not
None
]
if
len
(
samples
)
==
0
:
return
{}
audios
=
[
s
[
"source"
]
for
s
in
samples
]
audio_sizes
=
[
len
(
s
)
for
s
in
audios
]
if
self
.
pad_audio
:
audio_size
=
min
(
max
(
audio_sizes
),
self
.
max_sample_size
)
else
:
audio_size
=
min
(
min
(
audio_sizes
),
self
.
max_sample_size
)
collated_audios
,
padding_mask
,
audio_starts
=
self
.
collater_audio
(
audios
,
audio_size
)
targets_by_label
=
[
[
s
[
"label_list"
][
i
]
for
s
in
samples
]
for
i
in
range
(
self
.
num_labels
)
]
targets_list
,
lengths_list
,
ntokens_list
=
self
.
collater_label
(
targets_by_label
,
audio_size
,
audio_starts
)
net_input
=
{
"source"
:
collated_audios
,
"padding_mask"
:
padding_mask
}
batch
=
{
"id"
:
torch
.
LongTensor
([
s
[
"id"
]
for
s
in
samples
]),
"net_input"
:
net_input
,
}
if
self
.
single_target
:
batch
[
"target_lengths"
]
=
lengths_list
[
0
]
batch
[
"ntokens"
]
=
ntokens_list
[
0
]
batch
[
"target"
]
=
targets_list
[
0
]
else
:
batch
[
"target_lengths_list"
]
=
lengths_list
batch
[
"ntokens_list"
]
=
ntokens_list
batch
[
"target_list"
]
=
targets_list
return
batch
def
collater_audio
(
self
,
audios
,
audio_size
):
collated_audios
=
audios
[
0
].
new_zeros
(
len
(
audios
),
audio_size
)
padding_mask
=
(
torch
.
BoolTensor
(
collated_audios
.
shape
).
fill_
(
False
)
# if self.pad_audio else None
)
audio_starts
=
[
0
for
_
in
audios
]
for
i
,
audio
in
enumerate
(
audios
):
diff
=
len
(
audio
)
-
audio_size
if
diff
==
0
:
collated_audios
[
i
]
=
audio
elif
diff
<
0
:
assert
self
.
pad_audio
collated_audios
[
i
]
=
torch
.
cat
([
audio
,
audio
.
new_full
((
-
diff
,),
0.0
)])
padding_mask
[
i
,
diff
:]
=
True
else
:
collated_audios
[
i
],
audio_starts
[
i
]
=
self
.
crop_to_max_size
(
audio
,
audio_size
)
return
collated_audios
,
padding_mask
,
audio_starts
def
collater_frm_label
(
self
,
targets
,
audio_size
,
audio_starts
,
label_rate
,
pad
):
assert
label_rate
>
0
s2f
=
label_rate
/
self
.
sample_rate
frm_starts
=
[
int
(
round
(
s
*
s2f
))
for
s
in
audio_starts
]
frm_size
=
int
(
round
(
audio_size
*
s2f
))
if
not
self
.
pad_audio
:
rem_size
=
[
len
(
t
)
-
s
for
t
,
s
in
zip
(
targets
,
frm_starts
)]
frm_size
=
min
(
frm_size
,
*
rem_size
)
targets
=
[
t
[
s
:
s
+
frm_size
]
for
t
,
s
in
zip
(
targets
,
frm_starts
)]
logger
.
debug
(
f
"audio_starts=
{
audio_starts
}
"
)
logger
.
debug
(
f
"frame_starts=
{
frm_starts
}
"
)
logger
.
debug
(
f
"frame_size=
{
frm_size
}
"
)
lengths
=
torch
.
LongTensor
([
len
(
t
)
for
t
in
targets
])
ntokens
=
lengths
.
sum
().
item
()
targets
=
data_utils
.
collate_tokens
(
targets
,
pad_idx
=
pad
,
left_pad
=
False
)
return
targets
,
lengths
,
ntokens
def
collater_seq_label
(
self
,
targets
,
pad
):
lengths
=
torch
.
LongTensor
([
len
(
t
)
for
t
in
targets
])
ntokens
=
lengths
.
sum
().
item
()
targets
=
data_utils
.
collate_tokens
(
targets
,
pad_idx
=
pad
,
left_pad
=
False
)
return
targets
,
lengths
,
ntokens
def
collater_label
(
self
,
targets_by_label
,
audio_size
,
audio_starts
):
targets_list
,
lengths_list
,
ntokens_list
=
[],
[],
[]
itr
=
zip
(
targets_by_label
,
self
.
label_rates
,
self
.
pad_list
)
for
targets
,
label_rate
,
pad
in
itr
:
if
label_rate
==
-
1.0
:
targets
,
lengths
,
ntokens
=
self
.
collater_seq_label
(
targets
,
pad
)
else
:
targets
,
lengths
,
ntokens
=
self
.
collater_frm_label
(
targets
,
audio_size
,
audio_starts
,
label_rate
,
pad
)
targets_list
.
append
(
targets
)
lengths_list
.
append
(
lengths
)
ntokens_list
.
append
(
ntokens
)
return
targets_list
,
lengths_list
,
ntokens_list
def
num_tokens
(
self
,
index
):
return
self
.
size
(
index
)
def
size
(
self
,
index
):
if
self
.
pad_audio
:
return
self
.
sizes
[
index
]
return
min
(
self
.
sizes
[
index
],
self
.
max_sample_size
)
def
ordered_indices
(
self
):
if
self
.
shuffle
:
order
=
[
np
.
random
.
permutation
(
len
(
self
))]
else
:
order
=
[
np
.
arange
(
len
(
self
))]
order
.
append
(
self
.
sizes
)
return
np
.
lexsort
(
order
)[::
-
1
]
def
postprocess
(
self
,
wav
,
cur_sample_rate
):
if
wav
.
dim
()
==
2
:
wav
=
wav
.
mean
(
-
1
)
assert
wav
.
dim
()
==
1
,
wav
.
dim
()
if
cur_sample_rate
!=
self
.
sample_rate
:
raise
Exception
(
f
"sr
{
cur_sample_rate
}
!=
{
self
.
sample_rate
}
"
)
if
self
.
normalize
:
with
torch
.
no_grad
():
wav
=
F
.
layer_norm
(
wav
,
wav
.
shape
)
return
wav
Prev
1
…
5
6
7
8
9
10
11
12
13
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment