Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
dcuai
dlexamples
Commits
c0f05c10
Commit
c0f05c10
authored
Nov 29, 2022
by
hepj
Browse files
更新transformer代码
parent
c056df78
Changes
321
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2386 additions
and
0 deletions
+2386
-0
PyTorch/NLP/new-Transformer/fairseq/config/model/transformer_lm/transformer_lm_wiki103.yaml
...q/config/model/transformer_lm/transformer_lm_wiki103.yaml
+36
-0
PyTorch/NLP/new-Transformer/fairseq/config/model/wav2vec/vq_wav2vec_gumbel.yaml
...ormer/fairseq/config/model/wav2vec/vq_wav2vec_gumbel.yaml
+5
-0
PyTorch/NLP/new-Transformer/fairseq/config/model/wav2vec2/wav2vec2_base.yaml
...nsformer/fairseq/config/model/wav2vec2/wav2vec2_base.yaml
+8
-0
PyTorch/NLP/new-Transformer/fairseq/config/model/wav2vec2/wav2vec2_large.yaml
...sformer/fairseq/config/model/wav2vec2/wav2vec2_large.yaml
+20
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/__init__.py
PyTorch/NLP/new-Transformer/fairseq/criterions/__init__.py
+36
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/adaptive_loss.py
...h/NLP/new-Transformer/fairseq/criterions/adaptive_loss.py
+123
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/composite_loss.py
.../NLP/new-Transformer/fairseq/criterions/composite_loss.py
+100
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/cross_entropy.py
...h/NLP/new-Transformer/fairseq/criterions/cross_entropy.py
+90
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/ctc.py
PyTorch/NLP/new-Transformer/fairseq/criterions/ctc.py
+295
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/fairseq_criterion.py
...P/new-Transformer/fairseq/criterions/fairseq_criterion.py
+120
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/fastspeech2_loss.py
...LP/new-Transformer/fairseq/criterions/fastspeech2_loss.py
+136
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/hubert_criterion.py
...LP/new-Transformer/fairseq/criterions/hubert_criterion.py
+194
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/label_smoothed_cross_entropy.py
...former/fairseq/criterions/label_smoothed_cross_entropy.py
+167
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py
...terions/label_smoothed_cross_entropy_latency_augmented.py
+220
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py
...criterions/label_smoothed_cross_entropy_with_alignment.py
+130
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/label_smoothed_cross_entropy_with_ctc.py
...irseq/criterions/label_smoothed_cross_entropy_with_ctc.py
+96
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/legacy_masked_lm.py
...LP/new-Transformer/fairseq/criterions/legacy_masked_lm.py
+177
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/masked_lm.py
PyTorch/NLP/new-Transformer/fairseq/criterions/masked_lm.py
+98
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/model_criterion.py
...NLP/new-Transformer/fairseq/criterions/model_criterion.py
+155
-0
PyTorch/NLP/new-Transformer/fairseq/criterions/nat_loss.py
PyTorch/NLP/new-Transformer/fairseq/criterions/nat_loss.py
+180
-0
No files found.
Too many changes to show.
To preserve performance only
321 of 321+
files are displayed.
Plain diff
Email patch
PyTorch/NLP/new-Transformer/fairseq/config/model/transformer_lm/transformer_lm_wiki103.yaml
0 → 100644
View file @
c0f05c10
# @package _group_
activation_fn
:
"
relu"
dropout
:
0.3
attention_dropout
:
0.1
activation_dropout
:
0.1
relu_dropout
:
0.1
decoder_embed_dim
:
1024
decoder_output_dim
:
1024
decoder_input_dim
:
1024
decoder_ffn_embed_dim
:
4096
decoder_layers
:
16
decoder_attention_heads
:
8
decoder_normalize_before
:
true
no_decoder_final_norm
:
true
adaptive_softmax_cutoff
:
"
20000,60000"
adaptive_softmax_dropout
:
0.2
adaptive_softmax_factor
:
4
no_token_positional_embeddings
:
false
share_decoder_input_output_embed
:
false
character_embeddings
:
false
character_filters
:
"
[(1,
64),
(2,
128),
(3,
192),
(4,
256),
(5,
256),
(6,
256),
(7,
256)]"
character_embedding_dim
:
4
char_embedder_highway_layers
:
2
adaptive_input
:
true
adaptive_input_factor
:
4
adaptive_input_cutoff
:
"
20000,60000"
tie_adaptive_weights
:
true
tie_adaptive_proj
:
true
decoder_learned_pos
:
false
decoder_layerdrop
:
0
decoder_layers_to_keep
:
null
layernorm_embedding
:
false
no_scale_embedding
:
false
quant_noise_pq
:
0
quant_noise_pq_block_size
:
8
quant_noise_scalar
:
0
PyTorch/NLP/new-Transformer/fairseq/config/model/wav2vec/vq_wav2vec_gumbel.yaml
0 → 100644
View file @
c0f05c10
# @package _group_
activation
:
gelu
vq_type
:
gumbel
vq_depth
:
2
combine_groups
:
true
PyTorch/NLP/new-Transformer/fairseq/config/model/wav2vec2/wav2vec2_base.yaml
0 → 100644
View file @
c0f05c10
# @package _group_
quantize_targets
:
true
final_dim
:
256
encoder_layerdrop
:
0.05
dropout_input
:
0.1
dropout_features
:
0.1
feature_grad_mult
:
0.1
PyTorch/NLP/new-Transformer/fairseq/config/model/wav2vec2/wav2vec2_large.yaml
0 → 100644
View file @
c0f05c10
# @package _group_
quantize_targets
:
true
extractor_mode
:
layer_norm
layer_norm_first
:
true
final_dim
:
768
latent_temp
:
[
2.0
,
0.1
,
0.999995
]
encoder_layerdrop
:
0.0
dropout_input
:
0.0
dropout_features
:
0.0
dropout
:
0.0
attention_dropout
:
0.0
conv_bias
:
true
encoder_layers
:
24
encoder_embed_dim
:
1024
encoder_ffn_embed_dim
:
4096
encoder_attention_heads
:
16
feature_grad_mult
:
1.0
PyTorch/NLP/new-Transformer/fairseq/criterions/__init__.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
import
importlib
import
os
from
fairseq
import
registry
from
fairseq.criterions.fairseq_criterion
import
(
# noqa
FairseqCriterion
,
LegacyFairseqCriterion
,
)
from
omegaconf
import
DictConfig
(
build_criterion_
,
register_criterion
,
CRITERION_REGISTRY
,
CRITERION_DATACLASS_REGISTRY
,
)
=
registry
.
setup_registry
(
"--criterion"
,
base_class
=
FairseqCriterion
,
default
=
"cross_entropy"
)
def
build_criterion
(
cfg
:
DictConfig
,
task
):
return
build_criterion_
(
cfg
,
task
)
# automatically import any Python files in the criterions/ directory
for
file
in
sorted
(
os
.
listdir
(
os
.
path
.
dirname
(
__file__
))):
if
file
.
endswith
(
".py"
)
and
not
file
.
startswith
(
"_"
):
file_name
=
file
[:
file
.
find
(
".py"
)]
importlib
.
import_module
(
"fairseq.criterions."
+
file_name
)
PyTorch/NLP/new-Transformer/fairseq/criterions/adaptive_loss.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
dataclasses
import
dataclass
import
torch.nn.functional
as
F
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
from
fairseq.dataclass
import
FairseqDataclass
from
fairseq.dataclass.constants
import
DDP_BACKEND_CHOICES
from
omegaconf
import
II
@
dataclass
class
AdaptiveLossConfig
(
FairseqDataclass
):
sentence_avg
:
bool
=
II
(
"optimization.sentence_avg"
)
ddp_backend
:
DDP_BACKEND_CHOICES
=
II
(
"distributed_training.ddp_backend"
)
@
register_criterion
(
"adaptive_loss"
,
dataclass
=
AdaptiveLossConfig
)
class
AdaptiveLoss
(
FairseqCriterion
):
"""This is an implementation of the loss function accompanying the adaptive softmax approximation for
graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs"
(http://arxiv.org/abs/1609.04309)."""
def
__init__
(
self
,
task
,
sentence_avg
):
super
().
__init__
(
task
)
self
.
sentence_avg
=
sentence_avg
@
classmethod
def
build_criterion
(
cls
,
cfg
:
AdaptiveLossConfig
,
task
):
if
cfg
.
ddp_backend
in
{
"c10d"
,
"pytorch_ddp"
}:
raise
Exception
(
"AdaptiveLoss is not compatible with the PyTorch "
"version of DistributedDataParallel. Please use "
"`--ddp-backend=legacy_ddp` instead."
)
return
cls
(
task
,
cfg
.
sentence_avg
)
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert
(
hasattr
(
model
.
decoder
,
"adaptive_softmax"
)
and
model
.
decoder
.
adaptive_softmax
is
not
None
)
adaptive_softmax
=
model
.
decoder
.
adaptive_softmax
net_output
=
model
(
**
sample
[
"net_input"
])
orig_target
=
model
.
get_targets
(
sample
,
net_output
)
nsentences
=
orig_target
.
size
(
0
)
orig_target
=
orig_target
.
view
(
-
1
)
bsz
=
orig_target
.
size
(
0
)
logits
,
target
=
adaptive_softmax
(
net_output
[
0
],
orig_target
)
assert
len
(
target
)
==
len
(
logits
)
loss
=
net_output
[
0
].
new
(
1
if
reduce
else
bsz
).
zero_
()
for
i
in
range
(
len
(
target
)):
if
target
[
i
]
is
not
None
:
assert
target
[
i
].
min
()
>=
0
and
target
[
i
].
max
()
<=
logits
[
i
].
size
(
1
)
loss
+=
F
.
cross_entropy
(
logits
[
i
],
target
[
i
],
ignore_index
=
self
.
padding_idx
,
reduction
=
"sum"
if
reduce
else
"none"
,
)
orig
=
utils
.
strip_pad
(
orig_target
,
self
.
padding_idx
)
ntokens
=
orig
.
numel
()
sample_size
=
sample
[
"target"
].
size
(
0
)
if
self
.
sentence_avg
else
ntokens
logging_output
=
{
"loss"
:
loss
.
data
,
"ntokens"
:
ntokens
,
"nsentences"
:
nsentences
,
"sample_size"
:
sample_size
,
}
return
loss
,
sample_size
,
logging_output
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
"""Aggregate logging outputs from data parallel training."""
loss_sum
=
utils
.
item
(
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
))
ntokens
=
utils
.
item
(
sum
(
log
.
get
(
"ntokens"
,
0
)
for
log
in
logging_outputs
))
sample_size
=
utils
.
item
(
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
)
metrics
.
log_scalar
(
"loss"
,
loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
)
if
sample_size
!=
ntokens
:
metrics
.
log_scalar
(
"nll_loss"
,
loss_sum
/
ntokens
/
math
.
log
(
2
),
ntokens
,
round
=
3
)
metrics
.
log_derived
(
"ppl"
,
lambda
meters
:
utils
.
get_perplexity
(
meters
[
"nll_loss"
].
avg
)
)
else
:
metrics
.
log_derived
(
"ppl"
,
lambda
meters
:
utils
.
get_perplexity
(
meters
[
"loss"
].
avg
)
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return
True
PyTorch/NLP/new-Transformer/fairseq/criterions/composite_loss.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
fairseq
import
utils
from
fairseq.criterions
import
LegacyFairseqCriterion
,
register_criterion
from
torch
import
nn
@
register_criterion
(
"composite_loss"
)
class
CompositeLoss
(
LegacyFairseqCriterion
):
"""This is a composite loss that, given a list of model outputs and a list of targets,
computes an average of losses for each output-target pair"""
def
__init__
(
self
,
args
,
task
):
super
().
__init__
(
args
,
task
)
self
.
underlying_criterion
=
args
.
underlying_criterion
@
staticmethod
def
add_args
(
parser
):
"""Add criterion-specific arguments to the parser."""
# fmt: off
parser
.
add_argument
(
'--underlying-criterion'
,
type
=
str
,
metavar
=
'VAL'
,
required
=
True
,
help
=
'underlying criterion to use for the composite loss'
)
# fmt: on
@
staticmethod
def
build_underlying_criterion
(
args
,
task
):
saved_criterion
=
args
.
criterion
args
.
criterion
=
args
.
underlying_criterion
assert
saved_criterion
!=
args
.
underlying_criterion
underlying_criterion
=
task
.
build_criterion
(
args
)
args
.
criterion
=
saved_criterion
return
underlying_criterion
@
classmethod
def
build_criterion
(
cls
,
args
,
task
):
underlying_criterion
=
CompositeLoss
.
build_underlying_criterion
(
args
,
task
)
class
FakeModel
(
nn
.
Module
):
def
__init__
(
self
,
model
,
net_out
,
target
):
super
().
__init__
()
self
.
model
=
model
self
.
net_out
=
net_out
self
.
target
=
target
def
forward
(
self
,
**
unused
):
return
self
.
net_out
def
get_normalized_probs
(
self
,
net_output
,
log_probs
,
sample
=
None
):
return
self
.
model
.
get_normalized_probs
(
net_output
,
log_probs
,
sample
=
sample
)
def
get_targets
(
self
,
*
unused
):
return
self
.
target
@
property
def
decoder
(
self
):
return
self
.
model
.
decoder
class
_CompositeLoss
(
LegacyFairseqCriterion
):
def
__init__
(
self
,
args
,
task
,
underlying_criterion
):
super
().
__init__
(
args
,
task
)
self
.
underlying_criterion
=
underlying_criterion
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
net_outputs
=
model
(
**
sample
[
"net_input"
])
targets
=
sample
[
"target"
]
bsz
=
targets
[
0
].
size
(
0
)
loss
=
net_outputs
[
0
][
0
].
new
(
1
if
reduce
else
bsz
).
float
().
zero_
()
sample_size
=
0
logging_output
=
{}
for
o
,
t
in
zip
(
net_outputs
[
0
],
targets
):
m
=
FakeModel
(
model
,
(
o
,
net_outputs
[
1
]),
t
)
sample
[
"target"
]
=
t
l
,
ss
,
logging_output
=
self
.
underlying_criterion
(
m
,
sample
,
reduce
)
loss
+=
l
sample_size
+=
ss
loss
.
div_
(
len
(
targets
))
sample_size
/=
len
(
targets
)
logging_output
[
"loss"
]
=
utils
.
item
(
loss
.
data
)
if
reduce
else
loss
.
data
return
loss
,
sample_size
,
logging_output
@
staticmethod
def
aggregate_logging_outputs
(
logging_outputs
):
return
underlying_criterion
.
__class__
.
aggregate_logging_outputs
(
logging_outputs
)
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
underlying_criterion
.
__class__
.
reduce_metrics
(
logging_outputs
)
return
_CompositeLoss
(
args
,
task
,
underlying_criterion
)
PyTorch/NLP/new-Transformer/fairseq/criterions/cross_entropy.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
dataclasses
import
dataclass
import
torch.nn.functional
as
F
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
from
fairseq.dataclass
import
FairseqDataclass
from
omegaconf
import
II
@
dataclass
class
CrossEntropyCriterionConfig
(
FairseqDataclass
):
sentence_avg
:
bool
=
II
(
"optimization.sentence_avg"
)
@
register_criterion
(
"cross_entropy"
,
dataclass
=
CrossEntropyCriterionConfig
)
class
CrossEntropyCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
task
,
sentence_avg
):
super
().
__init__
(
task
)
self
.
sentence_avg
=
sentence_avg
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output
=
model
(
**
sample
[
"net_input"
])
loss
,
_
=
self
.
compute_loss
(
model
,
net_output
,
sample
,
reduce
=
reduce
)
sample_size
=
(
sample
[
"target"
].
size
(
0
)
if
self
.
sentence_avg
else
sample
[
"ntokens"
]
)
logging_output
=
{
"loss"
:
loss
.
data
,
"ntokens"
:
sample
[
"ntokens"
],
"nsentences"
:
sample
[
"target"
].
size
(
0
),
"sample_size"
:
sample_size
,
}
return
loss
,
sample_size
,
logging_output
def
compute_loss
(
self
,
model
,
net_output
,
sample
,
reduce
=
True
):
lprobs
=
model
.
get_normalized_probs
(
net_output
,
log_probs
=
True
)
lprobs
=
lprobs
.
view
(
-
1
,
lprobs
.
size
(
-
1
))
target
=
model
.
get_targets
(
sample
,
net_output
).
view
(
-
1
)
loss
=
F
.
nll_loss
(
lprobs
,
target
,
ignore_index
=
self
.
padding_idx
,
reduction
=
"sum"
if
reduce
else
"none"
,
)
return
loss
,
loss
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
"""Aggregate logging outputs from data parallel training."""
loss_sum
=
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
)
ntokens
=
sum
(
log
.
get
(
"ntokens"
,
0
)
for
log
in
logging_outputs
)
sample_size
=
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
# we divide by log(2) to convert the loss from base e to base 2
metrics
.
log_scalar
(
"loss"
,
loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
)
if
sample_size
!=
ntokens
:
metrics
.
log_scalar
(
"nll_loss"
,
loss_sum
/
ntokens
/
math
.
log
(
2
),
ntokens
,
round
=
3
)
metrics
.
log_derived
(
"ppl"
,
lambda
meters
:
utils
.
get_perplexity
(
meters
[
"nll_loss"
].
avg
)
)
else
:
metrics
.
log_derived
(
"ppl"
,
lambda
meters
:
utils
.
get_perplexity
(
meters
[
"loss"
].
avg
)
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return
True
PyTorch/NLP/new-Transformer/fairseq/criterions/ctc.py
0 → 100644
View file @
c0f05c10
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
math
from
argparse
import
Namespace
from
dataclasses
import
dataclass
,
field
from
omegaconf
import
II
from
typing
import
Optional
import
torch
import
torch.nn.functional
as
F
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
from
fairseq.dataclass
import
FairseqDataclass
from
fairseq.data.data_utils
import
post_process
from
fairseq.tasks
import
FairseqTask
from
fairseq.logging.meters
import
safe_round
@
dataclass
class
CtcCriterionConfig
(
FairseqDataclass
):
zero_infinity
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"zero inf loss when source length <= target length"
},
)
sentence_avg
:
bool
=
II
(
"optimization.sentence_avg"
)
post_process
:
str
=
field
(
default
=
"letter"
,
metadata
=
{
"help"
:
"how to post process predictions into words. can be letter, "
"wordpiece, BPE symbols, etc. "
"See fairseq.data.data_utils.post_process() for full list of options"
},
)
wer_kenlm_model
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"if this is provided, use kenlm to compute wer (along with other wer_* args)"
},
)
wer_lexicon
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"lexicon to use with wer_kenlm_model"
},
)
wer_lm_weight
:
float
=
field
(
default
=
2.0
,
metadata
=
{
"help"
:
"lm weight to use with wer_kenlm_model"
},
)
wer_word_score
:
float
=
field
(
default
=-
1.0
,
metadata
=
{
"help"
:
"lm word score to use with wer_kenlm_model"
},
)
wer_args
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"DEPRECATED: tuple of (wer_kenlm_model, wer_lexicon, wer_lm_weight, wer_word_score)"
},
)
@
register_criterion
(
"ctc"
,
dataclass
=
CtcCriterionConfig
)
class
CtcCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
cfg
:
CtcCriterionConfig
,
task
:
FairseqTask
):
super
().
__init__
(
task
)
self
.
blank_idx
=
(
task
.
target_dictionary
.
index
(
task
.
blank_symbol
)
if
hasattr
(
task
,
"blank_symbol"
)
else
0
)
self
.
pad_idx
=
task
.
target_dictionary
.
pad
()
self
.
eos_idx
=
task
.
target_dictionary
.
eos
()
self
.
post_process
=
cfg
.
post_process
if
cfg
.
wer_args
is
not
None
:
(
cfg
.
wer_kenlm_model
,
cfg
.
wer_lexicon
,
cfg
.
wer_lm_weight
,
cfg
.
wer_word_score
,
)
=
eval
(
cfg
.
wer_args
)
if
cfg
.
wer_kenlm_model
is
not
None
and
cfg
.
wer_kenlm_model
!=
""
:
from
examples.speech_recognition.w2l_decoder
import
W2lKenLMDecoder
dec_args
=
Namespace
()
dec_args
.
nbest
=
1
dec_args
.
criterion
=
"ctc"
dec_args
.
kenlm_model
=
cfg
.
wer_kenlm_model
dec_args
.
lexicon
=
cfg
.
wer_lexicon
dec_args
.
beam
=
50
dec_args
.
beam_size_token
=
min
(
50
,
len
(
task
.
target_dictionary
))
dec_args
.
beam_threshold
=
min
(
50
,
len
(
task
.
target_dictionary
))
dec_args
.
lm_weight
=
cfg
.
wer_lm_weight
dec_args
.
word_score
=
cfg
.
wer_word_score
dec_args
.
unk_weight
=
-
math
.
inf
dec_args
.
sil_weight
=
0
self
.
w2l_decoder
=
W2lKenLMDecoder
(
dec_args
,
task
.
target_dictionary
)
else
:
self
.
w2l_decoder
=
None
self
.
zero_infinity
=
cfg
.
zero_infinity
self
.
sentence_avg
=
cfg
.
sentence_avg
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
net_output
=
model
(
**
sample
[
"net_input"
])
lprobs
=
model
.
get_normalized_probs
(
net_output
,
log_probs
=
True
).
contiguous
()
# (T, B, C) from the encoder
if
"src_lengths"
in
sample
[
"net_input"
]:
input_lengths
=
sample
[
"net_input"
][
"src_lengths"
]
else
:
if
net_output
[
"padding_mask"
]
is
not
None
:
non_padding_mask
=
~
net_output
[
"padding_mask"
]
input_lengths
=
non_padding_mask
.
long
().
sum
(
-
1
)
else
:
input_lengths
=
lprobs
.
new_full
(
(
lprobs
.
size
(
1
),),
lprobs
.
size
(
0
),
dtype
=
torch
.
long
)
pad_mask
=
(
sample
[
"target"
]
!=
self
.
pad_idx
)
&
(
sample
[
"target"
]
!=
self
.
eos_idx
)
targets_flat
=
sample
[
"target"
].
masked_select
(
pad_mask
)
if
"target_lengths"
in
sample
:
target_lengths
=
sample
[
"target_lengths"
]
else
:
target_lengths
=
pad_mask
.
sum
(
-
1
)
with
torch
.
backends
.
cudnn
.
flags
(
enabled
=
False
):
loss
=
F
.
ctc_loss
(
lprobs
,
targets_flat
,
input_lengths
,
target_lengths
,
blank
=
self
.
blank_idx
,
reduction
=
"sum"
,
zero_infinity
=
self
.
zero_infinity
,
)
ntokens
=
(
sample
[
"ntokens"
]
if
"ntokens"
in
sample
else
target_lengths
.
sum
().
item
()
)
sample_size
=
sample
[
"target"
].
size
(
0
)
if
self
.
sentence_avg
else
ntokens
logging_output
=
{
"loss"
:
utils
.
item
(
loss
.
data
),
# * sample['ntokens'],
"ntokens"
:
ntokens
,
"nsentences"
:
sample
[
"id"
].
numel
(),
"sample_size"
:
sample_size
,
}
if
not
model
.
training
:
import
editdistance
with
torch
.
no_grad
():
lprobs_t
=
lprobs
.
transpose
(
0
,
1
).
float
().
contiguous
().
cpu
()
c_err
=
0
c_len
=
0
w_errs
=
0
w_len
=
0
wv_errs
=
0
for
lp
,
t
,
inp_l
in
zip
(
lprobs_t
,
sample
[
"target_label"
]
if
"target_label"
in
sample
else
sample
[
"target"
],
input_lengths
,
):
lp
=
lp
[:
inp_l
].
unsqueeze
(
0
)
decoded
=
None
if
self
.
w2l_decoder
is
not
None
:
decoded
=
self
.
w2l_decoder
.
decode
(
lp
)
if
len
(
decoded
)
<
1
:
decoded
=
None
else
:
decoded
=
decoded
[
0
]
if
len
(
decoded
)
<
1
:
decoded
=
None
else
:
decoded
=
decoded
[
0
]
p
=
(
t
!=
self
.
task
.
target_dictionary
.
pad
())
&
(
t
!=
self
.
task
.
target_dictionary
.
eos
()
)
targ
=
t
[
p
]
targ_units
=
self
.
task
.
target_dictionary
.
string
(
targ
)
targ_units_arr
=
targ
.
tolist
()
toks
=
lp
.
argmax
(
dim
=-
1
).
unique_consecutive
()
pred_units_arr
=
toks
[
toks
!=
self
.
blank_idx
].
tolist
()
c_err
+=
editdistance
.
eval
(
pred_units_arr
,
targ_units_arr
)
c_len
+=
len
(
targ_units_arr
)
targ_words
=
post_process
(
targ_units
,
self
.
post_process
).
split
()
pred_units
=
self
.
task
.
target_dictionary
.
string
(
pred_units_arr
)
pred_words_raw
=
post_process
(
pred_units
,
self
.
post_process
).
split
()
if
decoded
is
not
None
and
"words"
in
decoded
:
pred_words
=
decoded
[
"words"
]
w_errs
+=
editdistance
.
eval
(
pred_words
,
targ_words
)
wv_errs
+=
editdistance
.
eval
(
pred_words_raw
,
targ_words
)
else
:
dist
=
editdistance
.
eval
(
pred_words_raw
,
targ_words
)
w_errs
+=
dist
wv_errs
+=
dist
w_len
+=
len
(
targ_words
)
logging_output
[
"wv_errors"
]
=
wv_errs
logging_output
[
"w_errors"
]
=
w_errs
logging_output
[
"w_total"
]
=
w_len
logging_output
[
"c_errors"
]
=
c_err
logging_output
[
"c_total"
]
=
c_len
return
loss
,
sample_size
,
logging_output
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
"""Aggregate logging outputs from data parallel training."""
loss_sum
=
utils
.
item
(
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
))
ntokens
=
utils
.
item
(
sum
(
log
.
get
(
"ntokens"
,
0
)
for
log
in
logging_outputs
))
nsentences
=
utils
.
item
(
sum
(
log
.
get
(
"nsentences"
,
0
)
for
log
in
logging_outputs
)
)
sample_size
=
utils
.
item
(
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
)
metrics
.
log_scalar
(
"loss"
,
loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
)
metrics
.
log_scalar
(
"ntokens"
,
ntokens
)
metrics
.
log_scalar
(
"nsentences"
,
nsentences
)
if
sample_size
!=
ntokens
:
metrics
.
log_scalar
(
"nll_loss"
,
loss_sum
/
ntokens
/
math
.
log
(
2
),
ntokens
,
round
=
3
)
c_errors
=
sum
(
log
.
get
(
"c_errors"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_c_errors"
,
c_errors
)
c_total
=
sum
(
log
.
get
(
"c_total"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_c_total"
,
c_total
)
w_errors
=
sum
(
log
.
get
(
"w_errors"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_w_errors"
,
w_errors
)
wv_errors
=
sum
(
log
.
get
(
"wv_errors"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_wv_errors"
,
wv_errors
)
w_total
=
sum
(
log
.
get
(
"w_total"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"_w_total"
,
w_total
)
if
c_total
>
0
:
metrics
.
log_derived
(
"uer"
,
lambda
meters
:
safe_round
(
meters
[
"_c_errors"
].
sum
*
100.0
/
meters
[
"_c_total"
].
sum
,
3
)
if
meters
[
"_c_total"
].
sum
>
0
else
float
(
"nan"
),
)
if
w_total
>
0
:
metrics
.
log_derived
(
"wer"
,
lambda
meters
:
safe_round
(
meters
[
"_w_errors"
].
sum
*
100.0
/
meters
[
"_w_total"
].
sum
,
3
)
if
meters
[
"_w_total"
].
sum
>
0
else
float
(
"nan"
),
)
metrics
.
log_derived
(
"raw_wer"
,
lambda
meters
:
safe_round
(
meters
[
"_wv_errors"
].
sum
*
100.0
/
meters
[
"_w_total"
].
sum
,
3
)
if
meters
[
"_w_total"
].
sum
>
0
else
float
(
"nan"
),
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return
True
PyTorch/NLP/new-Transformer/fairseq/criterions/fairseq_criterion.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
inspect
from
typing
import
Any
,
Dict
,
List
from
fairseq
import
metrics
,
utils
from
fairseq.dataclass
import
FairseqDataclass
from
fairseq.dataclass.utils
import
gen_parser_from_dataclass
from
torch.nn.modules.loss
import
_Loss
class
FairseqCriterion
(
_Loss
):
def
__init__
(
self
,
task
):
super
().
__init__
()
self
.
task
=
task
if
hasattr
(
task
,
"target_dictionary"
):
tgt_dict
=
task
.
target_dictionary
self
.
padding_idx
=
tgt_dict
.
pad
()
if
tgt_dict
is
not
None
else
-
100
@
classmethod
def
add_args
(
cls
,
parser
):
"""Add criterion-specific arguments to the parser."""
dc
=
getattr
(
cls
,
"__dataclass"
,
None
)
if
dc
is
not
None
:
gen_parser_from_dataclass
(
parser
,
dc
())
@
classmethod
def
build_criterion
(
cls
,
cfg
:
FairseqDataclass
,
task
):
"""Construct a criterion from command-line args."""
# arguments in the __init__.
init_args
=
{}
for
p
in
inspect
.
signature
(
cls
).
parameters
.
values
():
if
(
p
.
kind
==
p
.
POSITIONAL_ONLY
or
p
.
kind
==
p
.
VAR_POSITIONAL
or
p
.
kind
==
p
.
VAR_KEYWORD
):
# we haven't implemented inference for these argument types,
# but PRs welcome :)
raise
NotImplementedError
(
"{} not supported"
.
format
(
p
.
kind
))
assert
p
.
kind
in
{
p
.
POSITIONAL_OR_KEYWORD
,
p
.
KEYWORD_ONLY
}
if
p
.
name
==
"task"
:
init_args
[
"task"
]
=
task
elif
p
.
name
==
"cfg"
:
init_args
[
"cfg"
]
=
cfg
elif
hasattr
(
cfg
,
p
.
name
):
init_args
[
p
.
name
]
=
getattr
(
cfg
,
p
.
name
)
elif
p
.
default
!=
p
.
empty
:
pass
# we'll use the default value
else
:
raise
NotImplementedError
(
"Unable to infer Criterion arguments, please implement "
"{}.build_criterion"
.
format
(
cls
.
__name__
)
)
return
cls
(
**
init_args
)
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
raise
NotImplementedError
@
staticmethod
def
aggregate_logging_outputs
(
logging_outputs
:
List
[
Dict
[
str
,
Any
]]
)
->
Dict
[
str
,
Any
]:
"""Aggregate logging outputs from data parallel training."""
utils
.
deprecation_warning
(
"The aggregate_logging_outputs API is deprecated. "
"Please use the reduce_metrics API instead."
)
raise
NotImplementedError
@
classmethod
def
reduce_metrics
(
cls
,
logging_outputs
:
List
[
Dict
[
str
,
Any
]])
->
None
:
"""Aggregate logging outputs from data parallel training."""
utils
.
deprecation_warning
(
"Criterions should implement the reduce_metrics API. "
"Falling back to deprecated aggregate_logging_outputs API."
)
agg_logging_outputs
=
cls
.
aggregate_logging_outputs
(
logging_outputs
)
for
k
,
v
in
agg_logging_outputs
.
items
():
if
k
in
{
"nsentences"
,
"ntokens"
,
"sample_size"
}:
continue
metrics
.
log_scalar
(
k
,
v
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return
False
class
LegacyFairseqCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
args
,
task
):
super
().
__init__
(
task
=
task
)
self
.
args
=
args
utils
.
deprecation_warning
(
"Criterions should take explicit arguments instead of an "
"argparse.Namespace object, please update your criterion by "
"extending FairseqCriterion instead of LegacyFairseqCriterion."
)
@
classmethod
def
build_criterion
(
cls
,
args
,
task
):
"""Construct a criterion from command-line args."""
return
cls
(
args
,
task
)
PyTorch/NLP/new-Transformer/fairseq/criterions/fastspeech2_loss.py
0 → 100644
View file @
c0f05c10
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from
typing
import
List
,
Dict
,
Any
from
dataclasses
import
dataclass
,
field
import
torch
import
torch.nn.functional
as
F
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
from
fairseq.dataclass
import
FairseqDataclass
from
fairseq.data.data_utils
import
lengths_to_mask
from
fairseq.models.fairseq_model
import
FairseqEncoderModel
@
dataclass
class
FastSpeech2CriterionConfig
(
FairseqDataclass
):
ctc_weight
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"weight for CTC loss"
})
@
register_criterion
(
"fastspeech2"
,
dataclass
=
FastSpeech2CriterionConfig
)
class
FastSpeech2Loss
(
FairseqCriterion
):
def
__init__
(
self
,
task
,
ctc_weight
):
super
().
__init__
(
task
)
self
.
ctc_weight
=
ctc_weight
def
forward
(
self
,
model
:
FairseqEncoderModel
,
sample
,
reduction
=
"mean"
):
src_tokens
=
sample
[
"net_input"
][
"src_tokens"
]
src_lens
=
sample
[
"net_input"
][
"src_lengths"
]
tgt_lens
=
sample
[
"target_lengths"
]
_feat_out
,
_feat_out_post
,
_
,
log_dur_out
,
pitch_out
,
energy_out
=
model
(
src_tokens
=
src_tokens
,
src_lengths
=
src_lens
,
prev_output_tokens
=
sample
[
"net_input"
][
"prev_output_tokens"
],
incremental_state
=
None
,
target_lengths
=
tgt_lens
,
speaker
=
sample
[
"speaker"
],
durations
=
sample
[
"durations"
],
pitches
=
sample
[
"pitches"
],
energies
=
sample
[
"energies"
],
)
src_mask
=
lengths_to_mask
(
sample
[
"net_input"
][
"src_lengths"
])
tgt_mask
=
lengths_to_mask
(
sample
[
"target_lengths"
])
pitches
,
energies
=
sample
[
"pitches"
],
sample
[
"energies"
]
pitch_out
,
pitches
=
pitch_out
[
src_mask
],
pitches
[
src_mask
]
energy_out
,
energies
=
energy_out
[
src_mask
],
energies
[
src_mask
]
feat_out
,
feat
=
_feat_out
[
tgt_mask
],
sample
[
"target"
][
tgt_mask
]
l1_loss
=
F
.
l1_loss
(
feat_out
,
feat
,
reduction
=
reduction
)
if
_feat_out_post
is
not
None
:
l1_loss
+=
F
.
l1_loss
(
_feat_out_post
[
tgt_mask
],
feat
,
reduction
=
reduction
)
pitch_loss
=
F
.
mse_loss
(
pitch_out
,
pitches
,
reduction
=
reduction
)
energy_loss
=
F
.
mse_loss
(
energy_out
,
energies
,
reduction
=
reduction
)
log_dur_out
=
log_dur_out
[
src_mask
]
dur
=
sample
[
"durations"
].
float
()
dur
=
dur
.
half
()
if
log_dur_out
.
type
().
endswith
(
".HalfTensor"
)
else
dur
log_dur
=
torch
.
log
(
dur
+
1
)[
src_mask
]
dur_loss
=
F
.
mse_loss
(
log_dur_out
,
log_dur
,
reduction
=
reduction
)
ctc_loss
=
torch
.
tensor
(
0.0
).
type_as
(
l1_loss
)
if
self
.
ctc_weight
>
0.0
:
lprobs
=
model
.
get_normalized_probs
((
_feat_out
,),
log_probs
=
True
)
lprobs
=
lprobs
.
transpose
(
0
,
1
)
# T x B x C
src_mask
=
lengths_to_mask
(
src_lens
)
src_tokens_flat
=
src_tokens
.
masked_select
(
src_mask
)
ctc_loss
=
(
F
.
ctc_loss
(
lprobs
,
src_tokens_flat
,
tgt_lens
,
src_lens
,
reduction
=
reduction
,
zero_infinity
=
True
,
)
*
self
.
ctc_weight
)
loss
=
l1_loss
+
dur_loss
+
pitch_loss
+
energy_loss
+
ctc_loss
sample_size
=
sample
[
"nsentences"
]
logging_output
=
{
"loss"
:
utils
.
item
(
loss
.
data
),
"ntokens"
:
sample
[
"ntokens"
],
"nsentences"
:
sample
[
"nsentences"
],
"sample_size"
:
sample_size
,
"l1_loss"
:
utils
.
item
(
l1_loss
.
data
),
"dur_loss"
:
utils
.
item
(
dur_loss
.
data
),
"pitch_loss"
:
utils
.
item
(
pitch_loss
.
data
),
"energy_loss"
:
utils
.
item
(
energy_loss
.
data
),
"ctc_loss"
:
utils
.
item
(
ctc_loss
.
data
),
}
return
loss
,
sample_size
,
logging_output
@
classmethod
def
reduce_metrics
(
cls
,
logging_outputs
:
List
[
Dict
[
str
,
Any
]])
->
None
:
ns
=
[
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
]
ntot
=
sum
(
ns
)
ws
=
[
n
/
(
ntot
+
1e-8
)
for
n
in
ns
]
for
key
in
[
"loss"
,
"l1_loss"
,
"dur_loss"
,
"pitch_loss"
,
"energy_loss"
,
"ctc_loss"
,
]:
vals
=
[
log
.
get
(
key
,
0
)
for
log
in
logging_outputs
]
val
=
sum
(
val
*
w
for
val
,
w
in
zip
(
vals
,
ws
))
metrics
.
log_scalar
(
key
,
val
,
ntot
,
round
=
3
)
metrics
.
log_scalar
(
"sample_size"
,
ntot
,
len
(
logging_outputs
))
# inference metrics
if
"targ_frames"
not
in
logging_outputs
[
0
]:
return
n
=
sum
(
log
.
get
(
"targ_frames"
,
0
)
for
log
in
logging_outputs
)
for
key
,
new_key
in
[
(
"mcd_loss"
,
"mcd_loss"
),
(
"pred_frames"
,
"pred_ratio"
),
(
"nins"
,
"ins_rate"
),
(
"ndel"
,
"del_rate"
),
]:
val
=
sum
(
log
.
get
(
key
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
new_key
,
val
/
n
,
n
,
round
=
3
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
return
False
PyTorch/NLP/new-Transformer/fairseq/criterions/hubert_criterion.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
import
re
from
dataclasses
import
dataclass
,
field
from
typing
import
List
,
Optional
import
torch
import
torch.nn.functional
as
F
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
from
fairseq.dataclass
import
FairseqDataclass
@
dataclass
class
HubertCriterionConfig
(
FairseqDataclass
):
pred_masked_weight
:
float
=
field
(
default
=
1.0
,
metadata
=
{
"help"
:
"weight for predictive loss for masked frames"
},
)
pred_nomask_weight
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"weight for predictive loss for unmasked frames"
},
)
loss_weights
:
Optional
[
List
[
float
]]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"weights for additional loss terms (not first one)"
},
)
log_keys
:
List
[
str
]
=
field
(
default_factory
=
lambda
:
[],
metadata
=
{
"help"
:
"output keys to log"
},
)
@
register_criterion
(
"hubert"
,
dataclass
=
HubertCriterionConfig
)
class
HubertCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
task
,
pred_masked_weight
,
pred_nomask_weight
,
loss_weights
=
None
,
log_keys
=
None
,
):
super
().
__init__
(
task
)
self
.
pred_masked_weight
=
pred_masked_weight
self
.
pred_nomask_weight
=
pred_nomask_weight
self
.
loss_weights
=
loss_weights
self
.
log_keys
=
[]
if
log_keys
is
None
else
log_keys
def
forward
(
self
,
model
,
sample
,
reduce
=
True
,
log_pred
=
False
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output
=
model
(
target_list
=
sample
[
"target_list"
],
**
sample
[
"net_input"
])
loss
=
0.0
sample_size
=
0
logging_output
=
{}
reduction
=
"sum"
if
reduce
else
"none"
loss_m_list
=
[]
logp_m_list
=
model
.
get_logits
(
net_output
,
True
)
targ_m_list
=
model
.
get_targets
(
net_output
,
True
)
assert
self
.
pred_masked_weight
==
0
or
len
(
logp_m_list
)
>
0
for
i
,
(
logp_m
,
targ_m
)
in
enumerate
(
zip
(
logp_m_list
,
targ_m_list
)):
loss_m
=
F
.
cross_entropy
(
logp_m
,
targ_m
,
reduction
=
reduction
)
loss_m_list
.
append
(
loss_m
)
logging_output
[
f
"loss_m_
{
i
}
"
]
=
loss_m
.
detach
().
item
()
if
self
.
pred_masked_weight
>
0
:
loss
+=
self
.
pred_masked_weight
*
sum
(
loss_m_list
)
sample_size
+=
targ_m_list
[
0
].
numel
()
loss_u_list
=
[]
logp_u_list
=
model
.
get_logits
(
net_output
,
False
)
targ_u_list
=
model
.
get_targets
(
net_output
,
False
)
assert
self
.
pred_nomask_weight
==
0
or
len
(
logp_u_list
)
>
0
for
i
,
(
logp_u
,
targ_u
)
in
enumerate
(
zip
(
logp_u_list
,
targ_u_list
)):
loss_u
=
F
.
cross_entropy
(
logp_u
,
targ_u
,
reduction
=
reduction
)
loss_u_list
.
append
(
loss_u
)
logging_output
[
f
"loss_u_
{
i
}
"
]
=
loss_u
.
detach
().
item
()
if
self
.
pred_nomask_weight
>
0
:
loss
+=
self
.
pred_nomask_weight
*
sum
(
loss_u_list
)
sample_size
+=
targ_u_list
[
0
].
numel
()
if
self
.
loss_weights
is
not
None
:
assert
hasattr
(
model
,
"get_extra_losses"
)
extra_losses
,
names
=
model
.
get_extra_losses
(
net_output
)
if
torch
.
is_tensor
(
extra_losses
):
extra_losses
=
[
extra_losses
]
names
=
[
names
]
if
len
(
self
.
loss_weights
)
==
1
and
len
(
extra_losses
)
!=
1
:
self
.
loss_weights
=
[
self
.
loss_weights
[
0
]]
*
len
(
extra_losses
)
assert
len
(
extra_losses
)
==
len
(
self
.
loss_weights
),
f
"
{
len
(
extra_losses
)
}
,
{
len
(
self
.
loss_weights
)
}
"
for
p
,
n
,
coef
in
zip
(
extra_losses
,
names
,
self
.
loss_weights
):
if
coef
!=
0
and
p
is
not
None
:
p
=
coef
*
p
.
float
()
*
sample_size
loss
+=
p
logging_output
[
f
"loss_
{
n
}
"
]
=
p
.
item
()
logging_output
=
{
"loss"
:
loss
.
item
()
if
reduce
else
loss
,
"ntokens"
:
sample_size
,
"nsentences"
:
sample
[
"id"
].
numel
(),
"sample_size"
:
sample_size
,
**
logging_output
,
}
for
lk
in
self
.
log_keys
:
if
lk
in
net_output
:
logging_output
[
lk
]
=
float
((
net_output
[
lk
]))
def
compute_correct
(
logits
):
if
logits
.
numel
()
==
0
:
return
0
,
0
else
:
assert
logits
.
dim
()
>
1
,
logits
.
shape
max
=
logits
.
argmax
(
-
1
)
==
0
min
=
logits
.
argmin
(
-
1
)
==
0
both
=
max
&
min
corr
=
max
.
long
().
sum
().
item
()
-
both
.
long
().
sum
().
item
()
count
=
max
.
numel
()
return
corr
,
count
with
torch
.
no_grad
():
for
i
,
logp_m
in
enumerate
(
logp_m_list
):
corr_m
,
count_m
=
compute_correct
(
logp_m
)
logging_output
[
f
"correct_m_
{
i
}
"
]
=
corr_m
logging_output
[
f
"count_m_
{
i
}
"
]
=
count_m
for
i
,
logp_u
in
enumerate
(
logp_u_list
):
corr_u
,
count_u
=
compute_correct
(
logp_u
)
logging_output
[
f
"correct_u_
{
i
}
"
]
=
corr_u
logging_output
[
f
"count_u_
{
i
}
"
]
=
count_u
return
loss
,
sample_size
,
logging_output
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
"""Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
loss_sum
=
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
)
ntokens
=
sum
(
log
.
get
(
"ntokens"
,
0
)
for
log
in
logging_outputs
)
sample_size
=
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"loss"
,
loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
)
if
sample_size
!=
ntokens
:
metrics
.
log_scalar
(
"nll_loss"
,
loss_sum
/
ntokens
/
math
.
log
(
2
),
ntokens
,
round
=
3
)
metrics
.
log_derived
(
"ppl"
,
lambda
meters
:
utils
.
get_perplexity
(
meters
[
"nll_loss"
].
avg
)
)
else
:
metrics
.
log_derived
(
"ppl"
,
lambda
meters
:
utils
.
get_perplexity
(
meters
[
"loss"
].
avg
)
)
counts
=
{}
for
lk
in
logging_outputs
[
0
].
keys
():
if
lk
.
startswith
(
"count_"
):
val
=
sum
(
log
[
lk
]
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
lk
,
val
)
counts
[
lk
]
=
val
for
lk
in
logging_outputs
[
0
].
keys
():
if
lk
.
startswith
(
"loss_"
):
val
=
sum
(
log
[
lk
]
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
lk
,
val
/
sample_size
/
math
.
log
(
2
),
round
=
3
)
elif
lk
.
startswith
(
"correct_"
):
val
=
sum
(
log
[
lk
]
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
lk
,
val
/
counts
[
re
.
sub
(
"correct"
,
"count"
,
lk
)])
@
staticmethod
def
aggregate_logging_outputs
(
logging_outputs
):
"""Aggregate logging outputs from data parallel training."""
raise
NotImplementedError
()
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return
False
PyTorch/NLP/new-Transformer/fairseq/criterions/label_smoothed_cross_entropy.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
dataclasses
import
dataclass
,
field
import
torch
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
from
fairseq.dataclass
import
FairseqDataclass
from
omegaconf
import
II
@
dataclass
class
LabelSmoothedCrossEntropyCriterionConfig
(
FairseqDataclass
):
label_smoothing
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"epsilon for label smoothing, 0 means no label smoothing"
},
)
report_accuracy
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"report accuracy metric"
},
)
ignore_prefix_size
:
int
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"Ignore first N tokens"
},
)
sentence_avg
:
bool
=
II
(
"optimization.sentence_avg"
)
def
label_smoothed_nll_loss
(
lprobs
,
target
,
epsilon
,
ignore_index
=
None
,
reduce
=
True
):
if
target
.
dim
()
==
lprobs
.
dim
()
-
1
:
target
=
target
.
unsqueeze
(
-
1
)
nll_loss
=
-
lprobs
.
gather
(
dim
=-
1
,
index
=
target
)
smooth_loss
=
-
lprobs
.
sum
(
dim
=-
1
,
keepdim
=
True
)
if
ignore_index
is
not
None
:
pad_mask
=
target
.
eq
(
ignore_index
)
nll_loss
.
masked_fill_
(
pad_mask
,
0.0
)
smooth_loss
.
masked_fill_
(
pad_mask
,
0.0
)
else
:
nll_loss
=
nll_loss
.
squeeze
(
-
1
)
smooth_loss
=
smooth_loss
.
squeeze
(
-
1
)
if
reduce
:
nll_loss
=
nll_loss
.
sum
()
smooth_loss
=
smooth_loss
.
sum
()
eps_i
=
epsilon
/
(
lprobs
.
size
(
-
1
)
-
1
)
loss
=
(
1.0
-
epsilon
-
eps_i
)
*
nll_loss
+
eps_i
*
smooth_loss
return
loss
,
nll_loss
@
register_criterion
(
"label_smoothed_cross_entropy"
,
dataclass
=
LabelSmoothedCrossEntropyCriterionConfig
)
class
LabelSmoothedCrossEntropyCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
task
,
sentence_avg
,
label_smoothing
,
ignore_prefix_size
=
0
,
report_accuracy
=
False
,
):
super
().
__init__
(
task
)
self
.
sentence_avg
=
sentence_avg
self
.
eps
=
label_smoothing
self
.
ignore_prefix_size
=
ignore_prefix_size
self
.
report_accuracy
=
report_accuracy
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output
=
model
(
**
sample
[
"net_input"
])
loss
,
nll_loss
=
self
.
compute_loss
(
model
,
net_output
,
sample
,
reduce
=
reduce
)
sample_size
=
(
sample
[
"target"
].
size
(
0
)
if
self
.
sentence_avg
else
sample
[
"ntokens"
]
)
logging_output
=
{
"loss"
:
loss
.
data
,
"nll_loss"
:
nll_loss
.
data
,
"ntokens"
:
sample
[
"ntokens"
],
"nsentences"
:
sample
[
"target"
].
size
(
0
),
"sample_size"
:
sample_size
,
}
if
self
.
report_accuracy
:
n_correct
,
total
=
self
.
compute_accuracy
(
model
,
net_output
,
sample
)
logging_output
[
"n_correct"
]
=
utils
.
item
(
n_correct
.
data
)
logging_output
[
"total"
]
=
utils
.
item
(
total
.
data
)
return
loss
,
sample_size
,
logging_output
def
get_lprobs_and_target
(
self
,
model
,
net_output
,
sample
):
lprobs
=
model
.
get_normalized_probs
(
net_output
,
log_probs
=
True
)
target
=
model
.
get_targets
(
sample
,
net_output
)
if
self
.
ignore_prefix_size
>
0
:
# lprobs: B x T x C
lprobs
=
lprobs
[:,
self
.
ignore_prefix_size
:,
:].
contiguous
()
target
=
target
[:,
self
.
ignore_prefix_size
:].
contiguous
()
return
lprobs
.
view
(
-
1
,
lprobs
.
size
(
-
1
)),
target
.
view
(
-
1
)
def
compute_loss
(
self
,
model
,
net_output
,
sample
,
reduce
=
True
):
lprobs
,
target
=
self
.
get_lprobs_and_target
(
model
,
net_output
,
sample
)
loss
,
nll_loss
=
label_smoothed_nll_loss
(
lprobs
,
target
,
self
.
eps
,
ignore_index
=
self
.
padding_idx
,
reduce
=
reduce
,
)
return
loss
,
nll_loss
def
compute_accuracy
(
self
,
model
,
net_output
,
sample
):
lprobs
,
target
=
self
.
get_lprobs_and_target
(
model
,
net_output
,
sample
)
mask
=
target
.
ne
(
self
.
padding_idx
)
n_correct
=
torch
.
sum
(
lprobs
.
argmax
(
1
).
masked_select
(
mask
).
eq
(
target
.
masked_select
(
mask
))
)
total
=
torch
.
sum
(
mask
)
return
n_correct
,
total
@
classmethod
def
reduce_metrics
(
cls
,
logging_outputs
)
->
None
:
"""Aggregate logging outputs from data parallel training."""
loss_sum
=
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
)
nll_loss_sum
=
sum
(
log
.
get
(
"nll_loss"
,
0
)
for
log
in
logging_outputs
)
ntokens
=
sum
(
log
.
get
(
"ntokens"
,
0
)
for
log
in
logging_outputs
)
sample_size
=
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"loss"
,
loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
)
metrics
.
log_scalar
(
"nll_loss"
,
nll_loss_sum
/
ntokens
/
math
.
log
(
2
),
ntokens
,
round
=
3
)
metrics
.
log_derived
(
"ppl"
,
lambda
meters
:
utils
.
get_perplexity
(
meters
[
"nll_loss"
].
avg
)
)
total
=
utils
.
item
(
sum
(
log
.
get
(
"total"
,
0
)
for
log
in
logging_outputs
))
if
total
>
0
:
metrics
.
log_scalar
(
"total"
,
total
)
n_correct
=
utils
.
item
(
sum
(
log
.
get
(
"n_correct"
,
0
)
for
log
in
logging_outputs
)
)
metrics
.
log_scalar
(
"n_correct"
,
n_correct
)
metrics
.
log_derived
(
"accuracy"
,
lambda
meters
:
round
(
meters
[
"n_correct"
].
sum
*
100.0
/
meters
[
"total"
].
sum
,
3
)
if
meters
[
"total"
].
sum
>
0
else
float
(
"nan"
),
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return
True
PyTorch/NLP/new-Transformer/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
dataclasses
import
dataclass
,
field
import
torch
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
register_criterion
from
fairseq.criterions.label_smoothed_cross_entropy
import
(
LabelSmoothedCrossEntropyCriterion
,
LabelSmoothedCrossEntropyCriterionConfig
,
)
try
:
from
simuleval.metrics.latency
import
(
AverageLagging
,
AverageProportion
,
DifferentiableAverageLagging
,
)
LATENCY_METRICS
=
{
"average_lagging"
:
AverageLagging
,
"average_proportion"
:
AverageProportion
,
"differentiable_average_lagging"
:
DifferentiableAverageLagging
,
}
except
ImportError
:
LATENCY_METRICS
=
None
@
dataclass
class
LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig
(
LabelSmoothedCrossEntropyCriterionConfig
):
latency_avg_weight
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"weight fot average latency loss."
},
)
latency_var_weight
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"weight fot variance latency loss."
},
)
latency_avg_type
:
str
=
field
(
default
=
"differentiable_average_lagging"
,
metadata
=
{
"help"
:
"latency type for average loss"
},
)
latency_var_type
:
str
=
field
(
default
=
"variance_delay"
,
metadata
=
{
"help"
:
"latency typ for variance loss"
},
)
latency_gather_method
:
str
=
field
(
default
=
"weighted_average"
,
metadata
=
{
"help"
:
"method to gather latency loss for all heads"
},
)
latency_update_after
:
int
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"Add latency loss after certain steps"
},
)
@
register_criterion
(
"latency_augmented_label_smoothed_cross_entropy"
,
dataclass
=
LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig
,
)
class
LatencyAugmentedLabelSmoothedCrossEntropyCriterion
(
LabelSmoothedCrossEntropyCriterion
):
def
__init__
(
self
,
task
,
sentence_avg
,
label_smoothing
,
ignore_prefix_size
,
report_accuracy
,
latency_avg_weight
,
latency_var_weight
,
latency_avg_type
,
latency_var_type
,
latency_gather_method
,
latency_update_after
,
):
super
().
__init__
(
task
,
sentence_avg
,
label_smoothing
,
ignore_prefix_size
,
report_accuracy
)
assert
LATENCY_METRICS
is
not
None
,
"Please make sure SimulEval is installed."
self
.
latency_avg_weight
=
latency_avg_weight
self
.
latency_var_weight
=
latency_var_weight
self
.
latency_avg_type
=
latency_avg_type
self
.
latency_var_type
=
latency_var_type
self
.
latency_gather_method
=
latency_gather_method
self
.
latency_update_after
=
latency_update_after
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
net_output
=
model
(
**
sample
[
"net_input"
])
# 1. Compute cross entropy loss
loss
,
nll_loss
=
self
.
compute_loss
(
model
,
net_output
,
sample
,
reduce
=
reduce
)
# 2. Compute cross latency loss
latency_loss
,
expected_latency
,
expected_delays_var
=
self
.
compute_latency_loss
(
model
,
sample
,
net_output
)
if
self
.
latency_update_after
>
0
:
num_updates
=
getattr
(
model
.
decoder
,
"num_updates"
,
None
)
assert
(
num_updates
is
not
None
),
"model.decoder doesn't have attribute 'num_updates'"
if
num_updates
<=
self
.
latency_update_after
:
latency_loss
=
0
loss
+=
latency_loss
sample_size
=
(
sample
[
"target"
].
size
(
0
)
if
self
.
sentence_avg
else
sample
[
"ntokens"
]
)
logging_output
=
{
"loss"
:
loss
.
data
,
"nll_loss"
:
nll_loss
.
data
,
"ntokens"
:
sample
[
"ntokens"
],
"nsentences"
:
sample
[
"target"
].
size
(
0
),
"sample_size"
:
sample_size
,
"latency"
:
expected_latency
,
"delays_var"
:
expected_delays_var
,
"latency_loss"
:
latency_loss
,
}
if
self
.
report_accuracy
:
n_correct
,
total
=
self
.
compute_accuracy
(
model
,
net_output
,
sample
)
logging_output
[
"n_correct"
]
=
utils
.
item
(
n_correct
.
data
)
logging_output
[
"total"
]
=
utils
.
item
(
total
.
data
)
return
loss
,
sample_size
,
logging_output
def
compute_latency_loss
(
self
,
model
,
sample
,
net_output
):
assert
(
net_output
[
-
1
].
encoder_padding_mask
is
None
or
not
net_output
[
-
1
].
encoder_padding_mask
[:,
0
].
any
()
),
"Only right padding on source is supported."
# 1. Obtain the expected alignment
alpha_list
=
[
item
[
"alpha"
]
for
item
in
net_output
[
1
].
attn_list
]
num_layers
=
len
(
alpha_list
)
bsz
,
num_heads
,
tgt_len
,
src_len
=
alpha_list
[
0
].
size
()
# bsz * num_layers * num_heads, tgt_len, src_len
alpha_all
=
torch
.
cat
(
alpha_list
,
dim
=
1
).
view
(
-
1
,
tgt_len
,
src_len
)
# 2 compute expected delays
# bsz * num_heads * num_layers, tgt_len, src_len for MMA
steps
=
(
torch
.
arange
(
1
,
1
+
src_len
)
.
unsqueeze
(
0
)
.
unsqueeze
(
1
)
.
expand_as
(
alpha_all
)
.
type_as
(
alpha_all
)
)
expected_delays
=
torch
.
sum
(
steps
*
alpha_all
,
dim
=-
1
)
target_padding_mask
=
(
model
.
get_targets
(
sample
,
net_output
)
.
eq
(
self
.
padding_idx
)
.
unsqueeze
(
1
)
.
expand
(
bsz
,
num_layers
*
num_heads
,
tgt_len
)
.
contiguous
()
.
view
(
-
1
,
tgt_len
)
)
src_lengths
=
(
sample
[
"net_input"
][
"src_lengths"
]
.
unsqueeze
(
1
)
.
expand
(
bsz
,
num_layers
*
num_heads
)
.
contiguous
()
.
view
(
-
1
)
)
expected_latency
=
LATENCY_METRICS
[
self
.
latency_avg_type
](
expected_delays
,
src_lengths
,
None
,
target_padding_mask
=
target_padding_mask
)
# 2.1 average expected latency of heads
# bsz, num_layers * num_heads
expected_latency
=
expected_latency
.
view
(
bsz
,
-
1
)
if
self
.
latency_gather_method
==
"average"
:
# bsz * tgt_len
expected_latency
=
expected_delays
.
mean
(
dim
=
1
)
elif
self
.
latency_gather_method
==
"weighted_average"
:
weights
=
torch
.
nn
.
functional
.
softmax
(
expected_latency
,
dim
=
1
)
expected_latency
=
torch
.
sum
(
expected_latency
*
weights
,
dim
=
1
)
elif
self
.
latency_gather_method
==
"max"
:
expected_latency
=
expected_latency
.
max
(
dim
=
1
)[
0
]
else
:
raise
NotImplementedError
expected_latency
=
expected_latency
.
sum
()
avg_loss
=
self
.
latency_avg_weight
*
expected_latency
# 2.2 variance of expected delays
expected_delays_var
=
(
expected_delays
.
view
(
bsz
,
-
1
,
tgt_len
).
var
(
dim
=
1
).
mean
(
dim
=
1
)
)
expected_delays_var
=
expected_delays_var
.
sum
()
var_loss
=
self
.
latency_avg_weight
*
expected_delays_var
# 3. Final loss
latency_loss
=
avg_loss
+
var_loss
return
latency_loss
,
expected_latency
,
expected_delays_var
@
classmethod
def
reduce_metrics
(
cls
,
logging_outputs
)
->
None
:
super
().
reduce_metrics
(
logging_outputs
)
latency
=
sum
(
log
.
get
(
"latency"
,
0
)
for
log
in
logging_outputs
)
delays_var
=
sum
(
log
.
get
(
"delays_var"
,
0
)
for
log
in
logging_outputs
)
latency_loss
=
sum
(
log
.
get
(
"latency_loss"
,
0
)
for
log
in
logging_outputs
)
nsentences
=
sum
(
log
.
get
(
"nsentences"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"latency"
,
latency
.
float
()
/
nsentences
,
nsentences
,
round
=
3
)
metrics
.
log_scalar
(
"delays_var"
,
delays_var
/
nsentences
,
nsentences
,
round
=
3
)
metrics
.
log_scalar
(
"latency_loss"
,
latency_loss
/
nsentences
,
nsentences
,
round
=
3
)
PyTorch/NLP/new-Transformer/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
register_criterion
from
.label_smoothed_cross_entropy
import
(
LabelSmoothedCrossEntropyCriterion
,
LabelSmoothedCrossEntropyCriterionConfig
,
)
from
dataclasses
import
dataclass
,
field
@
dataclass
class
LabelSmoothedCrossEntropyCriterionWithAlignmentConfig
(
LabelSmoothedCrossEntropyCriterionConfig
):
alignment_lambda
:
float
=
field
(
default
=
0.05
,
metadata
=
{
"help"
:
"weight for the alignment loss"
}
)
@
register_criterion
(
"label_smoothed_cross_entropy_with_alignment"
,
dataclass
=
LabelSmoothedCrossEntropyCriterionWithAlignmentConfig
,
)
class
LabelSmoothedCrossEntropyCriterionWithAlignment
(
LabelSmoothedCrossEntropyCriterion
):
def
__init__
(
self
,
task
,
sentence_avg
,
label_smoothing
,
alignment_lambda
):
super
().
__init__
(
task
,
sentence_avg
,
label_smoothing
)
self
.
alignment_lambda
=
alignment_lambda
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output
=
model
(
**
sample
[
"net_input"
])
loss
,
nll_loss
=
self
.
compute_loss
(
model
,
net_output
,
sample
,
reduce
=
reduce
)
sample_size
=
(
sample
[
"target"
].
size
(
0
)
if
self
.
sentence_avg
else
sample
[
"ntokens"
]
)
logging_output
=
{
"loss"
:
utils
.
item
(
loss
.
data
)
if
reduce
else
loss
.
data
,
"nll_loss"
:
utils
.
item
(
nll_loss
.
data
)
if
reduce
else
nll_loss
.
data
,
"ntokens"
:
sample
[
"ntokens"
],
"nsentences"
:
sample
[
"target"
].
size
(
0
),
"sample_size"
:
sample_size
,
}
alignment_loss
=
None
# Compute alignment loss only for training set and non dummy batches.
if
"alignments"
in
sample
and
sample
[
"alignments"
]
is
not
None
:
alignment_loss
=
self
.
compute_alignment_loss
(
sample
,
net_output
)
if
alignment_loss
is
not
None
:
logging_output
[
"alignment_loss"
]
=
utils
.
item
(
alignment_loss
.
data
)
loss
+=
self
.
alignment_lambda
*
alignment_loss
return
loss
,
sample_size
,
logging_output
def
compute_alignment_loss
(
self
,
sample
,
net_output
):
attn_prob
=
net_output
[
1
][
"attn"
][
0
]
bsz
,
tgt_sz
,
src_sz
=
attn_prob
.
shape
attn
=
attn_prob
.
view
(
bsz
*
tgt_sz
,
src_sz
)
align
=
sample
[
"alignments"
]
align_weights
=
sample
[
"align_weights"
].
float
()
if
len
(
align
)
>
0
:
# Alignment loss computation. align (shape [:, 2]) contains the src-tgt index pairs corresponding to
# the alignments. align_weights (shape [:]) contains the 1 / frequency of a tgt index for normalizing.
loss
=
-
(
(
attn
[
align
[:,
1
][:,
None
],
align
[:,
0
][:,
None
]]).
log
()
*
align_weights
[:,
None
]
).
sum
()
else
:
return
None
return
loss
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
"""Aggregate logging outputs from data parallel training."""
loss_sum
=
utils
.
item
(
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
))
nll_loss_sum
=
utils
.
item
(
sum
(
log
.
get
(
"nll_loss"
,
0
)
for
log
in
logging_outputs
)
)
alignment_loss_sum
=
utils
.
item
(
sum
(
log
.
get
(
"alignment_loss"
,
0
)
for
log
in
logging_outputs
)
)
ntokens
=
utils
.
item
(
sum
(
log
.
get
(
"ntokens"
,
0
)
for
log
in
logging_outputs
))
sample_size
=
utils
.
item
(
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
)
metrics
.
log_scalar
(
"loss"
,
loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
)
metrics
.
log_scalar
(
"nll_loss"
,
nll_loss_sum
/
ntokens
/
math
.
log
(
2
),
ntokens
,
round
=
3
)
metrics
.
log_scalar
(
"alignment_loss"
,
alignment_loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
,
)
metrics
.
log_derived
(
"ppl"
,
lambda
meters
:
utils
.
get_perplexity
(
meters
[
"nll_loss"
].
avg
)
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return
True
PyTorch/NLP/new-Transformer/fairseq/criterions/label_smoothed_cross_entropy_with_ctc.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
dataclasses
import
dataclass
,
field
import
torch
import
torch.nn.functional
as
F
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
register_criterion
from
fairseq.criterions.label_smoothed_cross_entropy
import
(
LabelSmoothedCrossEntropyCriterion
,
LabelSmoothedCrossEntropyCriterionConfig
,
)
from
fairseq.data.data_utils
import
lengths_to_mask
@
dataclass
class
LabelSmoothedCrossEntropyWithCtcCriterionConfig
(
LabelSmoothedCrossEntropyCriterionConfig
):
ctc_weight
:
float
=
field
(
default
=
1.0
,
metadata
=
{
"help"
:
"weight for CTC loss"
})
@
register_criterion
(
"label_smoothed_cross_entropy_with_ctc"
,
dataclass
=
LabelSmoothedCrossEntropyWithCtcCriterionConfig
,
)
class
LabelSmoothedCrossEntropyWithCtcCriterion
(
LabelSmoothedCrossEntropyCriterion
):
def
__init__
(
self
,
task
,
sentence_avg
,
label_smoothing
,
ignore_prefix_size
,
report_accuracy
,
ctc_weight
,
):
super
().
__init__
(
task
,
sentence_avg
,
label_smoothing
,
ignore_prefix_size
,
report_accuracy
)
self
.
ctc_weight
=
ctc_weight
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
net_output
=
model
(
**
sample
[
"net_input"
])
loss
,
nll_loss
=
self
.
compute_loss
(
model
,
net_output
,
sample
,
reduce
=
reduce
)
ctc_loss
=
torch
.
tensor
(
0.0
).
type_as
(
loss
)
if
self
.
ctc_weight
>
0.0
:
ctc_lprobs
,
ctc_lens
=
model
.
get_ctc_output
(
net_output
,
sample
)
ctc_tgt
,
ctc_tgt_lens
=
model
.
get_ctc_target
(
sample
)
ctc_tgt_mask
=
lengths_to_mask
(
ctc_tgt_lens
)
ctc_tgt_flat
=
ctc_tgt
.
masked_select
(
ctc_tgt_mask
)
reduction
=
"sum"
if
reduce
else
"none"
ctc_loss
=
(
F
.
ctc_loss
(
ctc_lprobs
,
ctc_tgt_flat
,
ctc_lens
,
ctc_tgt_lens
,
reduction
=
reduction
,
zero_infinity
=
True
,
)
*
self
.
ctc_weight
)
loss
+=
ctc_loss
sample_size
=
(
sample
[
"target"
].
size
(
0
)
if
self
.
sentence_avg
else
sample
[
"ntokens"
]
)
logging_output
=
{
"loss"
:
utils
.
item
(
loss
.
data
),
"nll_loss"
:
utils
.
item
(
nll_loss
.
data
),
"ctc_loss"
:
utils
.
item
(
ctc_loss
.
data
),
"ntokens"
:
sample
[
"ntokens"
],
"nsentences"
:
sample
[
"target"
].
size
(
0
),
"sample_size"
:
sample_size
,
}
if
self
.
report_accuracy
:
n_correct
,
total
=
self
.
compute_accuracy
(
model
,
net_output
,
sample
)
logging_output
[
"n_correct"
]
=
utils
.
item
(
n_correct
.
data
)
logging_output
[
"total"
]
=
utils
.
item
(
total
.
data
)
return
loss
,
sample_size
,
logging_output
@
classmethod
def
reduce_metrics
(
cls
,
logging_outputs
)
->
None
:
super
().
reduce_metrics
(
logging_outputs
)
loss_sum
=
sum
(
log
.
get
(
"ctc_loss"
,
0
)
for
log
in
logging_outputs
)
sample_size
=
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"ctc_loss"
,
loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
)
PyTorch/NLP/new-Transformer/fairseq/criterions/legacy_masked_lm.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
import
torch
import
torch.nn.functional
as
F
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
def
compute_cross_entropy_loss
(
logits
,
targets
,
ignore_index
=-
100
):
"""
Function to compute the cross entropy loss. The default value of
ignore_index is the same as the default value for F.cross_entropy in
pytorch.
"""
assert
logits
.
size
(
0
)
==
targets
.
size
(
-
1
),
"Logits and Targets tensor shapes don't match up"
loss
=
F
.
nll_loss
(
F
.
log_softmax
(
logits
,
-
1
,
dtype
=
torch
.
float32
),
targets
,
reduction
=
"sum"
,
ignore_index
=
ignore_index
,
)
return
loss
@
register_criterion
(
"legacy_masked_lm_loss"
)
class
LegacyMaskedLmLoss
(
FairseqCriterion
):
"""
Implementation for the loss used in masked language model (MLM) training.
This optionally also computes the next sentence prediction (NSP) loss and
adds it to the overall loss based on the specified args. There are three
cases to consider:
1) Generic MLM training without NSP loss. In this case sentence_targets
and sentence_logits are both None.
2) BERT training without NSP loss. In this case sentence_targets is
not None but sentence_logits is None and we should not be computing
a sentence level loss.
3) BERT training with NSP loss. In this case both sentence_targets and
sentence_logits are not None and we should be computing a sentence
level loss. The weight of the sentence level loss is specified as
an argument.
"""
def
__init__
(
self
,
task
,
masked_lm_only
,
nsp_loss_weight
):
super
().
__init__
(
task
)
self
.
masked_lm_only
=
masked_lm_only
self
.
nsp_loss_weight
=
nsp_loss_weight
@
staticmethod
def
add_args
(
parser
):
"""Args for MaskedLM Loss"""
# Default for masked_lm_only is False so as to not break BERT training
parser
.
add_argument
(
"--masked-lm-only"
,
default
=
False
,
action
=
"store_true"
,
help
=
"compute MLM loss only"
,
)
parser
.
add_argument
(
"--nsp-loss-weight"
,
default
=
1.0
,
type
=
float
,
help
=
"weight for next sentence prediction"
" loss (default 1)"
,
)
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
lm_logits
,
output_metadata
=
model
(
**
sample
[
"net_input"
])
# reshape lm_logits from (N,T,C) to (N*T,C)
lm_logits
=
lm_logits
.
view
(
-
1
,
lm_logits
.
size
(
-
1
))
lm_targets
=
sample
[
"lm_target"
].
view
(
-
1
)
lm_loss
=
compute_cross_entropy_loss
(
lm_logits
,
lm_targets
,
self
.
padding_idx
)
# compute the number of tokens for which loss is computed. This is used
# to normalize the loss
ntokens
=
utils
.
strip_pad
(
lm_targets
,
self
.
padding_idx
).
numel
()
loss
=
lm_loss
/
ntokens
nsentences
=
sample
[
"nsentences"
]
# nsentences = 0
# Compute sentence loss if masked_lm_only is False
sentence_loss
=
None
if
not
self
.
masked_lm_only
:
sentence_logits
=
output_metadata
[
"sentence_logits"
]
sentence_targets
=
sample
[
"sentence_target"
].
view
(
-
1
)
# This needs to be recomputed due to some differences between
# TokenBlock and BlockPair dataset. This can be resolved with a
# refactor of BERTModel which we will do in the future.
# TODO: Remove this after refactor of BERTModel
nsentences
=
sentence_targets
.
size
(
0
)
# Check for logits being none which can happen when remove_heads
# is set to true in the BERT model. Ideally we should set
# masked_lm_only to true in this case, but that requires some
# refactor in the BERT model.
if
sentence_logits
is
not
None
:
sentence_loss
=
compute_cross_entropy_loss
(
sentence_logits
,
sentence_targets
)
loss
+=
self
.
nsp_loss_weight
*
(
sentence_loss
/
nsentences
)
# NOTE: as we are summing up per token mlm loss and per sentence nsp loss
# we don't need to use sample_size as denominator for the gradient
# here sample_size is just used for logging
sample_size
=
1
logging_output
=
{
"loss"
:
utils
.
item
(
loss
.
data
)
if
reduce
else
loss
.
data
,
"lm_loss"
:
utils
.
item
(
lm_loss
.
data
)
if
reduce
else
lm_loss
.
data
,
# sentence loss is not always computed
"sentence_loss"
:
(
(
utils
.
item
(
sentence_loss
.
data
)
if
reduce
else
sentence_loss
.
data
)
if
sentence_loss
is
not
None
else
0.0
),
"ntokens"
:
ntokens
,
"nsentences"
:
nsentences
,
"sample_size"
:
sample_size
,
}
return
loss
,
sample_size
,
logging_output
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
"""Aggregate logging outputs from data parallel training."""
lm_loss_sum
=
sum
(
log
.
get
(
"lm_loss"
,
0
)
for
log
in
logging_outputs
)
sentence_loss_sum
=
sum
(
log
.
get
(
"sentence_loss"
,
0
)
for
log
in
logging_outputs
)
ntokens
=
sum
(
log
.
get
(
"ntokens"
,
0
)
for
log
in
logging_outputs
)
nsentences
=
sum
(
log
.
get
(
"nsentences"
,
0
)
for
log
in
logging_outputs
)
sample_size
=
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
agg_loss
=
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"loss"
,
agg_loss
/
sample_size
/
math
.
log
(
2
)
if
sample_size
>
0
else
0.0
,
sample_size
,
round
=
3
,
)
metrics
.
log_scalar
(
"lm_loss"
,
lm_loss_sum
/
ntokens
/
math
.
log
(
2
)
if
ntokens
>
0
else
0.0
,
ntokens
,
round
=
3
,
)
metrics
.
log_scalar
(
"sentence_loss"
,
sentence_loss_sum
/
nsentences
/
math
.
log
(
2
)
if
nsentences
>
0
else
0.0
,
nsentences
,
round
=
3
,
)
metrics
.
log_scalar
(
"nll_loss"
,
lm_loss_sum
/
ntokens
/
math
.
log
(
2
)
if
ntokens
>
0
else
0.0
,
ntokens
,
round
=
3
,
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return
True
PyTorch/NLP/new-Transformer/fairseq/criterions/masked_lm.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
dataclasses
import
dataclass
import
math
from
omegaconf
import
II
import
torch
from
fairseq
import
metrics
,
modules
,
utils
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
from
fairseq.dataclass
import
FairseqDataclass
@
dataclass
class
MaskedLmConfig
(
FairseqDataclass
):
tpu
:
bool
=
II
(
"common.tpu"
)
@
register_criterion
(
"masked_lm"
,
dataclass
=
MaskedLmConfig
)
class
MaskedLmLoss
(
FairseqCriterion
):
"""
Implementation for the loss used in masked language model (MLM) training.
"""
def
__init__
(
self
,
cfg
:
MaskedLmConfig
,
task
):
super
().
__init__
(
task
)
self
.
tpu
=
cfg
.
tpu
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
masked_tokens
=
sample
[
"target"
].
ne
(
self
.
padding_idx
)
sample_size
=
masked_tokens
.
int
().
sum
()
# Rare: when all tokens are masked, project all tokens.
# We use torch.where to avoid device-to-host transfers,
# except on CPU where torch.where is not well supported
# (see github.com/pytorch/pytorch/issues/26247).
if
self
.
tpu
:
masked_tokens
=
None
# always project all tokens on TPU
elif
masked_tokens
.
device
==
torch
.
device
(
"cpu"
):
if
not
masked_tokens
.
any
():
masked_tokens
=
None
else
:
masked_tokens
=
torch
.
where
(
masked_tokens
.
any
(),
masked_tokens
,
masked_tokens
.
new
([
True
]),
)
logits
=
model
(
**
sample
[
"net_input"
],
masked_tokens
=
masked_tokens
)[
0
]
targets
=
model
.
get_targets
(
sample
,
[
logits
])
if
masked_tokens
is
not
None
:
targets
=
targets
[
masked_tokens
]
loss
=
modules
.
cross_entropy
(
logits
.
view
(
-
1
,
logits
.
size
(
-
1
)),
targets
.
view
(
-
1
),
reduction
=
"sum"
,
ignore_index
=
self
.
padding_idx
,
)
logging_output
=
{
"loss"
:
loss
if
self
.
tpu
else
loss
.
data
,
"ntokens"
:
sample
[
"ntokens"
],
"nsentences"
:
sample
[
"nsentences"
],
"sample_size"
:
sample_size
,
}
return
loss
,
sample_size
,
logging_output
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
"""Aggregate logging outputs from data parallel training."""
loss_sum
=
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
)
sample_size
=
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"loss"
,
loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
)
metrics
.
log_derived
(
"ppl"
,
lambda
meters
:
utils
.
get_perplexity
(
meters
[
"loss"
].
avg
)
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return
True
PyTorch/NLP/new-Transformer/fairseq/criterions/model_criterion.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
logging
from
dataclasses
import
dataclass
,
field
from
typing
import
Dict
,
List
import
torch
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
from
fairseq.dataclass
import
FairseqDataclass
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
ModelCriterionConfig
(
FairseqDataclass
):
loss_weights
:
Dict
[
str
,
float
]
=
field
(
default_factory
=
dict
,
metadata
=
{
"help"
:
"weights for the loss terms"
},
)
log_keys
:
List
[
str
]
=
field
(
default_factory
=
list
,
metadata
=
{
"help"
:
"additional output keys to log"
},
)
@
register_criterion
(
"model"
,
dataclass
=
ModelCriterionConfig
)
class
ModelCriterion
(
FairseqCriterion
):
"""
This criterion relies on the model to supply losses.
The losses should be a dictionary of name -> scalar returned by
the model either by including it in the net_output dict or by
implementing a get_losses(net_output, sample) method. The final loss is
a scaled sum of all losses according to weights in loss_weights.
If no weights are provided, then all losses are scaled by 1.0.
The losses will be automatically logged. Additional keys from
net_output dict can be logged via the log_keys parameter.
"""
def
__init__
(
self
,
task
,
loss_weights
=
None
,
log_keys
=
None
):
super
().
__init__
(
task
)
self
.
loss_weights
=
loss_weights
self
.
log_keys
=
log_keys
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
net_output
=
model
(
**
sample
[
"net_input"
])
scaled_losses
=
{}
if
hasattr
(
model
,
"get_losses"
):
losses
=
model
.
get_losses
(
net_output
,
sample
)
elif
isinstance
(
net_output
,
dict
)
and
"losses"
in
net_output
:
losses
=
net_output
[
"losses"
]
else
:
raise
Exception
(
"Could not retrieve losses"
)
for
lk
,
p
in
losses
.
items
():
try
:
coef
=
1.0
if
len
(
self
.
loss_weights
)
==
0
else
self
.
loss_weights
[
lk
]
except
KeyError
:
logger
.
error
(
f
"weight for loss
{
lk
}
is not in loss_weights (
{
self
.
loss_weights
}
)"
)
raise
if
coef
!=
0
and
p
is
not
None
:
scaled_losses
[
lk
]
=
coef
*
p
.
float
()
loss
=
sum
(
scaled_losses
.
values
())
if
"sample_size"
in
net_output
:
sample_size
=
net_output
[
"sample_size"
]
else
:
sample_size
=
loss
.
numel
()
if
reduce
and
loss
.
numel
()
>
1
:
loss
=
loss
.
sum
()
logging_output
=
{
"loss"
:
loss
.
data
,
"ntokens"
:
sample_size
,
"nsentences"
:
sample
[
"id"
].
numel
(),
"sample_size"
:
sample_size
,
"_world_size"
:
1
,
}
for
lk
in
self
.
log_keys
:
if
lk
in
net_output
and
net_output
[
lk
]
is
not
None
:
if
not
torch
.
is_tensor
(
net_output
[
lk
])
or
net_output
[
lk
].
numel
()
==
1
:
logging_output
[
lk
]
=
float
(
net_output
[
lk
])
else
:
for
i
,
v
in
enumerate
(
net_output
[
lk
]):
logging_output
[
f
"
{
lk
}
_
{
i
}
"
]
=
float
(
v
)
if
len
(
scaled_losses
)
>
1
:
for
lk
,
l
in
scaled_losses
.
items
():
if
l
.
numel
()
>
1
:
l
=
l
.
sum
()
logging_output
[
f
"loss_
{
lk
}
"
]
=
l
.
item
()
if
"logs"
in
net_output
:
for
lgw
in
net_output
[
"logs"
]:
logging_output
[
lgw
]
=
net_output
[
"logs"
][
lgw
]
return
loss
,
sample_size
,
logging_output
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
"""Aggregate logging outputs from data parallel training."""
loss_sum
=
utils
.
item
(
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
))
ntokens
=
utils
.
item
(
sum
(
log
.
get
(
"ntokens"
,
0
)
for
log
in
logging_outputs
))
nsentences
=
utils
.
item
(
sum
(
log
.
get
(
"nsentences"
,
0
)
for
log
in
logging_outputs
)
)
sample_size
=
utils
.
item
(
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
)
metrics
.
log_scalar
(
"loss"
,
loss_sum
/
sample_size
,
sample_size
,
round
=
3
)
metrics
.
log_scalar
(
"ntokens"
,
ntokens
)
metrics
.
log_scalar
(
"nsentences"
,
nsentences
)
builtin_keys
=
{
"loss"
,
"ntokens"
,
"nsentences"
,
"sample_size"
,
"_world_size"
,
}
world_size
=
utils
.
item
(
sum
(
log
.
get
(
"_world_size"
,
0
)
for
log
in
logging_outputs
)
)
for
k
in
logging_outputs
[
0
]:
if
k
not
in
builtin_keys
:
val
=
sum
(
log
.
get
(
k
,
0
)
for
log
in
logging_outputs
)
if
k
.
startswith
(
"loss_"
):
metrics
.
log_scalar
(
k
,
val
/
sample_size
,
sample_size
,
round
=
3
)
else
:
metrics
.
log_scalar
(
k
,
val
/
world_size
,
round
=
3
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return
True
PyTorch/NLP/new-Transformer/fairseq/criterions/nat_loss.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
import
torch
import
torch.nn.functional
as
F
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
from
fairseq.dataclass
import
FairseqDataclass
from
torch
import
Tensor
from
dataclasses
import
dataclass
,
field
@
dataclass
class
LabelSmoothedDualImitationCriterionConfig
(
FairseqDataclass
):
label_smoothing
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"epsilon for label smoothing, 0 means no label smoothing"
},
)
@
register_criterion
(
"nat_loss"
,
dataclass
=
LabelSmoothedDualImitationCriterionConfig
)
class
LabelSmoothedDualImitationCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
task
,
label_smoothing
):
super
().
__init__
(
task
)
self
.
label_smoothing
=
label_smoothing
def
_compute_loss
(
self
,
outputs
,
targets
,
masks
=
None
,
label_smoothing
=
0.0
,
name
=
"loss"
,
factor
=
1.0
):
"""
outputs: batch x len x d_model
targets: batch x len
masks: batch x len
policy_logprob: if there is some policy
depends on the likelihood score as rewards.
"""
def
mean_ds
(
x
:
Tensor
,
dim
=
None
)
->
Tensor
:
return
(
x
.
float
().
mean
().
type_as
(
x
)
if
dim
is
None
else
x
.
float
().
mean
(
dim
).
type_as
(
x
)
)
if
masks
is
not
None
:
outputs
,
targets
=
outputs
[
masks
],
targets
[
masks
]
if
masks
is
not
None
and
not
masks
.
any
():
nll_loss
=
torch
.
tensor
(
0
)
loss
=
nll_loss
else
:
logits
=
F
.
log_softmax
(
outputs
,
dim
=-
1
)
if
targets
.
dim
()
==
1
:
losses
=
F
.
nll_loss
(
logits
,
targets
.
to
(
logits
.
device
),
reduction
=
"none"
)
else
:
# soft-labels
losses
=
F
.
kl_div
(
logits
,
targets
.
to
(
logits
.
device
),
reduction
=
"none"
)
losses
=
losses
.
sum
(
-
1
)
nll_loss
=
mean_ds
(
losses
)
if
label_smoothing
>
0
:
loss
=
(
nll_loss
*
(
1
-
label_smoothing
)
-
mean_ds
(
logits
)
*
label_smoothing
)
else
:
loss
=
nll_loss
loss
=
loss
*
factor
return
{
"name"
:
name
,
"loss"
:
loss
,
"nll_loss"
:
nll_loss
,
"factor"
:
factor
}
def
_custom_loss
(
self
,
loss
,
name
=
"loss"
,
factor
=
1.0
):
return
{
"name"
:
name
,
"loss"
:
loss
,
"factor"
:
factor
}
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
nsentences
,
ntokens
=
sample
[
"nsentences"
],
sample
[
"ntokens"
]
# B x T
src_tokens
,
src_lengths
=
(
sample
[
"net_input"
][
"src_tokens"
],
sample
[
"net_input"
][
"src_lengths"
],
)
tgt_tokens
,
prev_output_tokens
=
sample
[
"target"
],
sample
[
"prev_target"
]
outputs
=
model
(
src_tokens
,
src_lengths
,
prev_output_tokens
,
tgt_tokens
)
losses
,
nll_loss
=
[],
[]
for
obj
in
outputs
:
if
outputs
[
obj
].
get
(
"loss"
,
None
)
is
None
:
_losses
=
self
.
_compute_loss
(
outputs
[
obj
].
get
(
"out"
),
outputs
[
obj
].
get
(
"tgt"
),
outputs
[
obj
].
get
(
"mask"
,
None
),
outputs
[
obj
].
get
(
"ls"
,
0.0
),
name
=
obj
+
"-loss"
,
factor
=
outputs
[
obj
].
get
(
"factor"
,
1.0
),
)
else
:
_losses
=
self
.
_custom_loss
(
outputs
[
obj
].
get
(
"loss"
),
name
=
obj
+
"-loss"
,
factor
=
outputs
[
obj
].
get
(
"factor"
,
1.0
),
)
losses
+=
[
_losses
]
if
outputs
[
obj
].
get
(
"nll_loss"
,
False
):
nll_loss
+=
[
_losses
.
get
(
"nll_loss"
,
0.0
)]
loss
=
sum
(
l
[
"loss"
]
for
l
in
losses
)
nll_loss
=
sum
(
l
for
l
in
nll_loss
)
if
len
(
nll_loss
)
>
0
else
loss
.
new_tensor
(
0
)
# NOTE:
# we don't need to use sample_size as denominator for the gradient
# here sample_size is just used for logging
sample_size
=
1
logging_output
=
{
"loss"
:
loss
.
data
,
"nll_loss"
:
nll_loss
.
data
,
"ntokens"
:
ntokens
,
"nsentences"
:
nsentences
,
"sample_size"
:
sample_size
,
}
for
l
in
losses
:
logging_output
[
l
[
"name"
]]
=
(
utils
.
item
(
l
[
"loss"
].
data
/
l
[
"factor"
])
if
reduce
else
l
[[
"loss"
]].
data
/
l
[
"factor"
]
)
return
loss
,
sample_size
,
logging_output
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
"""Aggregate logging outputs from data parallel training."""
sample_size
=
utils
.
item
(
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
)
loss
=
utils
.
item
(
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
))
nll_loss
=
utils
.
item
(
sum
(
log
.
get
(
"nll_loss"
,
0
)
for
log
in
logging_outputs
))
metrics
.
log_scalar
(
"loss"
,
loss
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
)
metrics
.
log_scalar
(
"nll_loss"
,
nll_loss
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
)
metrics
.
log_derived
(
"ppl"
,
lambda
meters
:
utils
.
get_perplexity
(
meters
[
"loss"
].
avg
)
)
for
key
in
logging_outputs
[
0
]:
if
key
[
-
5
:]
==
"-loss"
:
val
=
sum
(
log
.
get
(
key
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
key
[:
-
5
],
val
/
sample_size
/
math
.
log
(
2
)
if
sample_size
>
0
else
0.0
,
sample_size
,
round
=
3
,
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return
True
Prev
1
…
4
5
6
7
8
9
10
11
12
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment