Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
eb005cdb
Commit
eb005cdb
authored
Jan 01, 2018
by
Myle Ott
Browse files
Streamline data formatting utils
parent
6f6cb4ab
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
38 additions
and
26 deletions
+38
-26
fairseq/data.py
fairseq/data.py
+8
-6
fairseq/multiprocessing_trainer.py
fairseq/multiprocessing_trainer.py
+1
-1
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+3
-3
fairseq/utils.py
fairseq/utils.py
+25
-15
train.py
train.py
+1
-1
No files found.
fairseq/data.py
View file @
eb005cdb
...
...
@@ -229,13 +229,15 @@ class LanguagePairDataset(object):
return
{
'id'
:
torch
.
LongTensor
([
s
[
'id'
].
item
()
for
s
in
samples
]),
'ntokens'
:
sum
(
len
(
s
[
'target'
])
for
s
in
samples
),
'net_input'
:
{
'src_tokens'
:
merge
(
'source'
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_SOURCE
),
# we create a shifted version of targets for feeding the
previous
#
output token(s) into the next decoder step
# we create a shifted version of targets for feeding the
# previous
output token(s) into the next decoder step
'input_tokens'
:
merge
(
'target'
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_TARGET
,
move_eos_to_beginning
=
True
),
},
'target'
:
merge
(
'target'
,
left_pad
=
LanguagePairDataset
.
LEFT_PAD_TARGET
),
'ntokens'
:
sum
(
len
(
s
[
'target'
])
for
s
in
samples
),
}
@
staticmethod
...
...
fairseq/multiprocessing_trainer.py
View file @
eb005cdb
...
...
@@ -381,4 +381,4 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self
.
_max_bsz_seen
=
sample
[
'target'
].
size
(
0
)
torch
.
cuda
.
empty_cache
()
self
.
_sample
=
utils
.
prepare_samp
le
(
sample
,
volatile
=
volatile
,
cuda_device
=
device_id
)
self
.
_sample
=
utils
.
make_variab
le
(
sample
,
volatile
=
volatile
,
cuda_device
=
device_id
)
fairseq/sequence_generator.py
View file @
eb005cdb
...
...
@@ -65,7 +65,7 @@ class SequenceGenerator(object):
maxlen_b
=
self
.
maxlen
for
sample
in
data_itr
:
s
=
utils
.
prepare_samp
le
(
sample
,
volatile
=
True
,
cuda_device
=
cuda_device
)
s
=
utils
.
make_variab
le
(
sample
,
volatile
=
True
,
cuda_device
=
cuda_device
)
input
=
s
[
'net_input'
]
srclen
=
input
[
'src_tokens'
].
size
(
1
)
if
timer
is
not
None
:
...
...
@@ -74,10 +74,10 @@ class SequenceGenerator(object):
maxlen
=
int
(
maxlen_a
*
srclen
+
maxlen_b
))
if
timer
is
not
None
:
timer
.
stop
(
s
[
'ntokens'
])
for
i
,
id
in
enumerate
(
s
[
'id'
]):
for
i
,
id
in
enumerate
(
s
[
'id'
]
.
data
):
src
=
input
[
'src_tokens'
].
data
[
i
,
:]
# remove padding from ref
ref
=
utils
.
r
strip_pad
(
s
[
'target'
].
data
[
i
,
:],
self
.
pad
)
ref
=
utils
.
strip_pad
(
s
[
'target'
].
data
[
i
,
:],
self
.
pad
)
yield
id
,
src
,
ref
,
hypos
[
i
]
def
generate
(
self
,
src_tokens
,
beam_size
=
None
,
maxlen
=
None
):
...
...
fairseq/utils.py
View file @
eb005cdb
...
...
@@ -176,23 +176,25 @@ def _upgrade_args(args):
return
args
def
prepare_samp
le
(
sample
,
volatile
=
False
,
cuda_device
=
None
):
def
make_variab
le
(
sample
,
volatile
=
False
,
cuda_device
=
None
):
"""Wrap input tensors in Variable class."""
def
make_variable
(
tensor
):
def
_make_variable
(
maybe_tensor
):
if
torch
.
is_tensor
(
maybe_tensor
):
if
cuda_device
is
not
None
and
torch
.
cuda
.
is_available
():
tensor
=
tensor
.
cuda
(
async
=
True
,
device
=
cuda_device
)
return
Variable
(
tensor
,
volatile
=
volatile
)
maybe_
tensor
=
maybe_
tensor
.
cuda
(
async
=
True
,
device
=
cuda_device
)
return
Variable
(
maybe_
tensor
,
volatile
=
volatile
)
elif
isinstance
(
maybe_tensor
,
dict
):
return
{
'id'
:
sample
[
'id'
],
'ntokens'
:
sample
[
'ntokens'
],
'target'
:
make_variable
(
sample
[
'target'
]),
'net_input'
:
{
key
:
make_variable
(
sample
[
key
])
for
key
in
[
'src_tokens'
,
'input_tokens'
]
},
key
:
_make_variable
(
value
)
for
key
,
value
in
maybe_tensor
.
items
()
}
elif
isinstance
(
maybe_tensor
,
list
):
return
[
_make_variable
(
x
)
for
x
in
maybe_tensor
]
else
:
return
maybe_tensor
return
_make_variable
(
sample
)
def
load_align_dict
(
replace_unk
):
...
...
@@ -247,6 +249,14 @@ def rstrip_pad(tensor, pad):
return
tensor
def
strip_pad
(
tensor
,
pad
):
if
tensor
[
0
]
==
pad
:
tensor
=
lstrip_pad
(
tensor
,
pad
)
if
tensor
[
-
1
]
==
pad
:
tensor
=
rstrip_pad
(
tensor
,
pad
)
return
tensor
def
maybe_no_grad
(
condition
):
if
hasattr
(
torch
,
'no_grad'
)
and
condition
:
return
torch
.
no_grad
()
...
...
train.py
View file @
eb005cdb
...
...
@@ -159,7 +159,7 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
del
loss_dict
[
'loss'
]
# don't include in extra_meters or extra_postfix
ntokens
=
sum
(
s
[
'ntokens'
]
for
s
in
sample
)
nsentences
=
sum
(
s
[
'src_tokens'
].
size
(
0
)
for
s
in
sample
)
nsentences
=
sum
(
s
[
'
net_input'
][
'
src_tokens'
].
size
(
0
)
for
s
in
sample
)
loss_meter
.
update
(
loss
,
nsentences
if
args
.
sentence_avg
else
ntokens
)
bsz_meter
.
update
(
nsentences
)
wpb_meter
.
update
(
ntokens
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment