Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
f8377a70
Commit
f8377a70
authored
Sep 30, 2018
by
myleott
Browse files
fbshipit-source-id: 6a835d32f9dc5e0de118f1b46d365d0e0cc85e11
parent
864b89d0
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
236 additions
and
15 deletions
+236
-15
fairseq/data/noising.py
fairseq/data/noising.py
+126
-0
fairseq/optim/fp16_optimizer.py
fairseq/optim/fp16_optimizer.py
+3
-1
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+4
-8
fairseq/trainer.py
fairseq/trainer.py
+9
-6
tests/test_noising.py
tests/test_noising.py
+94
-0
No files found.
fairseq/data/noising.py
0 → 100644
View file @
f8377a70
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
torch
import
numpy
as
np
class
WordNoising
(
object
):
"""Generate a noisy version of a sentence, without changing words themselves."""
def
__init__
(
self
,
dictionary
,
bpe_cont_marker
=
"@@"
):
self
.
dictionary
=
dictionary
self
.
bpe_end
=
np
.
array
([
not
self
.
dictionary
[
i
].
endswith
(
bpe_cont_marker
)
for
i
in
range
(
len
(
self
.
dictionary
))
])
def
noising
(
self
,
x
,
lengths
,
noising_prob
=
0.0
):
raise
NotImplementedError
()
def
_get_bpe_word_idx
(
self
,
x
):
# x: (T x B)
bpe_end
=
self
.
bpe_end
[
x
]
# do a reduce front sum to generate word ids
word_idx
=
bpe_end
[::
-
1
].
cumsum
(
0
)[::
-
1
]
word_idx
=
word_idx
.
max
(
0
)[
None
,
:]
-
word_idx
return
word_idx
class
WordDropout
(
WordNoising
):
"""Randomly drop input words. If not passing blank_idx (default is None),
then dropped words will be removed. Otherwise, it will be replaced by the
blank_idx."""
def
__init__
(
self
,
dictionary
):
super
().
__init__
(
dictionary
)
def
noising
(
self
,
x
,
lengths
,
dropout_prob
=
0.1
,
blank_idx
=
None
):
# x: (T x B), lengths: B
if
dropout_prob
==
0
:
return
x
,
lengths
assert
0
<
dropout_prob
<
1
# be sure to drop entire words
word_idx
=
self
.
_get_bpe_word_idx
(
x
)
sentences
=
[]
modified_lengths
=
[]
for
i
in
range
(
lengths
.
size
(
0
)):
# Since dropout probabilities need to apply over non-pad tokens,
# it is not trivial to generate the keep mask without consider
# input lengths; otherwise, this could be done outside the loop
keep
=
np
.
random
.
rand
(
lengths
[
i
]
-
1
)
>=
dropout_prob
# ith example: [x0, x1, ..., eos, pad, ..., pad]
assert
x
[
lengths
[
i
]
-
1
,
i
]
==
self
.
dictionary
.
eos
()
words
=
x
[:
lengths
[
i
],
i
].
tolist
()
# TODO: speed up the following loop
# drop words from the input according to keep
new_s
=
[
w
if
keep
[
word_idx
[
j
,
i
]]
else
blank_idx
for
j
,
w
in
enumerate
(
words
)
]
new_s
=
[
w
for
w
in
new_s
if
w
is
not
None
]
# we need to have at least one word in the sentence (more than the
# start / end sentence symbols)
if
len
(
new_s
)
==
1
:
new_s
.
append
(
words
[
np
.
random
.
randint
(
0
,
len
(
words
))])
assert
(
len
(
new_s
)
>=
2
and
new_s
[
-
1
]
==
self
.
dictionary
.
eos
()
),
"New sentence is invalid."
sentences
.
append
(
new_s
)
modified_lengths
.
append
(
len
(
new_s
))
# re-construct input
modified_lengths
=
torch
.
LongTensor
(
modified_lengths
)
modified_x
=
torch
.
LongTensor
(
modified_lengths
.
max
(),
modified_lengths
.
size
(
0
)
).
fill_
(
self
.
dictionary
.
pad
())
for
i
in
range
(
modified_lengths
.
size
(
0
)):
modified_x
[:
modified_lengths
[
i
],
i
].
copy_
(
torch
.
LongTensor
(
sentences
[
i
]))
return
modified_x
,
modified_lengths
class
WordShuffle
(
WordNoising
):
"""Shuffle words by no more than k positions."""
def
__init__
(
self
,
dictionary
):
super
().
__init__
(
dictionary
)
def
noising
(
self
,
x
,
lengths
,
max_shuffle_distance
=
3
):
# x: (T x B), lengths: B
if
max_shuffle_distance
==
0
:
return
x
,
lengths
# max_shuffle_distance < 1 will return the same sequence
assert
max_shuffle_distance
>
1
# define noise word scores
noise
=
np
.
random
.
uniform
(
0
,
max_shuffle_distance
,
size
=
(
x
.
size
(
0
)
-
1
,
x
.
size
(
1
)),
)
noise
[
0
]
=
-
1
# do not move start sentence symbol
# be sure to shuffle entire words
word_idx
=
self
.
_get_bpe_word_idx
(
x
)
x2
=
x
.
clone
()
for
i
in
range
(
lengths
.
size
(
0
)):
# generate a random permutation
scores
=
word_idx
[:
lengths
[
i
]
-
1
,
i
]
+
noise
[
word_idx
[:
lengths
[
i
]
-
1
,
i
],
i
]
# ensure no reordering inside a word
scores
+=
1e-6
*
np
.
arange
(
lengths
[
i
]
-
1
)
permutation
=
scores
.
argsort
()
# shuffle words
x2
[:
lengths
[
i
]
-
1
,
i
].
copy_
(
x2
[:
lengths
[
i
]
-
1
,
i
][
torch
.
from_numpy
(
permutation
)]
)
return
x2
,
lengths
fairseq/optim/fp16_optimizer.py
View file @
f8377a70
...
...
@@ -133,7 +133,9 @@ class FP16Optimizer(optim.FairseqOptimizer):
self
.
scaler
.
update_scale
(
overflow
)
if
overflow
:
if
self
.
scaler
.
loss_scale
<=
self
.
args
.
min_loss_scale
:
raise
Exception
((
# Use FloatingPointError as an uncommon error that parent
# functions can safely catch to stop training.
raise
FloatingPointError
((
'Minimum loss scale reached ({}). Your loss is probably exploding. '
'Try lowering the learning rate, using gradient clipping or '
'increasing the batch size.'
...
...
fairseq/sequence_generator.py
View file @
f8377a70
...
...
@@ -480,21 +480,17 @@ class SequenceGenerator(object):
if
len
(
self
.
models
)
==
1
:
return
self
.
_decode_one
(
tokens
,
self
.
models
[
0
],
encoder_outs
[
0
],
incremental_states
,
log_probs
=
True
)
av
g_probs
=
None
lo
g_probs
=
[]
avg_attn
=
None
for
model
,
encoder_out
in
zip
(
self
.
models
,
encoder_outs
):
probs
,
attn
=
self
.
_decode_one
(
tokens
,
model
,
encoder_out
,
incremental_states
,
log_probs
=
False
)
if
avg_probs
is
None
:
avg_probs
=
probs
else
:
avg_probs
.
add_
(
probs
)
probs
,
attn
=
self
.
_decode_one
(
tokens
,
model
,
encoder_out
,
incremental_states
,
log_probs
=
True
)
log_probs
.
append
(
probs
)
if
attn
is
not
None
:
if
avg_attn
is
None
:
avg_attn
=
attn
else
:
avg_attn
.
add_
(
attn
)
avg_probs
.
div_
(
len
(
self
.
models
))
avg_probs
.
log_
()
avg_probs
=
torch
.
logsumexp
(
torch
.
stack
(
log_probs
,
dim
=
0
),
dim
=
0
)
-
math
.
log
(
len
(
self
.
models
))
if
avg_attn
is
not
None
:
avg_attn
.
div_
(
len
(
self
.
models
))
return
avg_probs
,
avg_attn
...
...
fairseq/trainer.py
View file @
f8377a70
...
...
@@ -45,7 +45,15 @@ class Trainer(object):
else
:
self
.
_model
=
model
.
cuda
()
# initialize meters
self
.
_dummy_batch
=
dummy_batch
self
.
_num_updates
=
0
self
.
_optim_history
=
None
self
.
_optimizer
=
None
self
.
_wrapped_model
=
None
self
.
init_meters
(
args
)
def
init_meters
(
self
,
args
):
self
.
meters
=
OrderedDict
()
self
.
meters
[
'train_loss'
]
=
AverageMeter
()
self
.
meters
[
'train_nll_loss'
]
=
AverageMeter
()
...
...
@@ -63,11 +71,6 @@ class Trainer(object):
self
.
meters
[
'wall'
]
=
TimeMeter
()
# wall time in seconds
self
.
meters
[
'train_wall'
]
=
StopwatchMeter
()
# train wall time in seconds
self
.
_dummy_batch
=
dummy_batch
self
.
_num_updates
=
0
self
.
_optim_history
=
None
self
.
_optimizer
=
None
self
.
_wrapped_model
=
None
@
property
def
model
(
self
):
...
...
tests/test_noising.py
0 → 100644
View file @
f8377a70
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
torch
import
unittest
from
fairseq.data
import
data_utils
,
Dictionary
,
noising
class
TestDataNoising
(
unittest
.
TestCase
):
def
_get_test_data
(
self
):
vocab
=
Dictionary
()
vocab
.
add_symbol
(
"he@@"
)
vocab
.
add_symbol
(
"llo"
)
vocab
.
add_symbol
(
"how"
)
vocab
.
add_symbol
(
"are"
)
vocab
.
add_symbol
(
"y@@"
)
vocab
.
add_symbol
(
"ou"
)
vocab
.
add_symbol
(
"n@@"
)
vocab
.
add_symbol
(
"ew"
)
vocab
.
add_symbol
(
"or@@"
)
vocab
.
add_symbol
(
"k"
)
src_tokens
=
[
[
"he@@"
,
"llo"
,
"n@@"
,
"ew"
,
"y@@"
,
"or@@"
,
"k"
],
[
"how"
,
"are"
,
"y@@"
,
"ou"
],
]
src_len
=
[
len
(
x
)
for
x
in
src_tokens
]
x
=
torch
.
LongTensor
(
len
(
src_tokens
),
max
(
src_len
)
+
1
).
fill_
(
vocab
.
pad
())
for
i
in
range
(
len
(
src_tokens
)):
for
j
in
range
(
len
(
src_tokens
[
i
])):
x
[
i
][
j
]
=
vocab
.
index
(
src_tokens
[
i
][
j
])
x
[
i
][
j
+
1
]
=
vocab
.
eos
()
x
=
x
.
transpose
(
1
,
0
)
return
vocab
,
x
,
torch
.
LongTensor
([
i
+
1
for
i
in
src_len
])
def
test_word_dropout
(
self
):
vocab
,
x
,
x_len
=
self
.
_get_test_data
()
with
data_utils
.
numpy_seed
(
1234
):
noising_gen
=
noising
.
WordDropout
(
vocab
)
x_noised
,
l_noised
=
noising_gen
.
noising
(
x
,
x_len
,
0.2
)
# Expect only the first word (2 bpe tokens) of the first example
# was dropped out
self
.
assertEqual
(
x_len
[
0
]
-
2
,
l_noised
[
0
])
for
i
in
range
(
l_noised
[
0
]):
self
.
assertEqual
(
x_noised
[
i
][
0
],
x
[
i
+
2
][
0
])
def
test_word_blank
(
self
):
vocab
,
x
,
x_len
=
self
.
_get_test_data
()
with
data_utils
.
numpy_seed
(
1234
):
noising_gen
=
noising
.
WordDropout
(
vocab
)
x_noised
,
l_noised
=
noising_gen
.
noising
(
x
,
x_len
,
0.2
,
vocab
.
unk
())
# Expect only the first word (2 bpe tokens) of the first example
# was blanked out
self
.
assertEqual
(
x_len
[
0
],
l_noised
[
0
])
for
i
in
range
(
l_noised
[
0
]):
if
i
<
2
:
self
.
assertEqual
(
x_noised
[
i
][
0
],
vocab
.
unk
())
else
:
self
.
assertEqual
(
x_noised
[
i
][
0
],
x
[
i
][
0
])
def
test_word_shuffle
(
self
):
vocab
,
x
,
x_len
=
self
.
_get_test_data
()
with
data_utils
.
numpy_seed
(
1234
):
word_shuffle
=
noising
.
WordShuffle
(
vocab
)
x_noised
,
l_noised
=
word_shuffle
.
noising
(
x
,
x_len
,
0
)
for
i
in
range
(
len
(
x_len
)):
for
j
in
range
(
x_len
[
i
]):
self
.
assertEqual
(
x
[
j
][
i
],
x_noised
[
j
][
i
])
self
.
assertEqual
(
x_len
[
0
],
l_noised
[
0
])
x_noised
,
l_noised
=
word_shuffle
.
noising
(
x
,
x_len
,
3
)
# Expect the second example has the last three tokens shuffled
# 6, 7, 8, 9 => 6, 8, 9, 7, where (8, 9) is a word
for
i
in
range
(
x_len
[
0
]):
self
.
assertEqual
(
x
[
i
][
0
],
x_noised
[
i
][
0
])
shuffle_map
=
{
0
:
0
,
1
:
3
,
2
:
1
,
3
:
2
}
for
k
,
v
in
shuffle_map
.
items
():
self
.
assertEqual
(
x
[
k
][
1
],
x_noised
[
v
][
1
])
self
.
assertEqual
(
x_len
[
0
],
l_noised
[
0
])
self
.
assertEqual
(
x_len
[
1
],
l_noised
[
1
])
if
__name__
==
'__main__'
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment