Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
cbaf59d4
Unverified
Commit
cbaf59d4
authored
Mar 05, 2018
by
Sergey Edunov
Committed by
GitHub
Mar 05, 2018
Browse files
Merge pull request #116 from facebookresearch/oss-merge-internal
Oss merge internal
parents
56f9ec3c
b03b53b4
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
194 additions
and
123 deletions
+194
-123
fairseq/criterions/label_smoothed_cross_entropy.py
fairseq/criterions/label_smoothed_cross_entropy.py
+10
-40
fairseq/data.py
fairseq/data.py
+40
-30
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+2
-2
fairseq/sequence_scorer.py
fairseq/sequence_scorer.py
+1
-1
generate.py
generate.py
+11
-5
preprocess.py
preprocess.py
+32
-19
tests/test_label_smoothing.py
tests/test_label_smoothing.py
+84
-17
tests/utils.py
tests/utils.py
+14
-9
No files found.
fairseq/criterions/label_smoothed_cross_entropy.py
View file @
cbaf59d4
...
...
@@ -7,7 +7,6 @@
import
math
import
torch
from
torch.autograd
import
Variable
import
torch.nn.functional
as
F
from
fairseq
import
utils
...
...
@@ -15,41 +14,6 @@ from fairseq import utils
from
.
import
FairseqCriterion
,
register_criterion
class
LabelSmoothedNLLLoss
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
target
,
eps
,
padding_idx
,
weights
,
reduce
=
True
):
grad_input
=
input
.
new
(
input
.
size
()).
zero_
()
target
=
target
.
view
(
target
.
size
(
0
),
1
)
grad_input
=
grad_input
.
scatter_
(
grad_input
.
dim
()
-
1
,
target
,
eps
-
1
)
norm
=
grad_input
.
size
(
-
1
)
if
weights
is
not
None
:
if
isinstance
(
grad_input
,
Variable
)
and
not
isinstance
(
weights
,
Variable
):
weights
=
Variable
(
weights
,
requires_grad
=
False
)
norm
=
weights
.
sum
()
grad_input
.
mul
(
weights
.
view
(
1
,
weights
.
size
(
0
)).
expand_as
(
grad_input
))
if
padding_idx
is
not
None
:
norm
-=
1
if
weights
is
None
else
weights
[
padding_idx
]
grad_input
.
select
(
grad_input
.
dim
()
-
1
,
padding_idx
).
fill_
(
0
)
grad_input
=
grad_input
.
add
(
-
eps
/
norm
)
ctx
.
grad_input
=
grad_input
if
reduce
:
return
input
.
new
([
grad_input
.
view
(
-
1
).
dot
(
input
.
view
(
-
1
))])
else
:
return
grad_input
*
input
@
staticmethod
def
backward
(
ctx
,
grad
):
grad_input
=
ctx
.
grad_input
if
not
isinstance
(
grad_input
,
torch
.
autograd
.
Variable
):
grad_input
=
utils
.
volatile_variable
(
grad_input
)
return
grad_input
*
grad
,
None
,
None
,
None
,
None
,
None
@
register_criterion
(
'label_smoothed_cross_entropy'
)
class
LabelSmoothedCrossEntropyCriterion
(
FairseqCriterion
):
...
...
@@ -73,10 +37,16 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
"""
net_output
=
model
(
**
sample
[
'net_input'
])
lprobs
=
model
.
get_normalized_probs
(
net_output
,
log_probs
=
True
)
lprobs
=
lprobs
.
view
(
-
1
,
lprobs
.
size
(
-
1
))
target
=
sample
[
'target'
].
view
(
-
1
)
loss
=
LabelSmoothedNLLLoss
.
apply
(
lprobs
,
target
,
self
.
eps
,
self
.
padding_idx
,
None
,
reduce
)
nll_loss
=
F
.
nll_loss
(
lprobs
,
target
,
size_average
=
False
,
ignore_index
=
self
.
padding_idx
,
reduce
=
reduce
)
target
=
sample
[
'target'
].
unsqueeze
(
-
1
)
non_pad_mask
=
target
.
ne
(
self
.
padding_idx
)
nll_loss
=
-
lprobs
.
gather
(
dim
=-
1
,
index
=
target
)[
non_pad_mask
]
smooth_loss
=
-
lprobs
.
sum
(
dim
=-
1
,
keepdim
=
True
)[
non_pad_mask
]
if
reduce
:
nll_loss
=
nll_loss
.
sum
()
smooth_loss
=
smooth_loss
.
sum
()
eps_i
=
self
.
eps
/
lprobs
.
size
(
-
1
)
loss
=
(
1.
-
self
.
eps
)
*
nll_loss
+
eps_i
*
smooth_loss
sample_size
=
sample
[
'target'
].
size
(
0
)
if
self
.
args
.
sentence_avg
else
sample
[
'ntokens'
]
logging_output
=
{
'loss'
:
utils
.
item
(
loss
.
data
)
if
reduce
else
loss
.
data
,
...
...
fairseq/data.py
View file @
cbaf59d4
...
...
@@ -57,17 +57,17 @@ def load_dataset(path, load_splits, src=None, dst=None):
dataset
=
LanguageDatasets
(
src
,
dst
,
src_dict
,
dst_dict
)
# Load dataset from binary files
def
all_splits_exist
(
src
,
dst
):
def
all_splits_exist
(
src
,
dst
,
lang
):
for
split
in
load_splits
:
filename
=
'{0}.{1}-{2}.{
1
}.idx'
.
format
(
split
,
src
,
dst
)
filename
=
'{0}.{1}-{2}.{
3
}.idx'
.
format
(
split
,
src
,
dst
,
lang
)
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
path
,
filename
)):
return
False
return
True
# infer langcode
if
all_splits_exist
(
src
,
dst
):
if
all_splits_exist
(
src
,
dst
,
src
):
langcode
=
'{}-{}'
.
format
(
src
,
dst
)
elif
all_splits_exist
(
dst
,
src
):
elif
all_splits_exist
(
dst
,
src
,
src
):
langcode
=
'{}-{}'
.
format
(
dst
,
src
)
else
:
raise
Exception
(
'Dataset cannot be loaded from path: '
+
path
)
...
...
@@ -84,9 +84,13 @@ def load_dataset(path, load_splits, src=None, dst=None):
if
not
IndexedInMemoryDataset
.
exists
(
src_path
):
break
target_dataset
=
None
if
IndexedInMemoryDataset
.
exists
(
dst_path
):
target_dataset
=
IndexedInMemoryDataset
(
dst_path
)
dataset
.
splits
[
prefix
]
=
LanguagePairDataset
(
IndexedInMemoryDataset
(
src_path
),
IndexedInMemoryDataset
(
dst_path
)
,
target_dataset
,
pad_idx
=
dataset
.
src_dict
.
pad
(),
eos_idx
=
dataset
.
src_dict
.
eos
(),
)
...
...
@@ -194,21 +198,20 @@ class LanguagePairDataset(torch.utils.data.Dataset):
def
__getitem__
(
self
,
i
):
# subtract 1 for 0-based indexing
source
=
self
.
src
[
i
].
long
()
-
1
target
=
self
.
dst
[
i
].
long
()
-
1
return
{
'id'
:
i
,
'source'
:
source
,
'target'
:
target
,
}
res
=
{
'id'
:
i
,
'source'
:
source
}
if
self
.
dst
:
res
[
'target'
]
=
self
.
dst
[
i
].
long
()
-
1
return
res
def
__len__
(
self
):
return
len
(
self
.
src
)
def
collater
(
self
,
samples
):
return
LanguagePairDataset
.
collate
(
samples
,
self
.
pad_idx
,
self
.
eos_idx
)
return
LanguagePairDataset
.
collate
(
samples
,
self
.
pad_idx
,
self
.
eos_idx
,
self
.
dst
is
not
None
)
@
staticmethod
def
collate
(
samples
,
pad_idx
,
eos_idx
):
def
collate
(
samples
,
pad_idx
,
eos_idx
,
has_target
=
True
):
if
len
(
samples
)
==
0
:
return
{}
...
...
@@ -220,26 +223,31 @@ class LanguagePairDataset(torch.utils.data.Dataset):
id
=
torch
.
LongTensor
([
s
[
'id'
]
for
s
in
samples
])
src_tokens
=
merge
(
'source'
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_SOURCE
)
target
=
merge
(
'target'
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_TARGET
)
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens
=
merge
(
'target'
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_TARGET
,
move_eos_to_beginning
=
True
,
)
# sort by descending source length
src_lengths
=
torch
.
LongTensor
([
s
[
'source'
].
numel
()
for
s
in
samples
])
src_lengths
,
sort_order
=
src_lengths
.
sort
(
descending
=
True
)
id
=
id
.
index_select
(
0
,
sort_order
)
src_tokens
=
src_tokens
.
index_select
(
0
,
sort_order
)
prev_output_tokens
=
prev_output_tokens
.
index_select
(
0
,
sort_order
)
target
=
target
.
index_select
(
0
,
sort_order
)
prev_output_tokens
=
None
target
=
None
ntokens
=
None
if
has_target
:
target
=
merge
(
'target'
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_TARGET
)
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens
=
merge
(
'target'
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_TARGET
,
move_eos_to_beginning
=
True
,
)
prev_output_tokens
=
prev_output_tokens
.
index_select
(
0
,
sort_order
)
target
=
target
.
index_select
(
0
,
sort_order
)
ntokens
=
sum
(
len
(
s
[
'target'
])
for
s
in
samples
)
return
{
'id'
:
id
,
'ntokens'
:
sum
(
len
(
s
[
'target'
])
for
s
in
samples
)
,
'ntokens'
:
ntokens
,
'net_input'
:
{
'src_tokens'
:
src_tokens
,
'src_lengths'
:
src_lengths
,
...
...
@@ -301,21 +309,23 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
sample_len
=
0
ignored
=
[]
for
idx
in
map
(
int
,
indices
):
if
not
_valid_size
(
src
.
sizes
[
idx
],
dst
.
sizes
[
idx
],
max_positions
):
src_size
=
src
.
sizes
[
idx
]
dst_size
=
dst
.
sizes
[
idx
]
if
dst
else
src_size
if
not
_valid_size
(
src_size
,
dst_size
,
max_positions
):
if
ignore_invalid_inputs
:
ignored
.
append
(
idx
)
continue
raise
Exception
((
"Sample #{} has size (src={}, dst={}) but max size is {}."
" Skip this example with --skip-invalid-size-inputs-valid-test"
).
format
(
idx
,
src
.
size
s
[
idx
]
,
dst
.
size
s
[
idx
]
,
max_positions
))
).
format
(
idx
,
src
_
size
,
dst
_
size
,
max_positions
))
sample_len
=
max
(
sample_len
,
src
.
size
s
[
idx
]
,
dst
.
size
s
[
idx
]
)
sample_len
=
max
(
sample_len
,
src
_
size
,
dst
_
size
)
num_tokens
=
(
len
(
batch
)
+
1
)
*
sample_len
if
yield_batch
(
idx
,
num_tokens
):
yield
batch
batch
=
[]
sample_len
=
max
(
src
.
size
s
[
idx
]
,
dst
.
size
s
[
idx
]
)
sample_len
=
max
(
src
_
size
,
dst
_
size
)
batch
.
append
(
idx
)
...
...
@@ -332,7 +342,7 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
descending
=
False
):
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
assert
isinstance
(
src
,
IndexedDataset
)
and
isinstance
(
dst
,
IndexedDataset
)
assert
isinstance
(
src
,
IndexedDataset
)
and
(
dst
is
None
or
isinstance
(
dst
,
IndexedDataset
)
)
if
max_tokens
is
None
:
max_tokens
=
float
(
'Inf'
)
if
max_sentences
is
None
:
...
...
fairseq/sequence_generator.py
View file @
cbaf59d4
...
...
@@ -77,11 +77,11 @@ class SequenceGenerator(object):
prefix_tokens
=
s
[
'target'
][:,
:
prefix_size
]
if
prefix_size
>
0
else
None
,
)
if
timer
is
not
None
:
timer
.
stop
(
s
[
'ntokens'
]
)
timer
.
stop
(
s
um
([
len
(
h
[
0
][
'tokens'
])
for
h
in
hypos
])
)
for
i
,
id
in
enumerate
(
s
[
'id'
].
data
):
src
=
input
[
'src_tokens'
].
data
[
i
,
:]
# remove padding from ref
ref
=
utils
.
strip_pad
(
s
[
'target'
].
data
[
i
,
:],
self
.
pad
)
ref
=
utils
.
strip_pad
(
s
[
'target'
].
data
[
i
,
:],
self
.
pad
)
if
s
[
'target'
]
is
not
None
else
None
yield
id
,
src
,
ref
,
hypos
[
i
]
def
generate
(
self
,
src_tokens
,
src_lengths
,
beam_size
=
None
,
maxlen
=
None
,
prefix_tokens
=
None
):
...
...
fairseq/sequence_scorer.py
View file @
cbaf59d4
...
...
@@ -46,7 +46,7 @@ class SequenceScorer(object):
'alignment'
:
alignment
,
'positional_scores'
:
pos_scores_i
,
}]
# return results in the same format as SequenceGene
ne
rator
# return results in the same format as SequenceGenerator
yield
id
,
src
,
ref
,
hypos
def
score
(
self
,
sample
):
...
...
generate.py
View file @
cbaf59d4
...
...
@@ -84,6 +84,7 @@ def main(args):
# Generate and compute BLEU score
scorer
=
bleu
.
Scorer
(
dataset
.
dst_dict
.
pad
(),
dataset
.
dst_dict
.
eos
(),
dataset
.
dst_dict
.
unk
())
num_sentences
=
0
has_target
=
True
with
progress_bar
.
build_progress_bar
(
args
,
itr
)
as
t
:
if
args
.
score_reference
:
translations
=
translator
.
score_batched_itr
(
t
,
cuda
=
use_cuda
,
timer
=
gen_timer
)
...
...
@@ -94,18 +95,22 @@ def main(args):
wps_meter
=
TimeMeter
()
for
sample_id
,
src_tokens
,
target_tokens
,
hypos
in
translations
:
# Process input and ground truth
target_tokens
=
target_tokens
.
int
().
cpu
()
has_target
=
target_tokens
is
not
None
target_tokens
=
target_tokens
.
int
().
cpu
()
if
has_target
else
None
# Either retrieve the original sentences or regenerate them from tokens.
if
align_dict
is
not
None
:
src_str
=
dataset
.
splits
[
args
.
gen_subset
].
src
.
get_original_text
(
sample_id
)
target_str
=
dataset
.
splits
[
args
.
gen_subset
].
dst
.
get_original_text
(
sample_id
)
else
:
src_str
=
dataset
.
src_dict
.
string
(
src_tokens
,
args
.
remove_bpe
)
target_str
=
dataset
.
dst_dict
.
string
(
target_tokens
,
args
.
remove_bpe
,
escape_unk
=
True
)
target_str
=
dataset
.
dst_dict
.
string
(
target_tokens
,
args
.
remove_bpe
,
escape_unk
=
True
)
if
has_target
else
''
if
not
args
.
quiet
:
print
(
'S-{}
\t
{}'
.
format
(
sample_id
,
src_str
))
print
(
'T-{}
\t
{}'
.
format
(
sample_id
,
target_str
))
if
has_target
:
print
(
'T-{}
\t
{}'
.
format
(
sample_id
,
target_str
))
# Process top predictions
for
i
,
hypo
in
enumerate
(
hypos
[:
min
(
len
(
hypos
),
args
.
nbest
)]):
...
...
@@ -133,7 +138,7 @@ def main(args):
))
# Score only the top hypothesis
if
i
==
0
:
if
has_target
and
i
==
0
:
if
align_dict
is
not
None
or
args
.
remove_bpe
is
not
None
:
# Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens
=
tokenizer
.
Tokenizer
.
tokenize
(
...
...
@@ -146,7 +151,8 @@ def main(args):
print
(
'| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'
.
format
(
num_sentences
,
gen_timer
.
n
,
gen_timer
.
sum
,
1.
/
gen_timer
.
avg
))
print
(
'| Generate {} with beam={}: {}'
.
format
(
args
.
gen_subset
,
args
.
beam
,
scorer
.
result_string
()))
if
has_target
:
print
(
'| Generate {} with beam={}: {}'
.
format
(
args
.
gen_subset
,
args
.
beam
,
scorer
.
result_string
()))
if
__name__
==
'__main__'
:
...
...
preprocess.py
View file @
cbaf59d4
...
...
@@ -21,9 +21,9 @@ def get_parser():
description
=
'Data pre-processing: Create dictionary and store data in binary format'
)
parser
.
add_argument
(
'-s'
,
'--source-lang'
,
default
=
None
,
metavar
=
'SRC'
,
help
=
'source language'
)
parser
.
add_argument
(
'-t'
,
'--target-lang'
,
default
=
None
,
metavar
=
'TARGET'
,
help
=
'target language'
)
parser
.
add_argument
(
'--trainpref'
,
metavar
=
'FP'
,
default
=
'train'
,
help
=
'target language'
)
parser
.
add_argument
(
'--validpref'
,
metavar
=
'FP'
,
default
=
'valid'
,
help
=
'comma separated, valid language prefixes'
)
parser
.
add_argument
(
'--testpref'
,
metavar
=
'FP'
,
default
=
'test'
,
help
=
'comma separated, test language prefixes'
)
parser
.
add_argument
(
'--trainpref'
,
metavar
=
'FP'
,
default
=
None
,
help
=
'target language'
)
parser
.
add_argument
(
'--validpref'
,
metavar
=
'FP'
,
default
=
None
,
help
=
'comma separated, valid language prefixes'
)
parser
.
add_argument
(
'--testpref'
,
metavar
=
'FP'
,
default
=
None
,
help
=
'comma separated, test language prefixes'
)
parser
.
add_argument
(
'--destdir'
,
metavar
=
'DIR'
,
default
=
'data-bin'
,
help
=
'destination dir'
)
parser
.
add_argument
(
'--thresholdtgt'
,
metavar
=
'N'
,
default
=
0
,
type
=
int
,
help
=
'map words appearing less than threshold times to unknown'
)
...
...
@@ -37,12 +37,14 @@ def get_parser():
parser
.
add_argument
(
'--output-format'
,
metavar
=
'FORMAT'
,
default
=
'binary'
,
choices
=
[
'binary'
,
'raw'
],
help
=
'output format (optional)'
)
parser
.
add_argument
(
'--joined-dictionary'
,
action
=
'store_true'
,
help
=
'Generate joined dictionary'
)
parser
.
add_argument
(
'--only-source'
,
action
=
'store_true'
,
help
=
'Only process the source language'
)
return
parser
def
main
(
args
):
print
(
args
)
os
.
makedirs
(
args
.
destdir
,
exist_ok
=
True
)
target
=
not
args
.
only_source
if
args
.
joined_dictionary
:
assert
not
args
.
srcdict
,
'cannot combine --srcdict and --joined-dictionary'
...
...
@@ -60,16 +62,20 @@ def main(args):
if
args
.
srcdict
:
src_dict
=
dictionary
.
Dictionary
.
load
(
args
.
srcdict
)
else
:
assert
args
.
trainpref
,
"--trainpref must be set if --srcdict is not specified"
src_dict
=
Tokenizer
.
build_dictionary
(
filename
=
'{}.{}'
.
format
(
args
.
trainpref
,
args
.
source_lang
))
if
args
.
tgtdict
:
tgt_dict
=
dictionary
.
Dictionary
.
load
(
args
.
tgtdict
)
else
:
tgt_dict
=
Tokenizer
.
build_dictionary
(
filename
=
'{}.{}'
.
format
(
args
.
trainpref
,
args
.
target_lang
))
if
target
:
if
args
.
tgtdict
:
tgt_dict
=
dictionary
.
Dictionary
.
load
(
args
.
tgtdict
)
else
:
assert
args
.
trainpref
,
"--trainpref must be set if --tgtdict is not specified"
tgt_dict
=
Tokenizer
.
build_dictionary
(
filename
=
'{}.{}'
.
format
(
args
.
trainpref
,
args
.
target_lang
))
src_dict
.
save
(
os
.
path
.
join
(
args
.
destdir
,
'dict.{}.txt'
.
format
(
args
.
source_lang
)),
threshold
=
args
.
thresholdsrc
,
nwords
=
args
.
nwordssrc
)
tgt_dict
.
save
(
os
.
path
.
join
(
args
.
destdir
,
'dict.{}.txt'
.
format
(
args
.
target_lang
)),
threshold
=
args
.
thresholdtgt
,
nwords
=
args
.
nwordstgt
)
if
target
:
tgt_dict
.
save
(
os
.
path
.
join
(
args
.
destdir
,
'dict.{}.txt'
.
format
(
args
.
target_lang
)),
threshold
=
args
.
thresholdtgt
,
nwords
=
args
.
nwordstgt
)
def
make_binary_dataset
(
input_prefix
,
output_prefix
,
lang
):
dict
=
dictionary
.
Dictionary
.
load
(
os
.
path
.
join
(
args
.
destdir
,
'dict.{}.txt'
.
format
(
lang
)))
...
...
@@ -100,19 +106,26 @@ def main(args):
output_text_file
=
os
.
path
.
join
(
args
.
destdir
,
'{}.{}'
.
format
(
output_prefix
,
lang
))
shutil
.
copyfile
(
'{}.{}'
.
format
(
input_prefix
,
lang
),
output_text_file
)
make_dataset
(
args
.
trainpref
,
'train'
,
args
.
source_lang
,
args
.
output_format
)
make_dataset
(
args
.
trainpref
,
'train'
,
args
.
target_lang
,
args
.
output_format
)
for
k
,
validpref
in
enumerate
(
args
.
validpref
.
split
(
','
)):
outprefix
=
'valid{}'
.
format
(
k
)
if
k
>
0
else
'valid'
make_dataset
(
validpref
,
outprefix
,
args
.
source_lang
,
args
.
output_format
)
make_dataset
(
validpref
,
outprefix
,
args
.
target_lang
,
args
.
output_format
)
for
k
,
testpref
in
enumerate
(
args
.
testpref
.
split
(
','
)):
outprefix
=
'test{}'
.
format
(
k
)
if
k
>
0
else
'test'
make_dataset
(
testpref
,
outprefix
,
args
.
source_lang
,
args
.
output_format
)
make_dataset
(
testpref
,
outprefix
,
args
.
target_lang
,
args
.
output_format
)
def
make_all
(
args
,
make_dataset
,
lang
):
if
args
.
trainpref
:
make_dataset
(
args
.
trainpref
,
'train'
,
lang
,
args
.
output_format
)
if
args
.
validpref
:
for
k
,
validpref
in
enumerate
(
args
.
validpref
.
split
(
','
)):
outprefix
=
'valid{}'
.
format
(
k
)
if
k
>
0
else
'valid'
make_dataset
(
validpref
,
outprefix
,
lang
,
args
.
output_format
)
if
args
.
testpref
:
for
k
,
testpref
in
enumerate
(
args
.
testpref
.
split
(
','
)):
outprefix
=
'test{}'
.
format
(
k
)
if
k
>
0
else
'test'
make_dataset
(
testpref
,
outprefix
,
lang
,
args
.
output_format
)
make_all
(
args
,
make_dataset
,
args
.
source_lang
)
if
target
:
make_all
(
args
,
make_dataset
,
args
.
target_lang
)
print
(
'| Wrote preprocessed data to {}'
.
format
(
args
.
destdir
))
if
args
.
alignfile
:
assert
args
.
trainpref
,
"--trainpref must be set if --alignfile is specified"
src_file_name
=
'{}.{}'
.
format
(
args
.
trainpref
,
args
.
source_lang
)
tgt_file_name
=
'{}.{}'
.
format
(
args
.
trainpref
,
args
.
target_lang
)
src_dict
=
dictionary
.
Dictionary
.
load
(
os
.
path
.
join
(
args
.
destdir
,
'dict.{}.txt'
.
format
(
args
.
source_lang
)))
...
...
tests/test_label_smoothing.py
View file @
cbaf59d4
...
...
@@ -4,31 +4,98 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import
torch
import
argparse
import
copy
import
unittest
from
fairseq.criterions.label_smoothed_cross_entropy
import
LabelSmoothedNLLLoss
from
torch.autograd
import
Variable
,
gradcheck
import
torch
from
torch.autograd
import
Variable
from
fairseq
import
utils
from
fairseq.criterions.cross_entropy
import
CrossEntropyCriterion
from
fairseq.criterions.label_smoothed_cross_entropy
import
LabelSmoothedCrossEntropyCriterion
torch
.
set_default_tensor_type
(
'torch.DoubleTensor'
)
import
tests.utils
as
test_utils
class
TestLabelSmoothing
(
unittest
.
TestCase
):
def
test_label_smoothing
(
self
):
input
=
Variable
(
torch
.
randn
(
3
,
5
),
requires_grad
=
True
)
idx
=
torch
.
rand
(
3
)
*
4
target
=
Variable
(
idx
.
long
())
criterion
=
LabelSmoothedNLLLoss
()
self
.
assertTrue
(
gradcheck
(
lambda
x
,
y
:
criterion
.
apply
(
x
,
y
,
0.1
,
2
,
None
),
(
input
,
target
)
))
weights
=
torch
.
ones
(
5
)
weights
[
2
]
=
0
self
.
assertTrue
(
gradcheck
(
lambda
x
,
y
:
criterion
.
apply
(
x
,
y
,
0.1
,
None
,
weights
),
(
input
,
target
)))
self
.
assertTrue
(
gradcheck
(
lambda
x
,
y
:
criterion
.
apply
(
x
,
y
,
0.1
,
None
,
None
),
(
input
,
target
)))
def
setUp
(
self
):
# build dictionary
self
.
d
=
test_utils
.
dummy_dictionary
(
3
)
vocab
=
len
(
self
.
d
)
self
.
assertEqual
(
vocab
,
4
+
3
)
# 4 special + 3 tokens
self
.
assertEqual
(
self
.
d
.
pad
(),
1
)
self
.
assertEqual
(
self
.
d
.
eos
(),
2
)
self
.
assertEqual
(
self
.
d
.
unk
(),
3
)
pad
,
eos
,
unk
,
w1
,
w2
,
w3
=
1
,
2
,
3
,
4
,
5
,
6
# build dataset
self
.
data
=
[
# the first batch item has padding
{
'source'
:
torch
.
LongTensor
([
w1
,
eos
]),
'target'
:
torch
.
LongTensor
([
w1
,
eos
])},
{
'source'
:
torch
.
LongTensor
([
w1
,
eos
]),
'target'
:
torch
.
LongTensor
([
w1
,
w1
,
eos
])},
]
self
.
sample
=
next
(
test_utils
.
dummy_dataloader
(
self
.
data
))
# build model
self
.
args
=
argparse
.
Namespace
()
self
.
args
.
sentence_avg
=
False
self
.
args
.
probs
=
torch
.
FloatTensor
([
# pad eos unk w1 w2 w3
[
0.05
,
0.05
,
0.1
,
0.05
,
0.3
,
0.4
,
0.05
],
[
0.05
,
0.10
,
0.2
,
0.05
,
0.2
,
0.3
,
0.10
],
[
0.05
,
0.15
,
0.3
,
0.05
,
0.1
,
0.2
,
0.15
],
]).
unsqueeze
(
0
).
expand
(
2
,
3
,
7
)
# add batch dimension
self
.
model
=
test_utils
.
TestModel
.
build_model
(
self
.
args
,
self
.
d
,
self
.
d
)
def
test_nll_loss
(
self
):
self
.
args
.
label_smoothing
=
0.1
nll_crit
=
CrossEntropyCriterion
(
self
.
args
,
self
.
d
,
self
.
d
)
smooth_crit
=
LabelSmoothedCrossEntropyCriterion
(
self
.
args
,
self
.
d
,
self
.
d
)
nll_loss
,
nll_sample_size
,
nll_logging_output
=
nll_crit
(
self
.
model
,
self
.
sample
)
smooth_loss
,
smooth_sample_size
,
smooth_logging_output
=
smooth_crit
(
self
.
model
,
self
.
sample
)
self
.
assertLess
(
abs
(
nll_loss
-
nll_logging_output
[
'loss'
]),
1e-6
)
self
.
assertLess
(
abs
(
nll_loss
-
smooth_logging_output
[
'nll_loss'
]),
1e-6
)
def
test_padding
(
self
):
self
.
args
.
label_smoothing
=
0.1
crit
=
LabelSmoothedCrossEntropyCriterion
(
self
.
args
,
self
.
d
,
self
.
d
)
loss
,
_
,
logging_output
=
crit
(
self
.
model
,
self
.
sample
)
def
get_one_no_padding
(
idx
):
# create a new sample with just a single batch item so that there's
# no padding
sample1
=
next
(
test_utils
.
dummy_dataloader
([
self
.
data
[
idx
]]))
args1
=
copy
.
copy
(
self
.
args
)
args1
.
probs
=
args1
.
probs
[
idx
,
:,
:].
unsqueeze
(
0
)
model1
=
test_utils
.
TestModel
.
build_model
(
args1
,
self
.
d
,
self
.
d
)
loss1
,
_
,
_
=
crit
(
model1
,
sample1
)
return
loss1
loss1
=
get_one_no_padding
(
0
)
loss2
=
get_one_no_padding
(
1
)
self
.
assertAlmostEqual
(
loss
,
loss1
+
loss2
)
def
test_reduction
(
self
):
self
.
args
.
label_smoothing
=
0.1
crit
=
LabelSmoothedCrossEntropyCriterion
(
self
.
args
,
self
.
d
,
self
.
d
)
loss
,
_
,
logging_output
=
crit
(
self
.
model
,
self
.
sample
,
reduce
=
True
)
unreduced_loss
,
_
,
_
=
crit
(
self
.
model
,
self
.
sample
,
reduce
=
False
)
self
.
assertAlmostEqual
(
loss
,
unreduced_loss
.
sum
())
def
test_zero_eps
(
self
):
self
.
args
.
label_smoothing
=
0.0
nll_crit
=
CrossEntropyCriterion
(
self
.
args
,
self
.
d
,
self
.
d
)
smooth_crit
=
LabelSmoothedCrossEntropyCriterion
(
self
.
args
,
self
.
d
,
self
.
d
)
nll_loss
,
nll_sample_size
,
nll_logging_output
=
nll_crit
(
self
.
model
,
self
.
sample
)
smooth_loss
,
smooth_sample_size
,
smooth_logging_output
=
smooth_crit
(
self
.
model
,
self
.
sample
)
self
.
assertAlmostEqual
(
nll_loss
,
smooth_loss
)
def
assertAlmostEqual
(
self
,
t1
,
t2
):
self
.
assertEqual
(
t1
.
size
(),
t2
.
size
(),
"size mismatch"
)
self
.
assertLess
((
t1
-
t2
).
abs
().
max
(),
1e-6
)
if
__name__
==
'__main__'
:
...
...
tests/utils.py
View file @
cbaf59d4
...
...
@@ -92,7 +92,7 @@ class TestEncoder(FairseqEncoder):
class
TestIncrementalDecoder
(
FairseqIncrementalDecoder
):
def
__init__
(
self
,
args
,
dictionary
):
super
().
__init__
(
dictionary
)
assert
hasattr
(
args
,
'beam_probs'
)
assert
hasattr
(
args
,
'beam_probs'
)
or
hasattr
(
args
,
'probs'
)
args
.
max_decoder_positions
=
getattr
(
args
,
'max_decoder_positions'
,
100
)
self
.
args
=
args
...
...
@@ -116,14 +116,19 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder):
steps
=
list
(
range
(
tgt_len
))
# define output in terms of raw probs
probs
=
torch
.
FloatTensor
(
bbsz
,
len
(
steps
),
vocab
).
zero_
()
for
i
,
step
in
enumerate
(
steps
):
# args.beam_probs gives the probability for every vocab element,
# starting with eos, then unknown, and then the rest of the vocab
if
step
<
len
(
self
.
args
.
beam_probs
):
probs
[:,
i
,
self
.
dictionary
.
eos
():]
=
self
.
args
.
beam_probs
[
step
]
else
:
probs
[:,
i
,
self
.
dictionary
.
eos
()]
=
1.0
if
hasattr
(
self
.
args
,
'probs'
):
assert
self
.
args
.
probs
.
dim
()
==
3
,
\
'expected probs to have size bsz*steps*vocab'
probs
=
self
.
args
.
probs
.
index_select
(
1
,
torch
.
LongTensor
(
steps
))
else
:
probs
=
torch
.
FloatTensor
(
bbsz
,
len
(
steps
),
vocab
).
zero_
()
for
i
,
step
in
enumerate
(
steps
):
# args.beam_probs gives the probability for every vocab element,
# starting with eos, then unknown, and then the rest of the vocab
if
step
<
len
(
self
.
args
.
beam_probs
):
probs
[:,
i
,
self
.
dictionary
.
eos
():]
=
self
.
args
.
beam_probs
[
step
]
else
:
probs
[:,
i
,
self
.
dictionary
.
eos
()]
=
1.0
# random attention
attn
=
torch
.
rand
(
bbsz
,
src_len
,
tgt_len
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment