Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
97d7fcb9
Commit
97d7fcb9
authored
Oct 31, 2017
by
Myle Ott
Browse files
Left pad source and right pad target
parent
7ae79c12
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
8 deletions
+15
-8
fairseq/data.py
fairseq/data.py
+2
-2
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+2
-6
fairseq/utils.py
fairseq/utils.py
+11
-0
No files found.
fairseq/data.py
View file @
97d7fcb9
...
@@ -142,8 +142,8 @@ def skip_group_enumerator(it, ngpus, offset=0):
...
@@ -142,8 +142,8 @@ def skip_group_enumerator(it, ngpus, offset=0):
class
LanguagePairDataset
(
object
):
class
LanguagePairDataset
(
object
):
# padding constants
# padding constants
LEFT_PAD_SOURCE
=
Fals
e
LEFT_PAD_SOURCE
=
Tru
e
LEFT_PAD_TARGET
=
Tru
e
LEFT_PAD_TARGET
=
Fals
e
def
__init__
(
self
,
src
,
dst
,
pad_idx
,
eos_idx
):
def
__init__
(
self
,
src
,
dst
,
pad_idx
,
eos_idx
):
self
.
src
=
src
self
.
src
=
src
...
...
fairseq/sequence_generator.py
View file @
97d7fcb9
...
@@ -61,10 +61,6 @@ class SequenceGenerator(object):
...
@@ -61,10 +61,6 @@ class SequenceGenerator(object):
cuda_device: GPU on which to do generation.
cuda_device: GPU on which to do generation.
timer: StopwatchMeter for timing generations.
timer: StopwatchMeter for timing generations.
"""
"""
def
lstrip_pad
(
tensor
):
return
tensor
[
tensor
.
eq
(
self
.
pad
).
sum
():]
if
maxlen_b
is
None
:
if
maxlen_b
is
None
:
maxlen_b
=
self
.
maxlen
maxlen_b
=
self
.
maxlen
...
@@ -80,8 +76,8 @@ class SequenceGenerator(object):
...
@@ -80,8 +76,8 @@ class SequenceGenerator(object):
timer
.
stop
(
s
[
'ntokens'
])
timer
.
stop
(
s
[
'ntokens'
])
for
i
,
id
in
enumerate
(
s
[
'id'
]):
for
i
,
id
in
enumerate
(
s
[
'id'
]):
src
=
input
[
'src_tokens'
].
data
[
i
,
:]
src
=
input
[
'src_tokens'
].
data
[
i
,
:]
# remove padding from ref
, which appears at the beginning
# remove padding from ref
ref
=
l
strip_pad
(
s
[
'target'
].
data
[
i
,
:])
ref
=
utils
.
r
strip_pad
(
s
[
'target'
].
data
[
i
,
:]
,
self
.
pad
)
yield
id
,
src
,
ref
,
hypos
[
i
]
yield
id
,
src
,
ref
,
hypos
[
i
]
def
generate
(
self
,
src_tokens
,
beam_size
=
None
,
maxlen
=
None
):
def
generate
(
self
,
src_tokens
,
beam_size
=
None
,
maxlen
=
None
):
...
...
fairseq/utils.py
View file @
97d7fcb9
...
@@ -202,3 +202,14 @@ def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, dst_dic
...
@@ -202,3 +202,14 @@ def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, dst_dic
# Note that the dictionary can be modified inside the method.
# Note that the dictionary can be modified inside the method.
hypo_tokens
=
tokenizer
.
Tokenizer
.
tokenize
(
hypo_str
,
dst_dict
,
add_if_not_exist
=
True
)
hypo_tokens
=
tokenizer
.
Tokenizer
.
tokenize
(
hypo_str
,
dst_dict
,
add_if_not_exist
=
True
)
return
hypo_tokens
,
hypo_str
,
alignment
return
hypo_tokens
,
hypo_str
,
alignment
def
lstrip_pad
(
tensor
,
pad
):
return
tensor
[
tensor
.
eq
(
pad
).
sum
():]
def
rstrip_pad
(
tensor
,
pad
):
strip
=
tensor
.
eq
(
pad
).
sum
()
if
strip
>
0
:
return
tensor
[:
-
strip
]
return
tensor
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment