Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
dcuai
dlexamples
Commits
c0f05c10
Commit
c0f05c10
authored
Nov 29, 2022
by
hepj
Browse files
更新transformer代码
parent
c056df78
Changes
321
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2931 additions
and
0 deletions
+2931
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/sentence_prediction.py
...new-Transformer/fairseq/criterions/sentence_prediction.py
+141
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/sentence_prediction_adapters.py
...former/fairseq/criterions/sentence_prediction_adapters.py
+63
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/sentence_ranking.py
...LP/new-Transformer/fairseq/criterions/sentence_ranking.py
+120
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/speech_to_speech_criterion.py
...nsformer/fairseq/criterions/speech_to_speech_criterion.py
+310
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/speech_ulm_criterion.py
...ew-Transformer/fairseq/criterions/speech_ulm_criterion.py
+126
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/tacotron2_loss.py
.../NLP/new-Transformer/fairseq/criterions/tacotron2_loss.py
+226
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/wav2vec_criterion.py
...P/new-Transformer/fairseq/criterions/wav2vec_criterion.py
+230
-0
PyTorch/NLP/new-Transformer/fairseq/data/__init__.py
PyTorch/NLP/new-Transformer/fairseq/data/__init__.py
+130
-0
PyTorch/NLP/new-Transformer/fairseq/data/add_target_dataset.py
...ch/NLP/new-Transformer/fairseq/data/add_target_dataset.py
+83
-0
PyTorch/NLP/new-Transformer/fairseq/data/append_token_dataset.py
.../NLP/new-Transformer/fairseq/data/append_token_dataset.py
+41
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/__init__.py
PyTorch/NLP/new-Transformer/fairseq/data/audio/__init__.py
+0
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/audio_utils.py
...rch/NLP/new-Transformer/fairseq/data/audio/audio_utils.py
+293
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/data_cfg.py
PyTorch/NLP/new-Transformer/fairseq/data/audio/data_cfg.py
+299
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/feature_transforms/__init__.py
...sformer/fairseq/data/audio/feature_transforms/__init__.py
+82
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/feature_transforms/delta_deltas.py
...mer/fairseq/data/audio/feature_transforms/delta_deltas.py
+37
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/feature_transforms/global_cmvn.py
...rmer/fairseq/data/audio/feature_transforms/global_cmvn.py
+29
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/feature_transforms/specaugment.py
...rmer/fairseq/data/audio/feature_transforms/specaugment.py
+131
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/feature_transforms/utterance_cmvn.py
...r/fairseq/data/audio/feature_transforms/utterance_cmvn.py
+41
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/frm_text_to_speech_dataset.py
...nsformer/fairseq/data/audio/frm_text_to_speech_dataset.py
+205
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/hubert_dataset.py
.../NLP/new-Transformer/fairseq/data/audio/hubert_dataset.py
+344
-0
No files found.
Too many changes to show.
To preserve performance only
321 of 321+
files are displayed.
Plain diff
Email patch
PyTorch/NLP/new-Transformer/fairseq/criterions/sentence_prediction.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
dataclasses
import
dataclass
,
field
import
torch
import
torch.nn.functional
as
F
from
fairseq
import
metrics
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
from
fairseq.dataclass
import
FairseqDataclass
@
dataclass
class
SentencePredictionConfig
(
FairseqDataclass
):
classification_head_name
:
str
=
field
(
default
=
"sentence_classification_head"
,
metadata
=
{
"help"
:
"name of the classification head to use"
},
)
regression_target
:
bool
=
field
(
default
=
False
,
)
@
register_criterion
(
"sentence_prediction"
,
dataclass
=
SentencePredictionConfig
)
class
SentencePredictionCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
cfg
:
SentencePredictionConfig
,
task
):
super
().
__init__
(
task
)
self
.
classification_head_name
=
cfg
.
classification_head_name
self
.
regression_target
=
cfg
.
regression_target
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert
(
hasattr
(
model
,
"classification_heads"
)
and
self
.
classification_head_name
in
model
.
classification_heads
),
"model must provide sentence classification head for --criterion=sentence_prediction"
logits
,
_
=
model
(
**
sample
[
"net_input"
],
features_only
=
True
,
classification_head_name
=
self
.
classification_head_name
,
)
targets
=
model
.
get_targets
(
sample
,
[
logits
]).
view
(
-
1
)
sample_size
=
targets
.
numel
()
if
not
self
.
regression_target
:
lprobs
=
F
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
task_loss
=
F
.
nll_loss
(
lprobs
,
targets
,
reduction
=
"sum"
)
else
:
logits
=
logits
.
view
(
-
1
).
float
()
targets
=
targets
.
float
()
task_loss
=
F
.
mse_loss
(
logits
,
targets
,
reduction
=
"sum"
)
logging_output
=
{}
loss
=
task_loss
# mha & ffn regularization update
if
(
hasattr
(
model
.
args
,
"mha_reg_scale_factor"
)
and
model
.
args
.
mha_reg_scale_factor
!=
0.0
):
mha_reg_loss
=
model
.
_get_adaptive_head_loss
()
loss
+=
mha_reg_loss
logging_output
.
update
({
"mha_reg_loss"
:
mha_reg_loss
})
if
(
hasattr
(
model
.
args
,
"ffn_reg_scale_factor"
)
and
model
.
args
.
ffn_reg_scale_factor
!=
0.0
):
ffn_reg_loss
=
model
.
_get_adaptive_ffn_loss
()
loss
+=
ffn_reg_loss
logging_output
.
update
({
"ffn_reg_loss"
:
ffn_reg_loss
})
logging_output
.
update
(
{
"loss"
:
loss
.
data
,
"ntokens"
:
sample
[
"ntokens"
],
"nsentences"
:
sample_size
,
"sample_size"
:
sample_size
,
}
)
if
not
self
.
regression_target
:
preds
=
logits
.
argmax
(
dim
=
1
)
logging_output
[
"ncorrect"
]
=
(
preds
==
targets
).
sum
()
return
loss
,
sample_size
,
logging_output
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
"""Aggregate logging outputs from data parallel training."""
loss_sum
=
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
)
ntokens
=
sum
(
log
.
get
(
"ntokens"
,
0
)
for
log
in
logging_outputs
)
nsentences
=
sum
(
log
.
get
(
"nsentences"
,
0
)
for
log
in
logging_outputs
)
sample_size
=
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
mha_reg_loss_sum
=
sum
(
log
.
get
(
"mha_reg_loss"
,
0
)
for
log
in
logging_outputs
)
ffn_reg_loss_sum
=
sum
(
log
.
get
(
"ffn_reg_loss"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"loss"
,
loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
)
if
mha_reg_loss_sum
:
metrics
.
log_scalar
(
"mha_reg_loss"
,
mha_reg_loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
,
)
if
ffn_reg_loss_sum
:
metrics
.
log_scalar
(
"ffn_reg_loss"
,
ffn_reg_loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
,
)
if
sample_size
!=
ntokens
:
metrics
.
log_scalar
(
"nll_loss"
,
loss_sum
/
ntokens
/
math
.
log
(
2
),
ntokens
,
round
=
3
)
if
len
(
logging_outputs
)
>
0
and
"ncorrect"
in
logging_outputs
[
0
]:
ncorrect
=
sum
(
log
.
get
(
"ncorrect"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"accuracy"
,
100.0
*
ncorrect
/
nsentences
,
nsentences
,
round
=
1
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return
True
PyTorch/NLP/new-Transformer/fairseq/criterions/sentence_prediction_adapters.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
import
torch.nn.functional
as
F
from
fairseq.criterions
import
register_criterion
from
fairseq.criterions.sentence_prediction
import
(
SentencePredictionCriterion
,
SentencePredictionConfig
,
)
@
register_criterion
(
"sentence_prediction_adapters"
,
dataclass
=
SentencePredictionConfig
)
class
SentencePredictionCriterionAdapters
(
SentencePredictionCriterion
):
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert
(
hasattr
(
model
,
"classification_heads"
)
and
self
.
classification_head_name
in
model
.
classification_heads
),
"model must provide sentence classification head for --criterion=sentence_prediction"
if
not
hasattr
(
sample
,
"lang_id"
):
# If no language ID is given, we fall back to English
lang_id
=
[
"en_XX"
]
*
sample
[
"nsentences"
]
else
:
lang_id
=
sample
[
"lang_id"
]
logits
,
_
=
model
(
**
sample
[
"net_input"
],
features_only
=
True
,
classification_head_name
=
self
.
classification_head_name
,
lang_id
=
lang_id
,
)
targets
=
model
.
get_targets
(
sample
,
[
logits
]).
view
(
-
1
)
sample_size
=
targets
.
numel
()
if
not
self
.
regression_target
:
lprobs
=
F
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
loss
=
F
.
nll_loss
(
lprobs
,
targets
,
reduction
=
"sum"
)
else
:
logits
=
logits
.
view
(
-
1
).
float
()
targets
=
targets
.
float
()
loss
=
F
.
mse_loss
(
logits
,
targets
,
reduction
=
"sum"
)
logging_output
=
{
"loss"
:
loss
.
data
,
"ntokens"
:
sample
[
"ntokens"
],
"nsentences"
:
sample_size
,
"sample_size"
:
sample_size
,
}
if
not
self
.
regression_target
:
preds
=
logits
.
argmax
(
dim
=
1
)
logging_output
[
"ncorrect"
]
=
(
preds
==
targets
).
sum
()
return
loss
,
sample_size
,
logging_output
PyTorch/NLP/new-Transformer/fairseq/criterions/sentence_ranking.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
import
torch
import
torch.nn.functional
as
F
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
@
register_criterion
(
"sentence_ranking"
)
class
SentenceRankingCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
task
,
ranking_head_name
,
save_predictions
,
num_classes
):
super
().
__init__
(
task
)
self
.
ranking_head_name
=
ranking_head_name
if
save_predictions
is
not
None
:
self
.
prediction_h
=
open
(
save_predictions
,
"w"
)
else
:
self
.
prediction_h
=
None
self
.
num_classes
=
num_classes
def
__del__
(
self
):
if
self
.
prediction_h
is
not
None
:
self
.
prediction_h
.
close
()
@
staticmethod
def
add_args
(
parser
):
# fmt: off
parser
.
add_argument
(
'--save-predictions'
,
metavar
=
'FILE'
,
help
=
'file to save predictions to'
)
parser
.
add_argument
(
'--ranking-head-name'
,
default
=
'sentence_classification_head'
,
help
=
'name of the ranking head to use'
)
# fmt: on
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute ranking loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert
(
hasattr
(
model
,
"classification_heads"
)
and
self
.
ranking_head_name
in
model
.
classification_heads
),
"model must provide sentence ranking head for --criterion=sentence_ranking"
scores
=
[]
for
idx
in
range
(
self
.
num_classes
):
score
,
_
=
model
(
**
sample
[
"net_input{idx}"
.
format
(
idx
=
idx
+
1
)],
classification_head_name
=
self
.
ranking_head_name
,
)
scores
.
append
(
score
)
logits
=
torch
.
cat
(
scores
,
dim
=
1
)
sample_size
=
logits
.
size
(
0
)
if
"target"
in
sample
:
targets
=
model
.
get_targets
(
sample
,
[
logits
]).
view
(
-
1
)
lprobs
=
F
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
loss
=
F
.
nll_loss
(
lprobs
,
targets
,
reduction
=
"sum"
)
else
:
targets
=
None
loss
=
torch
.
tensor
(
0.0
,
requires_grad
=
True
)
if
self
.
prediction_h
is
not
None
:
preds
=
logits
.
argmax
(
dim
=
1
)
for
i
,
(
id
,
pred
)
in
enumerate
(
zip
(
sample
[
"id"
].
tolist
(),
preds
.
tolist
())):
if
targets
is
not
None
:
label
=
targets
[
i
].
item
()
print
(
"{}
\t
{}
\t
{}"
.
format
(
id
,
pred
,
label
),
file
=
self
.
prediction_h
)
else
:
print
(
"{}
\t
{}"
.
format
(
id
,
pred
),
file
=
self
.
prediction_h
)
logging_output
=
{
"loss"
:
loss
.
data
,
"ntokens"
:
sample
[
"ntokens"
],
"nsentences"
:
sample_size
,
"sample_size"
:
sample_size
,
}
if
targets
is
not
None
:
logging_output
[
"ncorrect"
]
=
(
logits
.
argmax
(
dim
=
1
)
==
targets
).
sum
()
return
loss
,
sample_size
,
logging_output
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
"""Aggregate logging outputs from data parallel training."""
loss_sum
=
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
)
ntokens
=
sum
(
log
.
get
(
"ntokens"
,
0
)
for
log
in
logging_outputs
)
nsentences
=
sum
(
log
.
get
(
"nsentences"
,
0
)
for
log
in
logging_outputs
)
sample_size
=
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"loss"
,
loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
)
if
sample_size
!=
ntokens
:
metrics
.
log_scalar
(
"nll_loss"
,
loss_sum
/
ntokens
/
math
.
log
(
2
),
ntokens
,
round
=
3
)
if
len
(
logging_outputs
)
>
0
and
"ncorrect"
in
logging_outputs
[
0
]:
ncorrect
=
sum
(
log
.
get
(
"ncorrect"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"accuracy"
,
100.0
*
ncorrect
/
nsentences
,
nsentences
,
round
=
1
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return
True
PyTorch/NLP/new-Transformer/fairseq/criterions/speech_to_speech_criterion.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
import
torch
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
register_criterion
from
fairseq.criterions.ctc
import
CtcCriterion
from
fairseq.criterions.label_smoothed_cross_entropy
import
(
LabelSmoothedCrossEntropyCriterion
,
LabelSmoothedCrossEntropyCriterionConfig
,
)
from
fairseq.criterions.tacotron2_loss
import
(
Tacotron2Criterion
,
Tacotron2CriterionConfig
,
)
class
MultitaskCriterion
:
def
__init__
(
self
,
multitask_tasks
):
self
.
multitask_criterion
=
{}
self
.
multitask_loss_weight
=
{}
for
task_name
,
task_obj
in
multitask_tasks
.
items
():
if
task_obj
.
args
.
decoder_type
==
"ctc"
:
self
.
multitask_criterion
[
task_name
]
=
CtcCriterion
(
task_obj
.
args
.
criterion_cfg
,
task_obj
)
else
:
self
.
multitask_criterion
[
task_name
]
=
LabelSmoothedCrossEntropyCriterion
(
task_obj
,
task_obj
.
args
.
criterion_cfg
.
sentence_avg
,
label_smoothing
=
task_obj
.
args
.
criterion_cfg
.
label_smoothing
,
)
def
set_multitask_loss_weight
(
self
,
task_name
,
weight
=
0.0
):
self
.
multitask_loss_weight
[
task_name
]
=
weight
def
get_multitask_loss
(
self
,
model
,
sample
,
model_out
):
logging_output
=
{}
loss
=
0.0
for
task_name
,
task_criterion
in
self
.
multitask_criterion
.
items
():
layer_id
=
task_criterion
.
task
.
args
.
input_layer
if
isinstance
(
task_criterion
,
CtcCriterion
):
if
task_criterion
.
task
.
args
.
input_from
==
"encoder"
:
non_padding_mask
=
~
model_out
[
"encoder_padding_mask"
][
0
]
input_lengths
=
non_padding_mask
.
long
().
sum
(
-
1
)
task_sample
=
{
"net_input"
:
{
"src_tokens"
:
model_out
[
"encoder_states"
][
layer_id
],
# check batch idx
"src_lengths"
:
input_lengths
,
},
"id"
:
sample
[
"id"
],
}
else
:
task_sample
=
{
"net_input"
:
{
"src_tokens"
:
model_out
[
"inner_states"
][
layer_id
],
"src_lengths"
:
sample
[
"target_lengths"
],
},
"id"
:
sample
[
"id"
],
}
else
:
task_sample
=
{
"net_input"
:
{
"src_tokens"
:
sample
[
"multitask"
][
task_name
][
"net_input"
][
"prev_output_tokens"
],
"encoder_out"
:
{
"encoder_out"
:
[
model_out
[
"encoder_states"
][
layer_id
]],
"encoder_padding_mask"
:
model_out
[
"encoder_padding_mask"
],
},
}
}
for
key
in
[
"target"
,
"target_lengths"
,
"ntokens"
]:
task_sample
[
key
]
=
sample
[
"multitask"
][
task_name
][
key
]
task_loss
,
task_sample_size
,
task_logging_output
=
task_criterion
(
model
.
multitask_decoders
[
task_name
],
task_sample
)
loss
=
loss
+
self
.
multitask_loss_weight
[
task_name
]
*
task_loss
task_logging_output
[
"loss_weight"
]
=
self
.
multitask_loss_weight
[
task_name
]
logging_output
[
task_name
]
=
task_logging_output
return
loss
,
logging_output
@
classmethod
def
reduce_metrics
(
cls
,
logging_outputs
)
->
None
:
for
task_name
in
logging_outputs
[
0
][
"multitask"
].
keys
():
# different criterion may return different logging
# currently only reduce on loss, the most common one
# ideally the way that losses are reduced should also depend on the task type
loss_sum
=
sum
(
log
[
"multitask"
][
task_name
].
get
(
"loss"
,
0
)
for
log
in
logging_outputs
)
sample_size
=
sum
(
log
[
"multitask"
][
task_name
].
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
f
"multitask_
{
task_name
}
_loss"
,
loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
,
)
loss_weight
=
logging_outputs
[
0
][
"multitask"
][
task_name
].
get
(
"loss_weight"
,
0
)
metrics
.
log_scalar
(
f
"multitask_
{
task_name
}
_loss_weight"
,
loss_weight
,
weight
=
0
,
priority
=
250
,
)
@
register_criterion
(
"speech_to_unit"
,
dataclass
=
LabelSmoothedCrossEntropyCriterionConfig
)
class
SpeechToUnitMultitaskTaskCriterion
(
LabelSmoothedCrossEntropyCriterion
,
MultitaskCriterion
):
def
__init__
(
self
,
task
,
sentence_avg
,
label_smoothing
,
ignore_prefix_size
=
0
,
report_accuracy
=
False
,
):
super
().
__init__
(
task
,
sentence_avg
,
label_smoothing
,
ignore_prefix_size
,
report_accuracy
)
MultitaskCriterion
.
__init__
(
self
,
task
.
multitask_tasks
)
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
net_output
,
extra
=
model
(
src_tokens
=
sample
[
"net_input"
][
"src_tokens"
],
src_lengths
=
sample
[
"net_input"
][
"src_lengths"
],
prev_output_tokens
=
sample
[
"net_input"
][
"prev_output_tokens"
],
tgt_speaker
=
sample
[
"net_input"
][
"tgt_speaker"
],
return_all_hiddens
=
True
,
)
loss
,
nll_loss
=
self
.
compute_loss
(
model
,
[
net_output
],
sample
,
reduce
=
reduce
)
sample_size
=
(
sample
[
"target"
].
size
(
0
)
if
self
.
sentence_avg
else
sample
[
"ntokens"
]
)
logging_output
=
{
"loss"
:
loss
.
data
,
"nll_loss"
:
nll_loss
.
data
,
"ntokens"
:
sample
[
"ntokens"
],
"nsentences"
:
sample
[
"target"
].
size
(
0
),
"sample_size"
:
sample_size
,
}
if
self
.
report_accuracy
:
n_correct
,
total
=
self
.
compute_accuracy
(
model
,
[
net_output
],
sample
)
logging_output
[
"n_correct"
]
=
utils
.
item
(
n_correct
.
data
)
logging_output
[
"total"
]
=
utils
.
item
(
total
.
data
)
if
len
(
self
.
multitask_criterion
)
==
0
:
return
loss
,
sample_size
,
logging_output
# multitask
multitask_loss
,
multitask_log
=
self
.
get_multitask_loss
(
model
,
sample
,
extra
)
loss
+=
multitask_loss
logging_output
[
"multitask"
]
=
multitask_log
return
loss
,
sample_size
,
logging_output
@
classmethod
def
reduce_metrics
(
cls
,
logging_outputs
)
->
None
:
super
().
reduce_metrics
(
logging_outputs
)
# inference metrics
if
"targ_frames"
in
logging_outputs
[
0
]:
n
=
sum
(
log
.
get
(
"norm_frames"
,
0
)
for
log
in
logging_outputs
)
for
key
,
new_key
in
[
(
"mcd_loss"
,
"mcd_loss"
),
(
"pred_frames"
,
"pred_ratio"
),
(
"nins"
,
"ins_rate"
),
(
"ndel"
,
"del_rate"
),
]:
val
=
sum
(
log
.
get
(
key
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
new_key
,
val
/
n
,
n
,
round
=
3
)
if
"multitask"
not
in
logging_outputs
[
0
]:
return
MultitaskCriterion
.
reduce_metrics
(
logging_outputs
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return
False
@
register_criterion
(
"speech_to_spectrogram"
,
dataclass
=
Tacotron2CriterionConfig
)
class
SpeechToSpectrogramMultitaskTaskCriterion
(
Tacotron2Criterion
,
MultitaskCriterion
):
def
__init__
(
self
,
task
,
sentence_avg
,
use_guided_attention_loss
,
guided_attention_loss_sigma
,
bce_pos_weight
,
ctc_weight
,
):
super
().
__init__
(
task
,
sentence_avg
,
use_guided_attention_loss
,
guided_attention_loss_sigma
,
bce_pos_weight
,
ctc_weight
,
)
MultitaskCriterion
.
__init__
(
self
,
task
.
multitask_tasks
)
def
forward
(
self
,
model
,
sample
,
reduction
=
"mean"
):
bsz
,
max_len
,
_
=
sample
[
"target"
].
size
()
feat_tgt
=
sample
[
"target"
]
feat_len
=
sample
[
"target_lengths"
].
view
(
bsz
,
1
).
expand
(
-
1
,
max_len
)
eos_tgt
=
torch
.
arange
(
max_len
).
to
(
sample
[
"target"
].
device
)
eos_tgt
=
eos_tgt
.
view
(
1
,
max_len
).
expand
(
bsz
,
-
1
)
eos_tgt
=
(
eos_tgt
==
(
feat_len
-
1
)).
float
()
feat_out
,
eos_out
,
extra
=
model
(
src_tokens
=
sample
[
"net_input"
][
"src_tokens"
],
src_lengths
=
sample
[
"net_input"
][
"src_lengths"
],
prev_output_tokens
=
sample
[
"net_input"
][
"prev_output_tokens"
],
tgt_speaker
=
sample
[
"net_input"
][
"tgt_speaker"
],
target_lengths
=
sample
[
"target_lengths"
],
return_all_hiddens
=
True
,
)
l1_loss
,
mse_loss
,
eos_loss
=
self
.
compute_loss
(
extra
[
"feature_out"
],
feat_out
,
eos_out
,
feat_tgt
,
eos_tgt
,
sample
[
"target_lengths"
],
reduction
,
)
attn_loss
=
torch
.
tensor
(
0.0
).
type_as
(
l1_loss
)
if
self
.
guided_attn
is
not
None
:
attn_loss
=
self
.
guided_attn
(
extra
[
"attn"
],
sample
[
"net_input"
][
"src_lengths"
],
sample
[
"target_lengths"
],
reduction
,
)
loss
=
(
l1_loss
+
mse_loss
+
eos_loss
+
attn_loss
)
# do not include ctc loss as there's no text target
sample_size
=
sample
[
"nsentences"
]
if
self
.
sentence_avg
else
sample
[
"ntokens"
]
logging_output
=
{
"loss"
:
utils
.
item
(
loss
.
data
),
"ntokens"
:
sample
[
"ntokens"
],
"nsentences"
:
sample
[
"nsentences"
],
"sample_size"
:
sample_size
,
"l1_loss"
:
utils
.
item
(
l1_loss
.
data
),
"mse_loss"
:
utils
.
item
(
mse_loss
.
data
),
"eos_loss"
:
utils
.
item
(
eos_loss
.
data
),
"attn_loss"
:
utils
.
item
(
attn_loss
.
data
),
}
if
len
(
self
.
multitask_criterion
)
==
0
:
return
loss
,
sample_size
,
logging_output
# multitask
multitask_loss
,
multitask_log
=
self
.
get_multitask_loss
(
model
,
sample
,
extra
)
loss
+=
multitask_loss
logging_output
[
"multitask"
]
=
multitask_log
return
loss
,
sample_size
,
logging_output
@
classmethod
def
reduce_metrics
(
cls
,
logging_outputs
)
->
None
:
super
().
reduce_metrics
(
logging_outputs
)
# inference metrics
if
"targ_frames"
in
logging_outputs
[
0
]:
n
=
sum
(
log
.
get
(
"norm_frames"
,
0
)
for
log
in
logging_outputs
)
for
key
,
new_key
in
[
(
"mcd_loss"
,
"mcd_loss"
),
(
"pred_frames"
,
"pred_ratio"
),
(
"nins"
,
"ins_rate"
),
(
"ndel"
,
"del_rate"
),
]:
val
=
sum
(
log
.
get
(
key
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
new_key
,
val
/
n
,
n
,
round
=
3
)
if
"multitask"
not
in
logging_outputs
[
0
]:
return
MultitaskCriterion
.
reduce_metrics
(
logging_outputs
)
PyTorch/NLP/new-Transformer/fairseq/criterions/speech_ulm_criterion.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
from
dataclasses
import
dataclass
,
field
import
torch.nn.functional
as
F
from
fairseq
import
metrics
from
fairseq.tasks
import
FairseqTask
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
from
fairseq.dataclass
import
FairseqDataclass
from
omegaconf
import
II
@
dataclass
class
SpeechUnitLmCriterionConfig
(
FairseqDataclass
):
sentence_avg
:
bool
=
II
(
"optimization.sentence_avg"
)
loss_weights
:
str
=
field
(
default
=
"1.;0.0;0.0"
,
metadata
=
{
"help"
:
"Weights of the losses that correspond to token, duration, and F0 streams"
},
)
discrete_duration
:
bool
=
II
(
"task.discrete_duration"
)
discrete_f0
:
bool
=
II
(
"task.discrete_f0"
)
def
mae_loss
(
pred
,
targ
,
mask
,
reduce
=
True
):
if
pred
.
ndim
==
3
:
pred
=
pred
.
squeeze
(
2
)
else
:
assert
pred
.
ndim
==
2
loss
=
(
pred
.
float
()
-
targ
.
float
()).
abs
()
*
(
~
mask
).
float
()
loss
=
loss
.
sum
()
if
reduce
else
loss
.
view
(
-
1
)
return
loss
def
nll_loss
(
pred
,
targ
,
mask
,
reduce
=
True
):
lprob
=
F
.
log_softmax
(
pred
,
dim
=-
1
)
loss
=
F
.
nll_loss
(
lprob
.
view
(
-
1
,
lprob
.
size
(
-
1
)),
targ
.
view
(
-
1
),
reduction
=
"none"
)
loss
=
loss
*
(
~
mask
).
float
().
view
(
-
1
)
loss
=
loss
.
sum
()
if
reduce
else
loss
.
view
(
-
1
)
return
loss
@
register_criterion
(
"speech_unit_lm_criterion"
,
dataclass
=
SpeechUnitLmCriterionConfig
)
class
SpeechUnitLmCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
cfg
:
SpeechUnitLmCriterionConfig
,
task
:
FairseqTask
):
super
().
__init__
(
task
)
self
.
sentence_avg
=
cfg
.
sentence_avg
self
.
weights
=
torch
.
tensor
([
float
(
w
)
for
w
in
cfg
.
loss_weights
.
split
(
";"
)])
assert
self
.
weights
.
size
(
0
)
==
3
assert
(
self
.
weights
>=
0.0
).
all
()
self
.
dur_loss_fn
=
nll_loss
if
cfg
.
discrete_duration
else
mae_loss
self
.
f0_loss_fn
=
nll_loss
if
cfg
.
discrete_f0
else
mae_loss
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output
=
model
(
**
sample
[
"net_input"
])
token_loss
=
nll_loss
(
net_output
[
"token"
],
sample
[
"target"
],
sample
[
"mask"
],
reduce
)
dur_loss
=
self
.
dur_loss_fn
(
net_output
[
"duration"
],
sample
[
"dur_target"
],
sample
[
"dur_mask"
],
reduce
,
)
f0_loss
=
self
.
f0_loss_fn
(
net_output
[
"f0"
],
sample
[
"f0_target"
],
sample
[
"f0_mask"
],
reduce
,
)
loss
=
self
.
weights
.
to
(
token_loss
.
device
)
*
torch
.
stack
(
[
token_loss
,
dur_loss
,
f0_loss
],
dim
=-
1
)
loss
=
loss
.
sum
()
if
reduce
else
loss
.
sum
(
-
1
)
sample_size
=
(
sample
[
"target"
].
size
(
0
)
if
self
.
sentence_avg
else
sample
[
"ntokens"
]
)
logging_output
=
{
"loss"
:
loss
.
detach
().
sum
().
item
(),
"token_loss"
:
token_loss
.
detach
().
sum
().
item
(),
"dur_loss"
:
dur_loss
.
detach
().
sum
().
item
(),
"f0_loss"
:
f0_loss
.
detach
().
sum
().
item
(),
"ntokens"
:
sample
[
"ntokens"
],
"nsentences"
:
sample
[
"target"
].
size
(
0
),
"sample_size"
:
sample_size
,
}
return
loss
,
sample_size
,
logging_output
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
"""Aggregate logging outputs from data parallel training."""
loss_sum
=
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
)
token_loss_sum
=
sum
(
log
.
get
(
"token_loss"
,
0
)
for
log
in
logging_outputs
)
dur_loss_sum
=
sum
(
log
.
get
(
"dur_loss"
,
0
)
for
log
in
logging_outputs
)
f0_loss_sum
=
sum
(
log
.
get
(
"f0_loss"
,
0
)
for
log
in
logging_outputs
)
sample_size
=
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"loss"
,
loss_sum
/
sample_size
,
sample_size
,
round
=
3
)
metrics
.
log_scalar
(
"token_loss"
,
token_loss_sum
/
sample_size
,
sample_size
,
round
=
3
)
metrics
.
log_scalar
(
"dur_loss"
,
dur_loss_sum
/
sample_size
,
sample_size
,
round
=
3
)
metrics
.
log_scalar
(
"f0_loss"
,
f0_loss_sum
/
sample_size
,
sample_size
,
round
=
3
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
return
True
PyTorch/NLP/new-Transformer/fairseq/criterions/tacotron2_loss.py
0 → 100644
View file @
c0f05c10
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
logging
from
dataclasses
import
dataclass
,
field
from
functools
import
lru_cache
from
typing
import
Any
,
Dict
,
List
import
torch
import
torch.nn.functional
as
F
from
omegaconf
import
II
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
from
fairseq.data.data_utils
import
lengths_to_mask
from
fairseq.dataclass
import
FairseqDataclass
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
Tacotron2CriterionConfig
(
FairseqDataclass
):
bce_pos_weight
:
float
=
field
(
default
=
1.0
,
metadata
=
{
"help"
:
"weight of positive examples for BCE loss"
},
)
use_guided_attention_loss
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"use guided attention loss"
},
)
guided_attention_loss_sigma
:
float
=
field
(
default
=
0.4
,
metadata
=
{
"help"
:
"weight of positive examples for BCE loss"
},
)
ctc_weight
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"weight for CTC loss"
})
sentence_avg
:
bool
=
II
(
"optimization.sentence_avg"
)
class
GuidedAttentionLoss
(
torch
.
nn
.
Module
):
"""
Efficiently Trainable Text-to-Speech System Based on Deep Convolutional
Networks with Guided Attention (https://arxiv.org/abs/1710.08969)
"""
def
__init__
(
self
,
sigma
):
super
().
__init__
()
self
.
sigma
=
sigma
@
staticmethod
@
lru_cache
(
maxsize
=
8
)
def
_get_weight
(
s_len
,
t_len
,
sigma
):
grid_x
,
grid_y
=
torch
.
meshgrid
(
torch
.
arange
(
t_len
),
torch
.
arange
(
s_len
))
grid_x
=
grid_x
.
to
(
s_len
.
device
)
grid_y
=
grid_y
.
to
(
s_len
.
device
)
w
=
(
grid_y
.
float
()
/
s_len
-
grid_x
.
float
()
/
t_len
)
**
2
return
1.0
-
torch
.
exp
(
-
w
/
(
2
*
(
sigma
**
2
)))
def
_get_weights
(
self
,
src_lens
,
tgt_lens
):
bsz
,
max_s_len
,
max_t_len
=
len
(
src_lens
),
max
(
src_lens
),
max
(
tgt_lens
)
weights
=
torch
.
zeros
((
bsz
,
max_t_len
,
max_s_len
))
for
i
,
(
s_len
,
t_len
)
in
enumerate
(
zip
(
src_lens
,
tgt_lens
)):
weights
[
i
,
:
t_len
,
:
s_len
]
=
self
.
_get_weight
(
s_len
,
t_len
,
self
.
sigma
)
return
weights
@
staticmethod
def
_get_masks
(
src_lens
,
tgt_lens
):
in_masks
=
lengths_to_mask
(
src_lens
)
out_masks
=
lengths_to_mask
(
tgt_lens
)
return
out_masks
.
unsqueeze
(
2
)
&
in_masks
.
unsqueeze
(
1
)
def
forward
(
self
,
attn
,
src_lens
,
tgt_lens
,
reduction
=
"mean"
):
weights
=
self
.
_get_weights
(
src_lens
,
tgt_lens
).
to
(
attn
.
device
)
masks
=
self
.
_get_masks
(
src_lens
,
tgt_lens
).
to
(
attn
.
device
)
loss
=
(
weights
*
attn
.
transpose
(
1
,
2
)).
masked_select
(
masks
)
loss
=
torch
.
sum
(
loss
)
if
reduction
==
"sum"
else
torch
.
mean
(
loss
)
return
loss
@
register_criterion
(
"tacotron2"
,
dataclass
=
Tacotron2CriterionConfig
)
class
Tacotron2Criterion
(
FairseqCriterion
):
def
__init__
(
self
,
task
,
sentence_avg
,
use_guided_attention_loss
,
guided_attention_loss_sigma
,
bce_pos_weight
,
ctc_weight
,
):
super
().
__init__
(
task
)
self
.
sentence_avg
=
sentence_avg
self
.
bce_pos_weight
=
bce_pos_weight
self
.
guided_attn
=
None
if
use_guided_attention_loss
:
self
.
guided_attn
=
GuidedAttentionLoss
(
guided_attention_loss_sigma
)
self
.
ctc_weight
=
ctc_weight
def
forward
(
self
,
model
,
sample
,
reduction
=
"mean"
):
bsz
,
max_len
,
_
=
sample
[
"target"
].
size
()
feat_tgt
=
sample
[
"target"
]
feat_len
=
sample
[
"target_lengths"
].
view
(
bsz
,
1
).
expand
(
-
1
,
max_len
)
eos_tgt
=
torch
.
arange
(
max_len
).
to
(
sample
[
"target"
].
device
)
eos_tgt
=
eos_tgt
.
view
(
1
,
max_len
).
expand
(
bsz
,
-
1
)
eos_tgt
=
(
eos_tgt
==
(
feat_len
-
1
)).
float
()
src_tokens
=
sample
[
"net_input"
][
"src_tokens"
]
src_lens
=
sample
[
"net_input"
][
"src_lengths"
]
tgt_lens
=
sample
[
"target_lengths"
]
feat_out
,
eos_out
,
extra
=
model
(
src_tokens
=
src_tokens
,
src_lengths
=
src_lens
,
prev_output_tokens
=
sample
[
"net_input"
][
"prev_output_tokens"
],
incremental_state
=
None
,
target_lengths
=
tgt_lens
,
speaker
=
sample
[
"speaker"
],
)
l1_loss
,
mse_loss
,
eos_loss
=
self
.
compute_loss
(
extra
[
"feature_out"
],
feat_out
,
eos_out
,
feat_tgt
,
eos_tgt
,
tgt_lens
,
reduction
,
)
attn_loss
=
torch
.
tensor
(
0.0
).
type_as
(
l1_loss
)
if
self
.
guided_attn
is
not
None
:
attn_loss
=
self
.
guided_attn
(
extra
[
"attn"
],
src_lens
,
tgt_lens
,
reduction
)
ctc_loss
=
torch
.
tensor
(
0.0
).
type_as
(
l1_loss
)
if
self
.
ctc_weight
>
0.0
:
net_output
=
(
feat_out
,
eos_out
,
extra
)
lprobs
=
model
.
get_normalized_probs
(
net_output
,
log_probs
=
True
)
lprobs
=
lprobs
.
transpose
(
0
,
1
)
# T x B x C
src_mask
=
lengths_to_mask
(
src_lens
)
src_tokens_flat
=
src_tokens
.
masked_select
(
src_mask
)
ctc_loss
=
(
F
.
ctc_loss
(
lprobs
,
src_tokens_flat
,
tgt_lens
,
src_lens
,
reduction
=
reduction
,
zero_infinity
=
True
,
)
*
self
.
ctc_weight
)
loss
=
l1_loss
+
mse_loss
+
eos_loss
+
attn_loss
+
ctc_loss
sample_size
=
sample
[
"nsentences"
]
if
self
.
sentence_avg
else
sample
[
"ntokens"
]
logging_output
=
{
"loss"
:
utils
.
item
(
loss
.
data
),
"ntokens"
:
sample
[
"ntokens"
],
"nsentences"
:
sample
[
"nsentences"
],
"sample_size"
:
sample_size
,
"l1_loss"
:
utils
.
item
(
l1_loss
.
data
),
"mse_loss"
:
utils
.
item
(
mse_loss
.
data
),
"eos_loss"
:
utils
.
item
(
eos_loss
.
data
),
"attn_loss"
:
utils
.
item
(
attn_loss
.
data
),
"ctc_loss"
:
utils
.
item
(
ctc_loss
.
data
),
}
return
loss
,
sample_size
,
logging_output
def
compute_loss
(
self
,
feat_out
,
feat_out_post
,
eos_out
,
feat_tgt
,
eos_tgt
,
tgt_lens
,
reduction
=
"mean"
,
):
mask
=
lengths_to_mask
(
tgt_lens
)
_eos_out
=
eos_out
[
mask
].
squeeze
()
_eos_tgt
=
eos_tgt
[
mask
]
_feat_tgt
=
feat_tgt
[
mask
]
_feat_out
=
feat_out
[
mask
]
_feat_out_post
=
feat_out_post
[
mask
]
l1_loss
=
F
.
l1_loss
(
_feat_out
,
_feat_tgt
,
reduction
=
reduction
)
+
F
.
l1_loss
(
_feat_out_post
,
_feat_tgt
,
reduction
=
reduction
)
mse_loss
=
F
.
mse_loss
(
_feat_out
,
_feat_tgt
,
reduction
=
reduction
)
+
F
.
mse_loss
(
_feat_out_post
,
_feat_tgt
,
reduction
=
reduction
)
eos_loss
=
F
.
binary_cross_entropy_with_logits
(
_eos_out
,
_eos_tgt
,
pos_weight
=
torch
.
tensor
(
self
.
bce_pos_weight
),
reduction
=
reduction
,
)
return
l1_loss
,
mse_loss
,
eos_loss
@
classmethod
def
reduce_metrics
(
cls
,
logging_outputs
:
List
[
Dict
[
str
,
Any
]])
->
None
:
ns
=
[
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
]
ntot
=
sum
(
ns
)
ws
=
[
n
/
(
ntot
+
1e-8
)
for
n
in
ns
]
for
key
in
[
"loss"
,
"l1_loss"
,
"mse_loss"
,
"eos_loss"
,
"attn_loss"
,
"ctc_loss"
]:
vals
=
[
log
.
get
(
key
,
0
)
for
log
in
logging_outputs
]
val
=
sum
(
val
*
w
for
val
,
w
in
zip
(
vals
,
ws
))
metrics
.
log_scalar
(
key
,
val
,
ntot
,
round
=
3
)
metrics
.
log_scalar
(
"sample_size"
,
ntot
,
len
(
logging_outputs
))
# inference metrics
if
"targ_frames"
not
in
logging_outputs
[
0
]:
return
n
=
sum
(
log
.
get
(
"targ_frames"
,
0
)
for
log
in
logging_outputs
)
for
key
,
new_key
in
[
(
"mcd_loss"
,
"mcd_loss"
),
(
"pred_frames"
,
"pred_ratio"
),
(
"nins"
,
"ins_rate"
),
(
"ndel"
,
"del_rate"
),
]:
val
=
sum
(
log
.
get
(
key
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
new_key
,
val
/
n
,
n
,
round
=
3
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
return
False
PyTorch/NLP/new-Transformer/fairseq/criterions/wav2vec_criterion.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
dataclasses
import
dataclass
,
field
from
typing
import
List
,
Optional
import
torch
import
torch.nn.functional
as
F
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
from
fairseq.dataclass
import
FairseqDataclass
from
fairseq.logging.meters
import
safe_round
from
fairseq.utils
import
is_xla_tensor
@
dataclass
class
Wav2VecCriterionConfig
(
FairseqDataclass
):
infonce
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"if set, uses cross entropy instead of binary cross entropy (i.e. InfoNCE loss)"
},
)
loss_weights
:
Optional
[
List
[
float
]]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"weights for additional loss terms (not first one)"
},
)
log_keys
:
List
[
str
]
=
field
(
default_factory
=
lambda
:
[],
metadata
=
{
"help"
:
"output keys to log"
},
)
@
register_criterion
(
"wav2vec"
,
dataclass
=
Wav2VecCriterionConfig
)
class
Wav2vecCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
task
,
infonce
=
False
,
loss_weights
=
None
,
log_keys
=
None
):
super
().
__init__
(
task
)
self
.
infonce
=
infonce
self
.
loss_weights
=
loss_weights
self
.
log_keys
=
[]
if
log_keys
is
None
else
log_keys
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output
=
model
(
**
sample
[
"net_input"
])
logits
=
model
.
get_logits
(
net_output
).
float
()
target
=
model
.
get_targets
(
sample
,
net_output
)
self
.
xla
=
is_xla_tensor
(
logits
)
# XXX: handle weights on xla.
weights
=
None
if
hasattr
(
model
,
"get_target_weights"
)
and
not
self
.
infonce
:
weights
=
model
.
get_target_weights
(
target
,
net_output
)
if
torch
.
is_tensor
(
weights
):
weights
=
weights
.
float
()
losses
=
[]
reduction
=
"none"
if
((
not
reduce
)
or
self
.
xla
)
else
"sum"
if
self
.
infonce
:
loss
=
F
.
cross_entropy
(
logits
,
target
,
reduction
=
reduction
)
else
:
loss
=
F
.
binary_cross_entropy_with_logits
(
logits
,
target
.
float
(),
weights
,
reduction
=
reduction
)
if
self
.
xla
:
# tpu-comment: since dynamic shapes lead to recompilations on xla,
# we don't shrink tensors using mask_indices.
# Instead, we use mask indices to adjust loss.
mi
=
(
sample
[
"net_input"
][
"mask_indices"
]
.
transpose
(
0
,
1
)
# logits are transposed in `model.get_logits`
.
reshape
(
logits
.
size
(
0
))
)
loss
=
(
loss
*
mi
).
sum
()
if
reduce
else
(
loss
*
mi
)
if
"sample_size"
in
sample
:
sample_size
=
sample
[
"sample_size"
]
elif
"mask_indices"
in
sample
[
"net_input"
]:
sample_size
=
sample
[
"net_input"
][
"mask_indices"
].
sum
()
else
:
sample_size
=
target
.
numel
()
if
self
.
infonce
else
target
.
long
().
sum
().
item
()
losses
.
append
(
loss
.
detach
().
clone
())
if
self
.
loss_weights
is
not
None
:
assert
hasattr
(
model
,
"get_extra_losses"
)
extra_losses
=
model
.
get_extra_losses
(
net_output
)
if
torch
.
is_tensor
(
extra_losses
):
extra_losses
=
[
extra_losses
]
if
len
(
self
.
loss_weights
)
==
1
and
len
(
extra_losses
)
!=
1
:
self
.
loss_weights
=
[
self
.
loss_weights
[
0
]]
*
len
(
extra_losses
)
assert
len
(
extra_losses
)
==
len
(
self
.
loss_weights
),
f
"
{
len
(
extra_losses
)
}
,
{
len
(
self
.
loss_weights
)
}
"
for
p
,
coef
in
zip
(
extra_losses
,
self
.
loss_weights
):
if
coef
!=
0
and
p
is
not
None
:
p
=
coef
*
p
.
float
()
*
sample_size
loss
+=
p
losses
.
append
(
p
)
logging_output
=
{
"loss"
:
loss
.
item
()
if
(
reduce
and
not
self
.
xla
)
else
loss
.
detach
(),
"ntokens"
:
sample_size
,
"nsentences"
:
sample
[
"id"
].
numel
(),
"sample_size"
:
sample_size
,
}
for
lk
in
self
.
log_keys
:
# Only store "logits" and "target" for computing MAP and MAUC
# during validation
if
lk
==
"logits"
:
if
not
self
.
training
:
logging_output
[
"logits"
]
=
logits
.
cpu
().
numpy
()
elif
lk
==
"target"
:
if
not
self
.
training
:
# If the targets have been mixed with the predictions of
# teacher models, find the original targets
if
hasattr
(
model
,
"get_original_targets"
):
original_target
=
model
.
get_original_targets
(
sample
,
net_output
)
else
:
original_target
=
target
logging_output
[
"target"
]
=
original_target
.
cpu
().
numpy
()
elif
lk
in
net_output
:
value
=
net_output
[
lk
]
if
not
is_xla_tensor
(
value
):
value
=
float
(
value
)
logging_output
[
lk
]
=
value
if
len
(
losses
)
>
1
:
for
i
,
l
in
enumerate
(
losses
):
logging_output
[
f
"loss_
{
i
}
"
]
=
l
.
item
()
if
not
self
.
xla
else
l
.
detach
()
if
self
.
infonce
:
with
torch
.
no_grad
():
if
logits
.
numel
()
==
0
:
corr
=
0
count
=
0
else
:
assert
logits
.
dim
()
>
1
,
logits
.
shape
max
=
logits
.
argmax
(
-
1
)
==
0
min
=
logits
.
argmin
(
-
1
)
==
0
if
is_xla_tensor
(
logits
):
max
,
min
=
max
*
mi
,
min
*
mi
both
=
max
&
min
corr
=
max
.
long
().
sum
()
-
both
.
long
().
sum
()
count
=
mi
.
sum
()
else
:
both
=
max
&
min
corr
=
max
.
long
().
sum
().
item
()
-
both
.
long
().
sum
().
item
()
count
=
float
(
max
.
numel
())
logging_output
[
"correct"
]
=
corr
logging_output
[
"count"
]
=
count
return
loss
,
sample_size
,
logging_output
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
"""Aggregate logging outputs from data parallel training."""
loss_sum
=
utils
.
item
(
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
))
ntokens
=
utils
.
item
(
sum
(
log
.
get
(
"ntokens"
,
0
)
for
log
in
logging_outputs
))
nsentences
=
utils
.
item
(
sum
(
log
.
get
(
"nsentences"
,
0
)
for
log
in
logging_outputs
)
)
sample_size
=
utils
.
item
(
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
)
metrics
.
log_scalar
(
"loss"
,
loss_sum
/
(
sample_size
or
1
)
/
math
.
log
(
2
),
sample_size
,
round
=
3
)
metrics
.
log_scalar
(
"ntokens"
,
ntokens
)
metrics
.
log_scalar
(
"nsentences"
,
nsentences
)
correct
=
sum
(
log
.
get
(
"correct"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_correct"
,
correct
)
total
=
sum
(
log
.
get
(
"count"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_total"
,
total
)
if
total
>
0
:
metrics
.
log_derived
(
"accuracy"
,
lambda
meters
:
safe_round
(
meters
[
"_correct"
].
sum
/
meters
[
"_total"
].
sum
,
5
)
if
meters
[
"_total"
].
sum
>
0
else
float
(
"nan"
),
)
builtin_keys
=
{
"loss"
,
"ntokens"
,
"nsentences"
,
"sample_size"
,
"correct"
,
"count"
,
}
for
k
in
logging_outputs
[
0
]:
if
k
not
in
builtin_keys
:
val
=
sum
(
log
.
get
(
k
,
0
)
for
log
in
logging_outputs
)
if
k
.
startswith
(
"loss"
):
metrics
.
log_scalar
(
k
,
val
/
(
sample_size
or
1
)
/
math
.
log
(
2
),
sample_size
,
round
=
3
)
else
:
metrics
.
log_scalar
(
k
,
val
/
len
(
logging_outputs
),
round
=
3
)
# FIXME: revert when gather based xla reduction is implemented
# @staticmethod
# def logging_outputs_can_be_summed() -> bool:
def
logging_outputs_can_be_summed
(
self
)
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
# XXX: Gather based reduction not implemented for xla yet.
# So we fall to sum based reduction for xla.
return
self
.
xla
PyTorch/NLP/new-Transformer/fairseq/data/__init__.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
from
.dictionary
import
Dictionary
,
TruncatedDictionary
from
.fairseq_dataset
import
FairseqDataset
,
FairseqIterableDataset
from
.base_wrapper_dataset
import
BaseWrapperDataset
from
.add_target_dataset
import
AddTargetDataset
from
.append_token_dataset
import
AppendTokenDataset
from
.audio.raw_audio_dataset
import
BinarizedAudioDataset
,
FileAudioDataset
from
.audio.hubert_dataset
import
HubertDataset
from
.backtranslation_dataset
import
BacktranslationDataset
from
.bucket_pad_length_dataset
import
BucketPadLengthDataset
from
.colorize_dataset
import
ColorizeDataset
from
.concat_dataset
import
ConcatDataset
from
.concat_sentences_dataset
import
ConcatSentencesDataset
from
.denoising_dataset
import
DenoisingDataset
from
.id_dataset
import
IdDataset
from
.indexed_dataset
import
(
IndexedCachedDataset
,
IndexedDataset
,
IndexedRawTextDataset
,
MMapIndexedDataset
,
)
from
.language_pair_dataset
import
LanguagePairDataset
from
.list_dataset
import
ListDataset
from
.lm_context_window_dataset
import
LMContextWindowDataset
from
.lru_cache_dataset
import
LRUCacheDataset
from
.mask_tokens_dataset
import
MaskTokensDataset
from
.monolingual_dataset
import
MonolingualDataset
from
.multi_corpus_sampled_dataset
import
MultiCorpusSampledDataset
from
.nested_dictionary_dataset
import
NestedDictionaryDataset
from
.noising
import
NoisingDataset
from
.numel_dataset
import
NumelDataset
from
.num_samples_dataset
import
NumSamplesDataset
from
.offset_tokens_dataset
import
OffsetTokensDataset
from
.pad_dataset
import
LeftPadDataset
,
PadDataset
,
RightPadDataset
from
.prepend_dataset
import
PrependDataset
from
.prepend_token_dataset
import
PrependTokenDataset
from
.raw_label_dataset
import
RawLabelDataset
from
.replace_dataset
import
ReplaceDataset
from
.resampling_dataset
import
ResamplingDataset
from
.roll_dataset
import
RollDataset
from
.round_robin_zip_datasets
import
RoundRobinZipDatasets
from
.sort_dataset
import
SortDataset
from
.strip_token_dataset
import
StripTokenDataset
from
.subsample_dataset
import
SubsampleDataset
from
.token_block_dataset
import
TokenBlockDataset
from
.transform_eos_dataset
import
TransformEosDataset
from
.transform_eos_lang_pair_dataset
import
TransformEosLangPairDataset
from
.shorten_dataset
import
TruncateDataset
,
RandomCropDataset
from
.multilingual.sampled_multi_dataset
import
SampledMultiDataset
from
.multilingual.sampled_multi_epoch_dataset
import
SampledMultiEpochDataset
from
.fasta_dataset
import
FastaDataset
,
EncodedFastaDataset
from
.transform_eos_concat_langpair_dataset
import
TransformEosConcatLangPairDataset
from
.iterators
import
(
CountingIterator
,
EpochBatchIterator
,
GroupedIterator
,
ShardedIterator
,
)
__all__
=
[
"AddTargetDataset"
,
"AppendTokenDataset"
,
"BacktranslationDataset"
,
"BaseWrapperDataset"
,
"BinarizedAudioDataset"
,
"BucketPadLengthDataset"
,
"ColorizeDataset"
,
"ConcatDataset"
,
"ConcatSentencesDataset"
,
"CountingIterator"
,
"DenoisingDataset"
,
"Dictionary"
,
"EncodedFastaDataset"
,
"EpochBatchIterator"
,
"FairseqDataset"
,
"FairseqIterableDataset"
,
"FastaDataset"
,
"FileAudioDataset"
,
"GroupedIterator"
,
"HubertDataset"
,
"IdDataset"
,
"IndexedCachedDataset"
,
"IndexedDataset"
,
"IndexedRawTextDataset"
,
"LanguagePairDataset"
,
"LeftPadDataset"
,
"ListDataset"
,
"LMContextWindowDataset"
,
"LRUCacheDataset"
,
"MaskTokensDataset"
,
"MMapIndexedDataset"
,
"MonolingualDataset"
,
"MultiCorpusSampledDataset"
,
"NestedDictionaryDataset"
,
"NoisingDataset"
,
"NumelDataset"
,
"NumSamplesDataset"
,
"OffsetTokensDataset"
,
"PadDataset"
,
"PrependDataset"
,
"PrependTokenDataset"
,
"RandomCropDataset"
,
"RawLabelDataset"
,
"ResamplingDataset"
,
"ReplaceDataset"
,
"RightPadDataset"
,
"RollDataset"
,
"RoundRobinZipDatasets"
,
"SampledMultiDataset"
,
"SampledMultiEpochDataset"
,
"ShardedIterator"
,
"SortDataset"
,
"StripTokenDataset"
,
"SubsampleDataset"
,
"TokenBlockDataset"
,
"TransformEosDataset"
,
"TransformEosLangPairDataset"
,
"TransformEosConcatLangPairDataset"
,
"TruncateDataset"
,
"TruncatedDictionary"
,
]
PyTorch/NLP/new-Transformer/fairseq/data/add_target_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
from
.
import
BaseWrapperDataset
,
data_utils
from
fairseq.data.text_compressor
import
TextCompressor
,
TextCompressionLevel
class
AddTargetDataset
(
BaseWrapperDataset
):
def
__init__
(
self
,
dataset
,
labels
,
pad
,
eos
,
batch_targets
,
process_label
=
None
,
label_len_fn
=
None
,
add_to_input
=
False
,
text_compression_level
=
TextCompressionLevel
.
none
,
):
super
().
__init__
(
dataset
)
self
.
labels
=
labels
self
.
batch_targets
=
batch_targets
self
.
pad
=
pad
self
.
eos
=
eos
self
.
process_label
=
process_label
self
.
label_len_fn
=
label_len_fn
self
.
add_to_input
=
add_to_input
self
.
text_compressor
=
TextCompressor
(
level
=
text_compression_level
)
def
get_label
(
self
,
index
,
process_fn
=
None
):
lbl
=
self
.
labels
[
index
]
lbl
=
self
.
text_compressor
.
decompress
(
lbl
)
return
lbl
if
process_fn
is
None
else
process_fn
(
lbl
)
def
__getitem__
(
self
,
index
):
item
=
self
.
dataset
[
index
]
item
[
"label"
]
=
self
.
get_label
(
index
,
process_fn
=
self
.
process_label
)
return
item
def
size
(
self
,
index
):
sz
=
self
.
dataset
.
size
(
index
)
own_sz
=
self
.
label_len_fn
(
self
.
get_label
(
index
))
return
sz
,
own_sz
def
collater
(
self
,
samples
):
collated
=
self
.
dataset
.
collater
(
samples
)
if
len
(
collated
)
==
0
:
return
collated
indices
=
set
(
collated
[
"id"
].
tolist
())
target
=
[
s
[
"label"
]
for
s
in
samples
if
s
[
"id"
]
in
indices
]
if
self
.
add_to_input
:
eos
=
torch
.
LongTensor
([
self
.
eos
])
prev_output_tokens
=
[
torch
.
cat
([
eos
,
t
],
axis
=-
1
)
for
t
in
target
]
target
=
[
torch
.
cat
([
t
,
eos
],
axis
=-
1
)
for
t
in
target
]
collated
[
"net_input"
][
"prev_output_tokens"
]
=
prev_output_tokens
if
self
.
batch_targets
:
collated
[
"target_lengths"
]
=
torch
.
LongTensor
([
len
(
t
)
for
t
in
target
])
target
=
data_utils
.
collate_tokens
(
target
,
pad_idx
=
self
.
pad
,
left_pad
=
False
)
collated
[
"ntokens"
]
=
collated
[
"target_lengths"
].
sum
().
item
()
if
getattr
(
collated
[
"net_input"
],
"prev_output_tokens"
,
None
):
collated
[
"net_input"
][
"prev_output_tokens"
]
=
data_utils
.
collate_tokens
(
collated
[
"net_input"
][
"prev_output_tokens"
],
pad_idx
=
self
.
pad
,
left_pad
=
False
,
)
else
:
collated
[
"ntokens"
]
=
sum
([
len
(
t
)
for
t
in
target
])
collated
[
"target"
]
=
target
return
collated
def
filter_indices_by_size
(
self
,
indices
,
max_sizes
):
indices
,
ignored
=
data_utils
.
_filter_by_size_dynamic
(
indices
,
self
.
size
,
max_sizes
)
return
indices
,
ignored
PyTorch/NLP/new-Transformer/fairseq/data/append_token_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
numpy
as
np
import
torch
from
.
import
BaseWrapperDataset
class
AppendTokenDataset
(
BaseWrapperDataset
):
def
__init__
(
self
,
dataset
,
token
=
None
):
super
().
__init__
(
dataset
)
self
.
token
=
token
if
token
is
not
None
:
self
.
_sizes
=
np
.
array
(
dataset
.
sizes
)
+
1
else
:
self
.
_sizes
=
dataset
.
sizes
def
__getitem__
(
self
,
idx
):
item
=
self
.
dataset
[
idx
]
if
self
.
token
is
not
None
:
item
=
torch
.
cat
([
item
,
item
.
new
([
self
.
token
])])
return
item
@
property
def
sizes
(
self
):
return
self
.
_sizes
def
num_tokens
(
self
,
index
):
n
=
self
.
dataset
.
num_tokens
(
index
)
if
self
.
token
is
not
None
:
n
+=
1
return
n
def
size
(
self
,
index
):
n
=
self
.
dataset
.
size
(
index
)
if
self
.
token
is
not
None
:
n
+=
1
return
n
PyTorch/NLP/Transformer/
scripts
/__init__.py
→
PyTorch/NLP/
new-
Transformer/
fairseq/data/audio
/__init__.py
View file @
c0f05c10
File moved
PyTorch/NLP/new-Transformer/fairseq/data/audio/audio_utils.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
mmap
from
pathlib
import
Path
from
typing
import
BinaryIO
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
SF_AUDIO_FILE_EXTENSIONS
=
{
".wav"
,
".flac"
,
".ogg"
}
FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS
=
{
".npy"
,
".wav"
,
".flac"
,
".ogg"
}
def
convert_waveform
(
waveform
:
Union
[
np
.
ndarray
,
torch
.
Tensor
],
sample_rate
:
int
,
normalize_volume
:
bool
=
False
,
to_mono
:
bool
=
False
,
to_sample_rate
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
Union
[
np
.
ndarray
,
torch
.
Tensor
],
int
]:
"""convert a waveform:
- to a target sample rate
- from multi-channel to mono channel
- volume normalization
Args:
waveform (numpy.ndarray or torch.Tensor): 2D original waveform
(channels x length)
sample_rate (int): original sample rate
normalize_volume (bool): perform volume normalization
to_mono (bool): convert to mono channel if having multiple channels
to_sample_rate (Optional[int]): target sample rate
Returns:
waveform (numpy.ndarray): converted 2D waveform (channels x length)
sample_rate (float): target sample rate
"""
try
:
import
torchaudio.sox_effects
as
ta_sox
except
ImportError
:
raise
ImportError
(
"Please install torchaudio: pip install torchaudio"
)
effects
=
[]
if
normalize_volume
:
effects
.
append
([
"gain"
,
"-n"
])
if
to_sample_rate
is
not
None
and
to_sample_rate
!=
sample_rate
:
effects
.
append
([
"rate"
,
f
"
{
to_sample_rate
}
"
])
if
to_mono
and
waveform
.
shape
[
0
]
>
1
:
effects
.
append
([
"channels"
,
"1"
])
if
len
(
effects
)
>
0
:
is_np_input
=
isinstance
(
waveform
,
np
.
ndarray
)
_waveform
=
torch
.
from_numpy
(
waveform
)
if
is_np_input
else
waveform
converted
,
converted_sample_rate
=
ta_sox
.
apply_effects_tensor
(
_waveform
,
sample_rate
,
effects
)
if
is_np_input
:
converted
=
converted
.
numpy
()
return
converted
,
converted_sample_rate
return
waveform
,
sample_rate
def
get_waveform
(
path_or_fp
:
Union
[
str
,
BinaryIO
],
normalization
:
bool
=
True
,
mono
:
bool
=
True
,
frames
:
int
=
-
1
,
start
:
int
=
0
,
always_2d
:
bool
=
True
,
output_sample_rate
:
Optional
[
int
]
=
None
,
normalize_volume
:
bool
=
False
,
)
->
Tuple
[
np
.
ndarray
,
int
]:
"""Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio.
Args:
path_or_fp (str or BinaryIO): the path or file-like object
normalization (bool): normalize values to [-1, 1] (Default: True)
mono (bool): convert multi-channel audio to mono-channel one
frames (int): the number of frames to read. (-1 for reading all)
start (int): Where to start reading. A negative value counts from the end.
always_2d (bool): always return 2D array even for mono-channel audios
output_sample_rate (Optional[int]): output sample rate
normalize_volume (bool): normalize volume
Returns:
waveform (numpy.ndarray): 1D or 2D waveform (channels x length)
sample_rate (float): sample rate
"""
if
isinstance
(
path_or_fp
,
str
):
ext
=
Path
(
path_or_fp
).
suffix
if
ext
not
in
SF_AUDIO_FILE_EXTENSIONS
:
raise
ValueError
(
f
"Unsupported audio format:
{
ext
}
"
)
try
:
import
soundfile
as
sf
except
ImportError
:
raise
ImportError
(
"Please install soundfile: pip install soundfile"
)
waveform
,
sample_rate
=
sf
.
read
(
path_or_fp
,
dtype
=
"float32"
,
always_2d
=
True
,
frames
=
frames
,
start
=
start
)
waveform
=
waveform
.
T
# T x C -> C x T
waveform
,
sample_rate
=
convert_waveform
(
waveform
,
sample_rate
,
normalize_volume
=
normalize_volume
,
to_mono
=
mono
,
to_sample_rate
=
output_sample_rate
,
)
if
not
normalization
:
waveform
*=
2
**
15
# denormalized to 16-bit signed integers
if
not
always_2d
:
waveform
=
waveform
.
squeeze
(
axis
=
0
)
return
waveform
,
sample_rate
def
_get_kaldi_fbank
(
waveform
:
np
.
ndarray
,
sample_rate
:
int
,
n_bins
=
80
)
->
Optional
[
np
.
ndarray
]:
"""Get mel-filter bank features via PyKaldi."""
try
:
from
kaldi.feat.fbank
import
Fbank
,
FbankOptions
from
kaldi.feat.mel
import
MelBanksOptions
from
kaldi.feat.window
import
FrameExtractionOptions
from
kaldi.matrix
import
Vector
mel_opts
=
MelBanksOptions
()
mel_opts
.
num_bins
=
n_bins
frame_opts
=
FrameExtractionOptions
()
frame_opts
.
samp_freq
=
sample_rate
opts
=
FbankOptions
()
opts
.
mel_opts
=
mel_opts
opts
.
frame_opts
=
frame_opts
fbank
=
Fbank
(
opts
=
opts
)
features
=
fbank
.
compute
(
Vector
(
waveform
.
squeeze
()),
1.0
).
numpy
()
return
features
except
ImportError
:
return
None
def
_get_torchaudio_fbank
(
waveform
:
np
.
ndarray
,
sample_rate
,
n_bins
=
80
)
->
Optional
[
np
.
ndarray
]:
"""Get mel-filter bank features via TorchAudio."""
try
:
import
torchaudio.compliance.kaldi
as
ta_kaldi
waveform
=
torch
.
from_numpy
(
waveform
)
features
=
ta_kaldi
.
fbank
(
waveform
,
num_mel_bins
=
n_bins
,
sample_frequency
=
sample_rate
)
return
features
.
numpy
()
except
ImportError
:
return
None
def
get_fbank
(
path_or_fp
:
Union
[
str
,
BinaryIO
],
n_bins
=
80
)
->
np
.
ndarray
:
"""Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi
(faster CPP implementation) to TorchAudio (Python implementation). Note that
Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the
waveform should not be normalized."""
waveform
,
sample_rate
=
get_waveform
(
path_or_fp
,
normalization
=
False
)
features
=
_get_kaldi_fbank
(
waveform
,
sample_rate
,
n_bins
)
if
features
is
None
:
features
=
_get_torchaudio_fbank
(
waveform
,
sample_rate
,
n_bins
)
if
features
is
None
:
raise
ImportError
(
"Please install pyKaldi or torchaudio to enable "
"online filterbank feature extraction"
)
return
features
def
is_npy_data
(
data
:
bytes
)
->
bool
:
return
data
[
0
]
==
147
and
data
[
1
]
==
78
def
is_sf_audio_data
(
data
:
bytes
)
->
bool
:
is_wav
=
data
[
0
]
==
82
and
data
[
1
]
==
73
and
data
[
2
]
==
70
is_flac
=
data
[
0
]
==
102
and
data
[
1
]
==
76
and
data
[
2
]
==
97
is_ogg
=
data
[
0
]
==
79
and
data
[
1
]
==
103
and
data
[
2
]
==
103
return
is_wav
or
is_flac
or
is_ogg
def
mmap_read
(
path
:
str
,
offset
:
int
,
length
:
int
)
->
bytes
:
with
open
(
path
,
"rb"
)
as
f
:
with
mmap
.
mmap
(
f
.
fileno
(),
length
=
0
,
access
=
mmap
.
ACCESS_READ
)
as
mmap_o
:
data
=
mmap_o
[
offset
:
offset
+
length
]
return
data
def
read_from_stored_zip
(
zip_path
:
str
,
offset
:
int
,
length
:
int
)
->
bytes
:
return
mmap_read
(
zip_path
,
offset
,
length
)
def
parse_path
(
path
:
str
)
->
Tuple
[
str
,
List
[
int
]]:
"""Parse data path which is either a path to
1. a .npy/.wav/.flac/.ogg file
2. a stored ZIP file with slicing info: "[zip_path]:[offset]:[length]"
Args:
path (str): the data path to parse
Returns:
file_path (str): the file path
slice_ptr (list of int): empty in case 1;
byte offset and length for the slice in case 2
"""
if
Path
(
path
).
suffix
in
FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS
:
_path
,
slice_ptr
=
path
,
[]
else
:
_path
,
*
slice_ptr
=
path
.
split
(
":"
)
if
not
Path
(
_path
).
is_file
():
raise
FileNotFoundError
(
f
"File not found:
{
_path
}
"
)
assert
len
(
slice_ptr
)
in
{
0
,
2
},
f
"Invalid path:
{
path
}
"
slice_ptr
=
[
int
(
i
)
for
i
in
slice_ptr
]
return
_path
,
slice_ptr
def
get_window
(
window_fn
:
callable
,
n_fft
:
int
,
win_length
:
int
)
->
torch
.
Tensor
:
padding
=
n_fft
-
win_length
assert
padding
>=
0
return
F
.
pad
(
window_fn
(
win_length
),
(
padding
//
2
,
padding
-
padding
//
2
))
def
get_fourier_basis
(
n_fft
:
int
)
->
torch
.
Tensor
:
basis
=
np
.
fft
.
fft
(
np
.
eye
(
n_fft
))
basis
=
np
.
vstack
(
[
np
.
real
(
basis
[:
n_fft
//
2
+
1
,
:]),
np
.
imag
(
basis
[:
n_fft
//
2
+
1
,
:])]
)
return
torch
.
from_numpy
(
basis
).
float
()
def
get_mel_filters
(
sample_rate
:
int
,
n_fft
:
int
,
n_mels
:
int
,
f_min
:
float
,
f_max
:
float
)
->
torch
.
Tensor
:
try
:
import
librosa
except
ImportError
:
raise
ImportError
(
"Please install librosa: pip install librosa"
)
basis
=
librosa
.
filters
.
mel
(
sample_rate
,
n_fft
,
n_mels
,
f_min
,
f_max
)
return
torch
.
from_numpy
(
basis
).
float
()
class
TTSSpectrogram
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
n_fft
:
int
,
win_length
:
int
,
hop_length
:
int
,
window_fn
:
callable
=
torch
.
hann_window
,
return_phase
:
bool
=
False
,
)
->
None
:
super
(
TTSSpectrogram
,
self
).
__init__
()
self
.
n_fft
=
n_fft
self
.
hop_length
=
hop_length
self
.
return_phase
=
return_phase
basis
=
get_fourier_basis
(
n_fft
).
unsqueeze
(
1
)
basis
*=
get_window
(
window_fn
,
n_fft
,
win_length
)
self
.
register_buffer
(
"basis"
,
basis
)
def
forward
(
self
,
waveform
:
torch
.
Tensor
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
padding
=
(
self
.
n_fft
//
2
,
self
.
n_fft
//
2
)
x
=
F
.
pad
(
waveform
.
unsqueeze
(
1
),
padding
,
mode
=
"reflect"
)
x
=
F
.
conv1d
(
x
,
self
.
basis
,
stride
=
self
.
hop_length
)
real_part
=
x
[:,
:
self
.
n_fft
//
2
+
1
,
:]
imag_part
=
x
[:,
self
.
n_fft
//
2
+
1
:,
:]
magnitude
=
torch
.
sqrt
(
real_part
**
2
+
imag_part
**
2
)
if
self
.
return_phase
:
phase
=
torch
.
atan2
(
imag_part
,
real_part
)
return
magnitude
,
phase
return
magnitude
class
TTSMelScale
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
n_mels
:
int
,
sample_rate
:
int
,
f_min
:
float
,
f_max
:
float
,
n_stft
:
int
)
->
None
:
super
(
TTSMelScale
,
self
).
__init__
()
basis
=
get_mel_filters
(
sample_rate
,
(
n_stft
-
1
)
*
2
,
n_mels
,
f_min
,
f_max
)
self
.
register_buffer
(
"basis"
,
basis
)
def
forward
(
self
,
specgram
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
matmul
(
self
.
basis
,
specgram
)
PyTorch/NLP/new-Transformer/fairseq/data/audio/data_cfg.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
argparse
import
Namespace
from
pathlib
import
Path
from
typing
import
Dict
,
Optional
from
fairseq.data
import
Dictionary
def
get_config_from_yaml
(
yaml_path
:
Path
):
try
:
import
yaml
except
ImportError
:
print
(
"Please install PyYAML: pip install PyYAML"
)
config
=
{}
if
yaml_path
.
is_file
():
try
:
with
open
(
yaml_path
)
as
f
:
config
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
except
Exception
as
e
:
raise
Exception
(
f
"Failed to load config from
{
yaml_path
.
as_posix
()
}
:
{
e
}
"
)
else
:
raise
FileNotFoundError
(
f
"
{
yaml_path
.
as_posix
()
}
not found"
)
return
config
class
S2TDataConfig
(
object
):
"""Wrapper class for data config YAML"""
def
__init__
(
self
,
yaml_path
:
Path
):
self
.
config
=
get_config_from_yaml
(
yaml_path
)
self
.
root
=
yaml_path
.
parent
def
_auto_convert_to_abs_path
(
self
,
x
):
if
isinstance
(
x
,
str
):
if
not
Path
(
x
).
exists
()
and
(
self
.
root
/
x
).
exists
():
return
(
self
.
root
/
x
).
as_posix
()
elif
isinstance
(
x
,
dict
):
return
{
k
:
self
.
_auto_convert_to_abs_path
(
v
)
for
k
,
v
in
x
.
items
()}
return
x
@
property
def
vocab_filename
(
self
):
"""fairseq vocabulary file under data root"""
return
self
.
config
.
get
(
"vocab_filename"
,
"dict.txt"
)
@
property
def
speaker_set_filename
(
self
):
"""speaker set file under data root"""
return
self
.
config
.
get
(
"speaker_set_filename"
,
None
)
@
property
def
shuffle
(
self
)
->
bool
:
"""Shuffle dataset samples before batching"""
return
self
.
config
.
get
(
"shuffle"
,
False
)
@
property
def
pre_tokenizer
(
self
)
->
Dict
:
"""Pre-tokenizer to apply before subword tokenization. Returning
a dictionary with `tokenizer` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
tokenizer
=
self
.
config
.
get
(
"pre_tokenizer"
,
{
"tokenizer"
:
None
})
return
self
.
_auto_convert_to_abs_path
(
tokenizer
)
@
property
def
bpe_tokenizer
(
self
)
->
Dict
:
"""Subword tokenizer to apply after pre-tokenization. Returning
a dictionary with `bpe` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
tokenizer
=
self
.
config
.
get
(
"bpe_tokenizer"
,
{
"bpe"
:
None
})
return
self
.
_auto_convert_to_abs_path
(
tokenizer
)
@
property
def
prepend_tgt_lang_tag
(
self
)
->
bool
:
"""Prepend target lang ID token as the target BOS (e.g. for to-many
multilingual setting). During inference, this requires `--prefix-size 1`
to force BOS to be lang ID token."""
return
self
.
config
.
get
(
"prepend_tgt_lang_tag"
,
False
)
@
property
def
prepend_bos_and_append_tgt_lang_tag
(
self
)
->
bool
:
"""Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
return
self
.
config
.
get
(
"prepend_bos_and_append_tgt_lang_tag"
,
False
)
@
property
def
input_feat_per_channel
(
self
):
"""The dimension of input features (per audio channel)"""
return
self
.
config
.
get
(
"input_feat_per_channel"
,
80
)
@
property
def
input_channels
(
self
):
"""The number of channels in the input audio"""
return
self
.
config
.
get
(
"input_channels"
,
1
)
@
property
def
sample_rate
(
self
):
return
self
.
config
.
get
(
"sample_rate"
,
16_000
)
@
property
def
sampling_alpha
(
self
):
"""Hyper-parameter alpha = 1/T for temperature-based resampling.
(alpha = 1 for no resampling)"""
return
self
.
config
.
get
(
"sampling_alpha"
,
1.0
)
@
property
def
use_audio_input
(
self
):
"""Needed by the dataset loader to see if the model requires
raw audio as inputs."""
return
self
.
config
.
get
(
"use_audio_input"
,
False
)
def
standardize_audio
(
self
)
->
bool
:
return
self
.
use_audio_input
and
self
.
config
.
get
(
"standardize_audio"
,
False
)
@
property
def
use_sample_rate
(
self
):
"""Needed by the dataset loader to see if the model requires
raw audio with specific sample rate as inputs."""
return
self
.
config
.
get
(
"use_sample_rate"
,
16000
)
@
property
def
audio_root
(
self
):
"""Audio paths in the manifest TSV can be relative and this provides
the root path. Set this to empty string when using absolute paths."""
return
self
.
config
.
get
(
"audio_root"
,
""
)
def
get_feature_transforms
(
self
,
split
,
is_train
):
"""Split-specific feature transforms. Allowing train set
wildcard `_train`, evaluation set wildcard `_eval` and general
wildcard `*` for matching."""
from
copy
import
deepcopy
cfg
=
deepcopy
(
self
.
config
)
_cur
=
cfg
.
get
(
"transforms"
,
{})
cur
=
_cur
.
get
(
split
)
cur
=
_cur
.
get
(
"_train"
)
if
cur
is
None
and
is_train
else
cur
cur
=
_cur
.
get
(
"_eval"
)
if
cur
is
None
and
not
is_train
else
cur
cur
=
_cur
.
get
(
"*"
)
if
cur
is
None
else
cur
cfg
[
"transforms"
]
=
cur
return
cfg
@
property
def
global_cmvn_stats_npz
(
self
)
->
Optional
[
str
]:
path
=
self
.
config
.
get
(
"global_cmvn"
,
{}).
get
(
"stats_npz_path"
,
None
)
return
self
.
_auto_convert_to_abs_path
(
path
)
@
property
def
vocoder
(
self
)
->
Dict
[
str
,
str
]:
vocoder
=
self
.
config
.
get
(
"vocoder"
,
{
"type"
:
"griffin_lim"
})
return
self
.
_auto_convert_to_abs_path
(
vocoder
)
@
property
def
hub
(
self
)
->
Dict
[
str
,
str
]:
return
self
.
config
.
get
(
"hub"
,
{})
class
S2SDataConfig
(
S2TDataConfig
):
"""Wrapper class for data config YAML"""
@
property
def
vocab_filename
(
self
):
"""fairseq vocabulary file under data root"""
return
self
.
config
.
get
(
"vocab_filename"
,
None
)
@
property
def
pre_tokenizer
(
self
)
->
Dict
:
return
None
@
property
def
bpe_tokenizer
(
self
)
->
Dict
:
return
None
@
property
def
input_transformed_channels
(
self
):
"""The number of channels in the audio after feature transforms"""
# TODO: move this into individual transforms
_cur
=
self
.
config
.
get
(
"transforms"
,
{})
cur
=
_cur
.
get
(
"_train"
,
[])
_channels
=
self
.
input_channels
if
"delta_deltas"
in
cur
:
_channels
*=
3
return
_channels
@
property
def
output_sample_rate
(
self
):
"""The audio sample rate of output target speech"""
return
self
.
config
.
get
(
"output_sample_rate"
,
22050
)
@
property
def
target_speaker_embed
(
self
):
"""Target speaker embedding file (one line per target audio sample)"""
return
self
.
config
.
get
(
"target_speaker_embed"
,
None
)
@
property
def
prepend_tgt_lang_tag_as_bos
(
self
)
->
bool
:
"""Prepend target lang ID token as the target BOS."""
return
self
.
config
.
get
(
"prepend_tgt_lang_tag_as_bos"
,
False
)
class
MultitaskConfig
(
object
):
"""Wrapper class for data config YAML"""
def
__init__
(
self
,
yaml_path
:
Path
):
config
=
get_config_from_yaml
(
yaml_path
)
self
.
config
=
{}
for
k
,
v
in
config
.
items
():
self
.
config
[
k
]
=
SingleTaskConfig
(
k
,
v
)
def
get_all_tasks
(
self
):
return
self
.
config
def
get_single_task
(
self
,
name
):
assert
name
in
self
.
config
,
f
"multitask '
{
name
}
' does not exist!"
return
self
.
config
[
name
]
class
SingleTaskConfig
(
object
):
def
__init__
(
self
,
name
,
config
):
self
.
task_name
=
name
self
.
config
=
config
dict_path
=
config
.
get
(
"dict"
,
""
)
self
.
tgt_dict
=
Dictionary
.
load
(
dict_path
)
if
Path
(
dict_path
).
exists
()
else
None
@
property
def
data
(
self
):
return
self
.
config
.
get
(
"data"
,
""
)
@
property
def
decoder_type
(
self
):
return
self
.
config
.
get
(
"decoder_type"
,
"transformer"
)
@
property
def
decoder_args
(
self
):
"""Decoder arch related args"""
args
=
self
.
config
.
get
(
"decoder_args"
,
{})
return
Namespace
(
**
args
)
@
property
def
criterion_cfg
(
self
):
"""cfg for the multitask criterion"""
if
self
.
decoder_type
==
"ctc"
:
from
fairseq.criterions.ctc
import
CtcCriterionConfig
cfg
=
CtcCriterionConfig
cfg
.
zero_infinity
=
self
.
config
.
get
(
"zero_infinity"
,
True
)
else
:
from
fairseq.criterions.label_smoothed_cross_entropy
import
(
LabelSmoothedCrossEntropyCriterionConfig
,
)
cfg
=
LabelSmoothedCrossEntropyCriterionConfig
cfg
.
label_smoothing
=
self
.
config
.
get
(
"label_smoothing"
,
0.2
)
return
cfg
@
property
def
input_from
(
self
):
"""Condition on encoder/decoder of the main model"""
return
"decoder"
if
"decoder_layer"
in
self
.
config
else
"encoder"
@
property
def
input_layer
(
self
):
if
self
.
input_from
==
"decoder"
:
return
self
.
config
[
"decoder_layer"
]
-
1
else
:
# default using the output from the last encoder layer (-1)
return
self
.
config
.
get
(
"encoder_layer"
,
0
)
-
1
@
property
def
loss_weight_schedule
(
self
):
return
(
"decay"
if
"loss_weight_max"
in
self
.
config
and
"loss_weight_decay_steps"
in
self
.
config
else
"fixed"
)
def
get_loss_weight
(
self
,
num_updates
):
if
self
.
loss_weight_schedule
==
"fixed"
:
weight
=
self
.
config
.
get
(
"loss_weight"
,
1.0
)
else
:
# "decay"
assert
(
self
.
config
.
get
(
"loss_weight_decay_steps"
,
0
)
>
0
),
"loss_weight_decay_steps must be greater than 0 for a decay schedule"
loss_weight_min
=
self
.
config
.
get
(
"loss_weight_min"
,
0.0001
)
loss_weight_decay_stepsize
=
(
self
.
config
[
"loss_weight_max"
]
-
loss_weight_min
)
/
self
.
config
[
"loss_weight_decay_steps"
]
weight
=
max
(
self
.
config
[
"loss_weight_max"
]
-
loss_weight_decay_stepsize
*
num_updates
,
loss_weight_min
,
)
return
weight
PyTorch/NLP/new-Transformer/fairseq/data/audio/feature_transforms/__init__.py
0 → 100644
View file @
c0f05c10
import
importlib
import
os
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
Optional
class
AudioFeatureTransform
(
ABC
):
@
classmethod
@
abstractmethod
def
from_config_dict
(
cls
,
config
:
Optional
[
Dict
]
=
None
):
pass
AUDIO_FEATURE_TRANSFORM_REGISTRY
=
{}
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES
=
set
()
def
register_audio_feature_transform
(
name
):
def
register_audio_feature_transform_cls
(
cls
):
if
name
in
AUDIO_FEATURE_TRANSFORM_REGISTRY
:
raise
ValueError
(
f
"Cannot register duplicate transform (
{
name
}
)"
)
if
not
issubclass
(
cls
,
AudioFeatureTransform
):
raise
ValueError
(
f
"Transform (
{
name
}
:
{
cls
.
__name__
}
) must extend "
"AudioFeatureTransform"
)
if
cls
.
__name__
in
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES
:
raise
ValueError
(
f
"Cannot register audio feature transform with duplicate "
f
"class name (
{
cls
.
__name__
}
)"
)
AUDIO_FEATURE_TRANSFORM_REGISTRY
[
name
]
=
cls
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES
.
add
(
cls
.
__name__
)
return
cls
return
register_audio_feature_transform_cls
def
get_audio_feature_transform
(
name
):
return
AUDIO_FEATURE_TRANSFORM_REGISTRY
[
name
]
transforms_dir
=
os
.
path
.
dirname
(
__file__
)
for
file
in
os
.
listdir
(
transforms_dir
):
path
=
os
.
path
.
join
(
transforms_dir
,
file
)
if
(
not
file
.
startswith
(
"_"
)
and
not
file
.
startswith
(
"."
)
and
(
file
.
endswith
(
".py"
)
or
os
.
path
.
isdir
(
path
))
):
name
=
file
[:
file
.
find
(
".py"
)]
if
file
.
endswith
(
".py"
)
else
file
importlib
.
import_module
(
"fairseq.data.audio.feature_transforms."
+
name
)
class
CompositeAudioFeatureTransform
(
AudioFeatureTransform
):
@
classmethod
def
from_config_dict
(
cls
,
config
=
None
):
_config
=
{}
if
config
is
None
else
config
_transforms
=
_config
.
get
(
"transforms"
)
if
_transforms
is
None
:
return
None
transforms
=
[
get_audio_feature_transform
(
_t
).
from_config_dict
(
_config
.
get
(
_t
))
for
_t
in
_transforms
]
return
CompositeAudioFeatureTransform
(
transforms
)
def
__init__
(
self
,
transforms
):
self
.
transforms
=
[
t
for
t
in
transforms
if
t
is
not
None
]
def
__call__
(
self
,
x
):
for
t
in
self
.
transforms
:
x
=
t
(
x
)
return
x
def
__repr__
(
self
):
format_string
=
(
[
self
.
__class__
.
__name__
+
"("
]
+
[
f
"
{
t
.
__repr__
()
}
"
for
t
in
self
.
transforms
]
+
[
")"
]
)
return
"
\n
"
.
join
(
format_string
)
PyTorch/NLP/new-Transformer/fairseq/data/audio/feature_transforms/delta_deltas.py
0 → 100644
View file @
c0f05c10
import
numpy
as
np
import
torch
from
fairseq.data.audio.feature_transforms
import
(
AudioFeatureTransform
,
register_audio_feature_transform
,
)
@
register_audio_feature_transform
(
"delta_deltas"
)
class
DeltaDeltas
(
AudioFeatureTransform
):
"""Expand delta-deltas features from spectrum."""
@
classmethod
def
from_config_dict
(
cls
,
config
=
None
):
_config
=
{}
if
config
is
None
else
config
return
DeltaDeltas
(
_config
.
get
(
"win_length"
,
5
))
def
__init__
(
self
,
win_length
=
5
):
self
.
win_length
=
win_length
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
def
__call__
(
self
,
spectrogram
):
from
torchaudio.functional
import
compute_deltas
assert
len
(
spectrogram
.
shape
)
==
2
,
"spectrogram must be a 2-D tensor."
# spectrogram is T x F, while compute_deltas takes (…, F, T)
spectrogram
=
torch
.
from_numpy
(
spectrogram
).
transpose
(
0
,
1
)
delta
=
compute_deltas
(
spectrogram
)
delta_delta
=
compute_deltas
(
delta
)
out_feat
=
np
.
concatenate
(
[
spectrogram
,
delta
.
numpy
(),
delta_delta
.
numpy
()],
axis
=
0
)
out_feat
=
np
.
transpose
(
out_feat
)
return
out_feat
PyTorch/NLP/new-Transformer/fairseq/data/audio/feature_transforms/global_cmvn.py
0 → 100644
View file @
c0f05c10
import
numpy
as
np
from
fairseq.data.audio.feature_transforms
import
(
AudioFeatureTransform
,
register_audio_feature_transform
,
)
@
register_audio_feature_transform
(
"global_cmvn"
)
class
GlobalCMVN
(
AudioFeatureTransform
):
"""Global CMVN (cepstral mean and variance normalization). The global mean
and variance need to be pre-computed and stored in NumPy format (.npz)."""
@
classmethod
def
from_config_dict
(
cls
,
config
=
None
):
_config
=
{}
if
config
is
None
else
config
return
GlobalCMVN
(
_config
.
get
(
"stats_npz_path"
))
def
__init__
(
self
,
stats_npz_path
):
self
.
stats_npz_path
=
stats_npz_path
stats
=
np
.
load
(
stats_npz_path
)
self
.
mean
,
self
.
std
=
stats
[
"mean"
],
stats
[
"std"
]
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
f
'(stats_npz_path="
{
self
.
stats_npz_path
}
")'
def
__call__
(
self
,
x
):
x
=
np
.
subtract
(
x
,
self
.
mean
)
x
=
np
.
divide
(
x
,
self
.
std
)
return
x
PyTorch/NLP/new-Transformer/fairseq/data/audio/feature_transforms/specaugment.py
0 → 100644
View file @
c0f05c10
import
math
import
numbers
from
typing
import
Optional
import
numpy
as
np
from
fairseq.data.audio.feature_transforms
import
(
AudioFeatureTransform
,
register_audio_feature_transform
,
)
@
register_audio_feature_transform
(
"specaugment"
)
class
SpecAugmentTransform
(
AudioFeatureTransform
):
"""SpecAugment (https://arxiv.org/abs/1904.08779)"""
@
classmethod
def
from_config_dict
(
cls
,
config
=
None
):
_config
=
{}
if
config
is
None
else
config
return
SpecAugmentTransform
(
_config
.
get
(
"time_warp_W"
,
0
),
_config
.
get
(
"freq_mask_N"
,
0
),
_config
.
get
(
"freq_mask_F"
,
0
),
_config
.
get
(
"time_mask_N"
,
0
),
_config
.
get
(
"time_mask_T"
,
0
),
_config
.
get
(
"time_mask_p"
,
0.0
),
_config
.
get
(
"mask_value"
,
None
),
)
def
__init__
(
self
,
time_warp_w
:
int
=
0
,
freq_mask_n
:
int
=
0
,
freq_mask_f
:
int
=
0
,
time_mask_n
:
int
=
0
,
time_mask_t
:
int
=
0
,
time_mask_p
:
float
=
0.0
,
mask_value
:
Optional
[
float
]
=
0.0
,
):
# Sanity checks
assert
mask_value
is
None
or
isinstance
(
mask_value
,
numbers
.
Number
),
f
"mask_value (type:
{
type
(
mask_value
)
}
) must be None or a number"
if
freq_mask_n
>
0
:
assert
freq_mask_f
>
0
,
(
f
"freq_mask_F (
{
freq_mask_f
}
) "
f
"must be larger than 0 when doing freq masking."
)
if
time_mask_n
>
0
:
assert
time_mask_t
>
0
,
(
f
"time_mask_T (
{
time_mask_t
}
) must be larger than 0 when "
f
"doing time masking."
)
self
.
time_warp_w
=
time_warp_w
self
.
freq_mask_n
=
freq_mask_n
self
.
freq_mask_f
=
freq_mask_f
self
.
time_mask_n
=
time_mask_n
self
.
time_mask_t
=
time_mask_t
self
.
time_mask_p
=
time_mask_p
self
.
mask_value
=
mask_value
def
__repr__
(
self
):
return
(
self
.
__class__
.
__name__
+
"("
+
", "
.
join
(
[
f
"time_warp_w=
{
self
.
time_warp_w
}
"
,
f
"freq_mask_n=
{
self
.
freq_mask_n
}
"
,
f
"freq_mask_f=
{
self
.
freq_mask_f
}
"
,
f
"time_mask_n=
{
self
.
time_mask_n
}
"
,
f
"time_mask_t=
{
self
.
time_mask_t
}
"
,
f
"time_mask_p=
{
self
.
time_mask_p
}
"
,
]
)
+
")"
)
def
__call__
(
self
,
spectrogram
):
assert
len
(
spectrogram
.
shape
)
==
2
,
"spectrogram must be a 2-D tensor."
distorted
=
spectrogram
.
copy
()
# make a copy of input spectrogram.
num_frames
=
spectrogram
.
shape
[
0
]
# or 'tau' in the paper.
num_freqs
=
spectrogram
.
shape
[
1
]
# or 'miu' in the paper.
mask_value
=
self
.
mask_value
if
mask_value
is
None
:
# if no value was specified, use local mean.
mask_value
=
spectrogram
.
mean
()
if
num_frames
==
0
:
return
spectrogram
if
num_freqs
<
self
.
freq_mask_f
:
return
spectrogram
if
self
.
time_warp_w
>
0
:
if
2
*
self
.
time_warp_w
<
num_frames
:
import
cv2
w0
=
np
.
random
.
randint
(
self
.
time_warp_w
,
num_frames
-
self
.
time_warp_w
)
w
=
np
.
random
.
randint
(
-
self
.
time_warp_w
+
1
,
self
.
time_warp_w
)
upper
,
lower
=
distorted
[:
w0
,
:],
distorted
[
w0
:,
:]
upper
=
cv2
.
resize
(
upper
,
dsize
=
(
num_freqs
,
w0
+
w
),
interpolation
=
cv2
.
INTER_LINEAR
)
lower
=
cv2
.
resize
(
lower
,
dsize
=
(
num_freqs
,
num_frames
-
w0
-
w
),
interpolation
=
cv2
.
INTER_LINEAR
,
)
distorted
=
np
.
concatenate
((
upper
,
lower
),
axis
=
0
)
for
_i
in
range
(
self
.
freq_mask_n
):
f
=
np
.
random
.
randint
(
0
,
self
.
freq_mask_f
)
f0
=
np
.
random
.
randint
(
0
,
num_freqs
-
f
)
if
f
!=
0
:
distorted
[:,
f0
:
f0
+
f
]
=
mask_value
max_time_mask_t
=
min
(
self
.
time_mask_t
,
math
.
floor
(
num_frames
*
self
.
time_mask_p
)
)
if
max_time_mask_t
<
1
:
return
distorted
for
_i
in
range
(
self
.
time_mask_n
):
t
=
np
.
random
.
randint
(
0
,
max_time_mask_t
)
t0
=
np
.
random
.
randint
(
0
,
num_frames
-
t
)
if
t
!=
0
:
distorted
[
t0
:
t0
+
t
,
:]
=
mask_value
return
distorted
PyTorch/NLP/new-Transformer/fairseq/data/audio/feature_transforms/utterance_cmvn.py
0 → 100644
View file @
c0f05c10
import
numpy
as
np
from
fairseq.data.audio.feature_transforms
import
(
AudioFeatureTransform
,
register_audio_feature_transform
,
)
@
register_audio_feature_transform
(
"utterance_cmvn"
)
class
UtteranceCMVN
(
AudioFeatureTransform
):
"""Utterance-level CMVN (cepstral mean and variance normalization)"""
@
classmethod
def
from_config_dict
(
cls
,
config
=
None
):
_config
=
{}
if
config
is
None
else
config
return
UtteranceCMVN
(
_config
.
get
(
"norm_means"
,
True
),
_config
.
get
(
"norm_vars"
,
True
),
)
def
__init__
(
self
,
norm_means
=
True
,
norm_vars
=
True
):
self
.
norm_means
,
self
.
norm_vars
=
norm_means
,
norm_vars
def
__repr__
(
self
):
return
(
self
.
__class__
.
__name__
+
f
"(norm_means=
{
self
.
norm_means
}
, norm_vars=
{
self
.
norm_vars
}
)"
)
def
__call__
(
self
,
x
):
mean
=
x
.
mean
(
axis
=
0
)
square_sums
=
(
x
**
2
).
sum
(
axis
=
0
)
if
self
.
norm_means
:
x
=
np
.
subtract
(
x
,
mean
)
if
self
.
norm_vars
:
var
=
square_sums
/
x
.
shape
[
0
]
-
mean
**
2
std
=
np
.
sqrt
(
np
.
maximum
(
var
,
1e-10
))
x
=
np
.
divide
(
x
,
std
)
return
x
PyTorch/NLP/new-Transformer/fairseq/data/audio/frm_text_to_speech_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.abs
import
csv
import
logging
import
os.path
as
op
from
typing
import
List
,
Optional
import
numpy
as
np
import
torch
from
fairseq.data
import
Dictionary
from
fairseq.data.audio.speech_to_text_dataset
import
S2TDataConfig
from
fairseq.data.audio.text_to_speech_dataset
import
(
TextToSpeechDataset
,
TextToSpeechDatasetCreator
,
)
logger
=
logging
.
getLogger
(
__name__
)
class
FrmTextToSpeechDataset
(
TextToSpeechDataset
):
def
__init__
(
self
,
split
:
str
,
is_train_split
:
bool
,
data_cfg
:
S2TDataConfig
,
audio_paths
:
List
[
str
],
n_frames
:
List
[
int
],
src_texts
:
Optional
[
List
[
str
]]
=
None
,
tgt_texts
:
Optional
[
List
[
str
]]
=
None
,
speakers
:
Optional
[
List
[
str
]]
=
None
,
src_langs
:
Optional
[
List
[
str
]]
=
None
,
tgt_langs
:
Optional
[
List
[
str
]]
=
None
,
ids
:
Optional
[
List
[
str
]]
=
None
,
tgt_dict
:
Optional
[
Dictionary
]
=
None
,
pre_tokenizer
=
None
,
bpe_tokenizer
=
None
,
n_frames_per_step
=
1
,
speaker_to_id
=
None
,
do_chunk
=
False
,
chunk_bound
=-
1
,
chunk_init
=
50
,
chunk_incr
=
5
,
add_eos
=
True
,
dedup
=
True
,
ref_fpu
=-
1
,
):
# It assumes texts are encoded at a fixed frame-rate
super
().
__init__
(
split
=
split
,
is_train_split
=
is_train_split
,
data_cfg
=
data_cfg
,
audio_paths
=
audio_paths
,
n_frames
=
n_frames
,
src_texts
=
src_texts
,
tgt_texts
=
tgt_texts
,
speakers
=
speakers
,
src_langs
=
src_langs
,
tgt_langs
=
tgt_langs
,
ids
=
ids
,
tgt_dict
=
tgt_dict
,
pre_tokenizer
=
pre_tokenizer
,
bpe_tokenizer
=
bpe_tokenizer
,
n_frames_per_step
=
n_frames_per_step
,
speaker_to_id
=
speaker_to_id
,
)
self
.
do_chunk
=
do_chunk
self
.
chunk_bound
=
chunk_bound
self
.
chunk_init
=
chunk_init
self
.
chunk_incr
=
chunk_incr
self
.
add_eos
=
add_eos
self
.
dedup
=
dedup
self
.
ref_fpu
=
ref_fpu
self
.
chunk_size
=
-
1
if
do_chunk
:
assert
self
.
chunk_incr
>=
0
assert
self
.
pre_tokenizer
is
None
def
__getitem__
(
self
,
index
):
index
,
source
,
target
,
speaker_id
,
_
,
_
,
_
=
super
().
__getitem__
(
index
)
if
target
[
-
1
].
item
()
==
self
.
tgt_dict
.
eos_index
:
target
=
target
[:
-
1
]
fpu
=
source
.
size
(
0
)
/
target
.
size
(
0
)
# frame-per-unit
fps
=
self
.
n_frames_per_step
assert
(
self
.
ref_fpu
==
-
1
or
abs
((
fpu
*
fps
-
self
.
ref_fpu
)
/
self
.
ref_fpu
)
<
0.1
),
f
"
{
fpu
*
fps
}
!=
{
self
.
ref_fpu
}
"
# only chunk training split
if
self
.
is_train_split
and
self
.
do_chunk
and
self
.
chunk_size
>
0
:
lang
=
target
[:
int
(
self
.
data_cfg
.
prepend_tgt_lang_tag
)]
text
=
target
[
int
(
self
.
data_cfg
.
prepend_tgt_lang_tag
)
:]
size
=
len
(
text
)
chunk_size
=
min
(
self
.
chunk_size
,
size
)
chunk_start
=
np
.
random
.
randint
(
size
-
chunk_size
+
1
)
text
=
text
[
chunk_start
:
chunk_start
+
chunk_size
]
target
=
torch
.
cat
((
lang
,
text
),
0
)
f_size
=
int
(
np
.
floor
(
chunk_size
*
fpu
))
f_start
=
int
(
np
.
floor
(
chunk_start
*
fpu
))
assert
f_size
>
0
source
=
source
[
f_start
:
f_start
+
f_size
,
:]
if
self
.
dedup
:
target
=
torch
.
unique_consecutive
(
target
)
if
self
.
add_eos
:
eos_idx
=
self
.
tgt_dict
.
eos_index
target
=
torch
.
cat
((
target
,
torch
.
LongTensor
([
eos_idx
])),
0
)
return
index
,
source
,
target
,
speaker_id
def
set_epoch
(
self
,
epoch
):
if
self
.
is_train_split
and
self
.
do_chunk
:
old
=
self
.
chunk_size
self
.
chunk_size
=
self
.
chunk_init
+
epoch
*
self
.
chunk_incr
if
self
.
chunk_bound
>
0
:
self
.
chunk_size
=
min
(
self
.
chunk_size
,
self
.
chunk_bound
)
logger
.
info
(
(
f
"
{
self
.
split
}
: setting chunk size "
f
"from
{
old
}
to
{
self
.
chunk_size
}
"
)
)
class
FrmTextToSpeechDatasetCreator
(
TextToSpeechDatasetCreator
):
# inherit for key names
@
classmethod
def
from_tsv
(
cls
,
root
:
str
,
data_cfg
:
S2TDataConfig
,
split
:
str
,
tgt_dict
,
pre_tokenizer
,
bpe_tokenizer
,
is_train_split
:
bool
,
n_frames_per_step
:
int
,
speaker_to_id
,
do_chunk
:
bool
=
False
,
chunk_bound
:
int
=
-
1
,
chunk_init
:
int
=
50
,
chunk_incr
:
int
=
5
,
add_eos
:
bool
=
True
,
dedup
:
bool
=
True
,
ref_fpu
:
float
=
-
1
,
)
->
FrmTextToSpeechDataset
:
tsv_path
=
op
.
join
(
root
,
f
"
{
split
}
.tsv"
)
if
not
op
.
isfile
(
tsv_path
):
raise
FileNotFoundError
(
f
"Dataset not found:
{
tsv_path
}
"
)
with
open
(
tsv_path
)
as
f
:
reader
=
csv
.
DictReader
(
f
,
delimiter
=
"
\t
"
,
quotechar
=
None
,
doublequote
=
False
,
lineterminator
=
"
\n
"
,
quoting
=
csv
.
QUOTE_NONE
,
)
s
=
[
dict
(
e
)
for
e
in
reader
]
assert
len
(
s
)
>
0
ids
=
[
ss
[
cls
.
KEY_ID
]
for
ss
in
s
]
audio_paths
=
[
op
.
join
(
data_cfg
.
audio_root
,
ss
[
cls
.
KEY_AUDIO
])
for
ss
in
s
]
n_frames
=
[
int
(
ss
[
cls
.
KEY_N_FRAMES
])
for
ss
in
s
]
tgt_texts
=
[
ss
[
cls
.
KEY_TGT_TEXT
]
for
ss
in
s
]
src_texts
=
[
ss
.
get
(
cls
.
KEY_SRC_TEXT
,
cls
.
DEFAULT_SRC_TEXT
)
for
ss
in
s
]
speakers
=
[
ss
.
get
(
cls
.
KEY_SPEAKER
,
cls
.
DEFAULT_SPEAKER
)
for
ss
in
s
]
src_langs
=
[
ss
.
get
(
cls
.
KEY_SRC_LANG
,
cls
.
DEFAULT_LANG
)
for
ss
in
s
]
tgt_langs
=
[
ss
.
get
(
cls
.
KEY_TGT_LANG
,
cls
.
DEFAULT_LANG
)
for
ss
in
s
]
return
FrmTextToSpeechDataset
(
split
=
split
,
is_train_split
=
is_train_split
,
data_cfg
=
data_cfg
,
audio_paths
=
audio_paths
,
n_frames
=
n_frames
,
src_texts
=
src_texts
,
tgt_texts
=
tgt_texts
,
speakers
=
speakers
,
src_langs
=
src_langs
,
tgt_langs
=
tgt_langs
,
ids
=
ids
,
tgt_dict
=
tgt_dict
,
pre_tokenizer
=
pre_tokenizer
,
bpe_tokenizer
=
bpe_tokenizer
,
n_frames_per_step
=
n_frames_per_step
,
speaker_to_id
=
speaker_to_id
,
do_chunk
=
do_chunk
,
chunk_bound
=
chunk_bound
,
chunk_init
=
chunk_init
,
chunk_incr
=
chunk_incr
,
add_eos
=
add_eos
,
dedup
=
dedup
,
ref_fpu
=
ref_fpu
,
)
PyTorch/NLP/new-Transformer/fairseq/data/audio/hubert_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
itertools
import
logging
import
os
import
sys
from
typing
import
Any
,
List
,
Optional
,
Union
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
fairseq.data
import
data_utils
from
fairseq.data.fairseq_dataset
import
FairseqDataset
logger
=
logging
.
getLogger
(
__name__
)
def
load_audio
(
manifest_path
,
max_keep
,
min_keep
):
n_long
,
n_short
=
0
,
0
names
,
inds
,
sizes
=
[],
[],
[]
with
open
(
manifest_path
)
as
f
:
root
=
f
.
readline
().
strip
()
for
ind
,
line
in
enumerate
(
f
):
items
=
line
.
strip
().
split
(
"
\t
"
)
assert
len
(
items
)
==
2
,
line
sz
=
int
(
items
[
1
])
if
min_keep
is
not
None
and
sz
<
min_keep
:
n_short
+=
1
elif
max_keep
is
not
None
and
sz
>
max_keep
:
n_long
+=
1
else
:
names
.
append
(
items
[
0
])
inds
.
append
(
ind
)
sizes
.
append
(
sz
)
tot
=
ind
+
1
logger
.
info
(
(
f
"max_keep=
{
max_keep
}
, min_keep=
{
min_keep
}
, "
f
"loaded
{
len
(
names
)
}
, skipped
{
n_short
}
short and
{
n_long
}
long, "
f
"longest-loaded=
{
max
(
sizes
)
}
, shortest-loaded=
{
min
(
sizes
)
}
"
)
)
return
root
,
names
,
inds
,
tot
,
sizes
def
load_label
(
label_path
,
inds
,
tot
):
with
open
(
label_path
)
as
f
:
labels
=
[
line
.
rstrip
()
for
line
in
f
]
assert
(
len
(
labels
)
==
tot
),
f
"number of labels does not match (
{
len
(
labels
)
}
!=
{
tot
}
)"
labels
=
[
labels
[
i
]
for
i
in
inds
]
return
labels
def
load_label_offset
(
label_path
,
inds
,
tot
):
with
open
(
label_path
)
as
f
:
code_lengths
=
[
len
(
line
.
encode
(
"utf-8"
))
for
line
in
f
]
assert
(
len
(
code_lengths
)
==
tot
),
f
"number of labels does not match (
{
len
(
code_lengths
)
}
!=
{
tot
}
)"
offsets
=
list
(
itertools
.
accumulate
([
0
]
+
code_lengths
))
offsets
=
[(
offsets
[
i
],
offsets
[
i
+
1
])
for
i
in
inds
]
return
offsets
def
verify_label_lengths
(
audio_sizes
,
audio_rate
,
label_path
,
label_rate
,
inds
,
tot
,
tol
=
0.1
,
# tolerance in seconds
):
if
label_rate
<
0
:
logger
.
info
(
f
"
{
label_path
}
is sequence label. skipped"
)
return
with
open
(
label_path
)
as
f
:
lengths
=
[
len
(
line
.
rstrip
().
split
())
for
line
in
f
]
assert
len
(
lengths
)
==
tot
lengths
=
[
lengths
[
i
]
for
i
in
inds
]
num_invalid
=
0
for
i
,
ind
in
enumerate
(
inds
):
dur_from_audio
=
audio_sizes
[
i
]
/
audio_rate
dur_from_label
=
lengths
[
i
]
/
label_rate
if
abs
(
dur_from_audio
-
dur_from_label
)
>
tol
:
logger
.
warning
(
(
f
"audio and label duration differ too much "
f
"(|
{
dur_from_audio
}
-
{
dur_from_label
}
| >
{
tol
}
) "
f
"in line
{
ind
+
1
}
of
{
label_path
}
. Check if `label_rate` "
f
"is correctly set (currently
{
label_rate
}
). "
f
"num. of samples =
{
audio_sizes
[
i
]
}
; "
f
"label length =
{
lengths
[
i
]
}
"
)
)
num_invalid
+=
1
if
num_invalid
>
0
:
logger
.
warning
(
f
"total
{
num_invalid
}
(audio, label) pairs with mismatched lengths"
)
class
HubertDataset
(
FairseqDataset
):
def
__init__
(
self
,
manifest_path
:
str
,
sample_rate
:
float
,
label_paths
:
List
[
str
],
label_rates
:
Union
[
List
[
float
],
float
],
# -1 for sequence labels
pad_list
:
List
[
str
],
eos_list
:
List
[
str
],
label_processors
:
Optional
[
List
[
Any
]]
=
None
,
max_keep_sample_size
:
Optional
[
int
]
=
None
,
min_keep_sample_size
:
Optional
[
int
]
=
None
,
max_sample_size
:
Optional
[
int
]
=
None
,
shuffle
:
bool
=
True
,
pad_audio
:
bool
=
False
,
normalize
:
bool
=
False
,
store_labels
:
bool
=
True
,
random_crop
:
bool
=
False
,
single_target
:
bool
=
False
,
):
self
.
audio_root
,
self
.
audio_names
,
inds
,
tot
,
self
.
sizes
=
load_audio
(
manifest_path
,
max_keep_sample_size
,
min_keep_sample_size
)
self
.
sample_rate
=
sample_rate
self
.
shuffle
=
shuffle
self
.
random_crop
=
random_crop
self
.
num_labels
=
len
(
label_paths
)
self
.
pad_list
=
pad_list
self
.
eos_list
=
eos_list
self
.
label_processors
=
label_processors
self
.
single_target
=
single_target
self
.
label_rates
=
(
[
label_rates
for
_
in
range
(
len
(
label_paths
))]
if
isinstance
(
label_rates
,
float
)
else
label_rates
)
self
.
store_labels
=
store_labels
if
store_labels
:
self
.
label_list
=
[
load_label
(
p
,
inds
,
tot
)
for
p
in
label_paths
]
else
:
self
.
label_paths
=
label_paths
self
.
label_offsets_list
=
[
load_label_offset
(
p
,
inds
,
tot
)
for
p
in
label_paths
]
assert
label_processors
is
None
or
len
(
label_processors
)
==
self
.
num_labels
for
label_path
,
label_rate
in
zip
(
label_paths
,
self
.
label_rates
):
verify_label_lengths
(
self
.
sizes
,
sample_rate
,
label_path
,
label_rate
,
inds
,
tot
)
self
.
max_sample_size
=
(
max_sample_size
if
max_sample_size
is
not
None
else
sys
.
maxsize
)
self
.
pad_audio
=
pad_audio
self
.
normalize
=
normalize
logger
.
info
(
f
"pad_audio=
{
pad_audio
}
, random_crop=
{
random_crop
}
, "
f
"normalize=
{
normalize
}
, max_sample_size=
{
self
.
max_sample_size
}
"
)
def
get_audio
(
self
,
index
):
import
soundfile
as
sf
wav_path
=
os
.
path
.
join
(
self
.
audio_root
,
self
.
audio_names
[
index
])
wav
,
cur_sample_rate
=
sf
.
read
(
wav_path
)
wav
=
torch
.
from_numpy
(
wav
).
float
()
wav
=
self
.
postprocess
(
wav
,
cur_sample_rate
)
return
wav
def
get_label
(
self
,
index
,
label_idx
):
if
self
.
store_labels
:
label
=
self
.
label_list
[
label_idx
][
index
]
else
:
with
open
(
self
.
label_paths
[
label_idx
])
as
f
:
offset_s
,
offset_e
=
self
.
label_offsets_list
[
label_idx
][
index
]
f
.
seek
(
offset_s
)
label
=
f
.
read
(
offset_e
-
offset_s
)
if
self
.
label_processors
is
not
None
:
label
=
self
.
label_processors
[
label_idx
](
label
)
return
label
def
get_labels
(
self
,
index
):
return
[
self
.
get_label
(
index
,
i
)
for
i
in
range
(
self
.
num_labels
)]
def
__getitem__
(
self
,
index
):
wav
=
self
.
get_audio
(
index
)
labels
=
self
.
get_labels
(
index
)
return
{
"id"
:
index
,
"source"
:
wav
,
"label_list"
:
labels
}
def
__len__
(
self
):
return
len
(
self
.
sizes
)
def
crop_to_max_size
(
self
,
wav
,
target_size
):
size
=
len
(
wav
)
diff
=
size
-
target_size
if
diff
<=
0
:
return
wav
,
0
start
,
end
=
0
,
target_size
if
self
.
random_crop
:
start
=
np
.
random
.
randint
(
0
,
diff
+
1
)
end
=
size
-
diff
+
start
return
wav
[
start
:
end
],
start
def
collater
(
self
,
samples
):
# target = max(sizes) -> random_crop not used
# target = max_sample_size -> random_crop used for long
samples
=
[
s
for
s
in
samples
if
s
[
"source"
]
is
not
None
]
if
len
(
samples
)
==
0
:
return
{}
audios
=
[
s
[
"source"
]
for
s
in
samples
]
audio_sizes
=
[
len
(
s
)
for
s
in
audios
]
if
self
.
pad_audio
:
audio_size
=
min
(
max
(
audio_sizes
),
self
.
max_sample_size
)
else
:
audio_size
=
min
(
min
(
audio_sizes
),
self
.
max_sample_size
)
collated_audios
,
padding_mask
,
audio_starts
=
self
.
collater_audio
(
audios
,
audio_size
)
targets_by_label
=
[
[
s
[
"label_list"
][
i
]
for
s
in
samples
]
for
i
in
range
(
self
.
num_labels
)
]
targets_list
,
lengths_list
,
ntokens_list
=
self
.
collater_label
(
targets_by_label
,
audio_size
,
audio_starts
)
net_input
=
{
"source"
:
collated_audios
,
"padding_mask"
:
padding_mask
}
batch
=
{
"id"
:
torch
.
LongTensor
([
s
[
"id"
]
for
s
in
samples
]),
"net_input"
:
net_input
,
}
if
self
.
single_target
:
batch
[
"target_lengths"
]
=
lengths_list
[
0
]
batch
[
"ntokens"
]
=
ntokens_list
[
0
]
batch
[
"target"
]
=
targets_list
[
0
]
else
:
batch
[
"target_lengths_list"
]
=
lengths_list
batch
[
"ntokens_list"
]
=
ntokens_list
batch
[
"target_list"
]
=
targets_list
return
batch
def
collater_audio
(
self
,
audios
,
audio_size
):
collated_audios
=
audios
[
0
].
new_zeros
(
len
(
audios
),
audio_size
)
padding_mask
=
(
torch
.
BoolTensor
(
collated_audios
.
shape
).
fill_
(
False
)
# if self.pad_audio else None
)
audio_starts
=
[
0
for
_
in
audios
]
for
i
,
audio
in
enumerate
(
audios
):
diff
=
len
(
audio
)
-
audio_size
if
diff
==
0
:
collated_audios
[
i
]
=
audio
elif
diff
<
0
:
assert
self
.
pad_audio
collated_audios
[
i
]
=
torch
.
cat
([
audio
,
audio
.
new_full
((
-
diff
,),
0.0
)])
padding_mask
[
i
,
diff
:]
=
True
else
:
collated_audios
[
i
],
audio_starts
[
i
]
=
self
.
crop_to_max_size
(
audio
,
audio_size
)
return
collated_audios
,
padding_mask
,
audio_starts
def
collater_frm_label
(
self
,
targets
,
audio_size
,
audio_starts
,
label_rate
,
pad
):
assert
label_rate
>
0
s2f
=
label_rate
/
self
.
sample_rate
frm_starts
=
[
int
(
round
(
s
*
s2f
))
for
s
in
audio_starts
]
frm_size
=
int
(
round
(
audio_size
*
s2f
))
if
not
self
.
pad_audio
:
rem_size
=
[
len
(
t
)
-
s
for
t
,
s
in
zip
(
targets
,
frm_starts
)]
frm_size
=
min
(
frm_size
,
*
rem_size
)
targets
=
[
t
[
s
:
s
+
frm_size
]
for
t
,
s
in
zip
(
targets
,
frm_starts
)]
logger
.
debug
(
f
"audio_starts=
{
audio_starts
}
"
)
logger
.
debug
(
f
"frame_starts=
{
frm_starts
}
"
)
logger
.
debug
(
f
"frame_size=
{
frm_size
}
"
)
lengths
=
torch
.
LongTensor
([
len
(
t
)
for
t
in
targets
])
ntokens
=
lengths
.
sum
().
item
()
targets
=
data_utils
.
collate_tokens
(
targets
,
pad_idx
=
pad
,
left_pad
=
False
)
return
targets
,
lengths
,
ntokens
def
collater_seq_label
(
self
,
targets
,
pad
):
lengths
=
torch
.
LongTensor
([
len
(
t
)
for
t
in
targets
])
ntokens
=
lengths
.
sum
().
item
()
targets
=
data_utils
.
collate_tokens
(
targets
,
pad_idx
=
pad
,
left_pad
=
False
)
return
targets
,
lengths
,
ntokens
def
collater_label
(
self
,
targets_by_label
,
audio_size
,
audio_starts
):
targets_list
,
lengths_list
,
ntokens_list
=
[],
[],
[]
itr
=
zip
(
targets_by_label
,
self
.
label_rates
,
self
.
pad_list
)
for
targets
,
label_rate
,
pad
in
itr
:
if
label_rate
==
-
1.0
:
targets
,
lengths
,
ntokens
=
self
.
collater_seq_label
(
targets
,
pad
)
else
:
targets
,
lengths
,
ntokens
=
self
.
collater_frm_label
(
targets
,
audio_size
,
audio_starts
,
label_rate
,
pad
)
targets_list
.
append
(
targets
)
lengths_list
.
append
(
lengths
)
ntokens_list
.
append
(
ntokens
)
return
targets_list
,
lengths_list
,
ntokens_list
def
num_tokens
(
self
,
index
):
return
self
.
size
(
index
)
def
size
(
self
,
index
):
if
self
.
pad_audio
:
return
self
.
sizes
[
index
]
return
min
(
self
.
sizes
[
index
],
self
.
max_sample_size
)
def
ordered_indices
(
self
):
if
self
.
shuffle
:
order
=
[
np
.
random
.
permutation
(
len
(
self
))]
else
:
order
=
[
np
.
arange
(
len
(
self
))]
order
.
append
(
self
.
sizes
)
return
np
.
lexsort
(
order
)[::
-
1
]
def
postprocess
(
self
,
wav
,
cur_sample_rate
):
if
wav
.
dim
()
==
2
:
wav
=
wav
.
mean
(
-
1
)
assert
wav
.
dim
()
==
1
,
wav
.
dim
()
if
cur_sample_rate
!=
self
.
sample_rate
:
raise
Exception
(
f
"sr
{
cur_sample_rate
}
!=
{
self
.
sample_rate
}
"
)
if
self
.
normalize
:
with
torch
.
no_grad
():
wav
=
F
.
layer_norm
(
wav
,
wav
.
shape
)
return
wav
Prev
1
…
5
6
7
8
9
10
11
12
13
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment