Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
97d7fcb9
"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "b7b1a30bc49cad350c7a642e1171e886d83cd909"
Commit
97d7fcb9
authored
Oct 31, 2017
by
Myle Ott
Browse files
Left pad source and right pad target
parent
7ae79c12
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
8 deletions
+15
-8
fairseq/data.py
fairseq/data.py
+2
-2
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+2
-6
fairseq/utils.py
fairseq/utils.py
+11
-0
No files found.
fairseq/data.py
View file @
97d7fcb9
...
@@ -142,8 +142,8 @@ def skip_group_enumerator(it, ngpus, offset=0):
...
@@ -142,8 +142,8 @@ def skip_group_enumerator(it, ngpus, offset=0):
class
LanguagePairDataset
(
object
):
class
LanguagePairDataset
(
object
):
# padding constants
# padding constants
LEFT_PAD_SOURCE
=
Fals
e
LEFT_PAD_SOURCE
=
Tru
e
LEFT_PAD_TARGET
=
Tru
e
LEFT_PAD_TARGET
=
Fals
e
def
__init__
(
self
,
src
,
dst
,
pad_idx
,
eos_idx
):
def
__init__
(
self
,
src
,
dst
,
pad_idx
,
eos_idx
):
self
.
src
=
src
self
.
src
=
src
...
...
fairseq/sequence_generator.py
View file @
97d7fcb9
...
@@ -61,10 +61,6 @@ class SequenceGenerator(object):
...
@@ -61,10 +61,6 @@ class SequenceGenerator(object):
cuda_device: GPU on which to do generation.
cuda_device: GPU on which to do generation.
timer: StopwatchMeter for timing generations.
timer: StopwatchMeter for timing generations.
"""
"""
def
lstrip_pad
(
tensor
):
return
tensor
[
tensor
.
eq
(
self
.
pad
).
sum
():]
if
maxlen_b
is
None
:
if
maxlen_b
is
None
:
maxlen_b
=
self
.
maxlen
maxlen_b
=
self
.
maxlen
...
@@ -80,8 +76,8 @@ class SequenceGenerator(object):
...
@@ -80,8 +76,8 @@ class SequenceGenerator(object):
timer
.
stop
(
s
[
'ntokens'
])
timer
.
stop
(
s
[
'ntokens'
])
for
i
,
id
in
enumerate
(
s
[
'id'
]):
for
i
,
id
in
enumerate
(
s
[
'id'
]):
src
=
input
[
'src_tokens'
].
data
[
i
,
:]
src
=
input
[
'src_tokens'
].
data
[
i
,
:]
# remove padding from ref
, which appears at the beginning
# remove padding from ref
ref
=
l
strip_pad
(
s
[
'target'
].
data
[
i
,
:])
ref
=
utils
.
r
strip_pad
(
s
[
'target'
].
data
[
i
,
:]
,
self
.
pad
)
yield
id
,
src
,
ref
,
hypos
[
i
]
yield
id
,
src
,
ref
,
hypos
[
i
]
def
generate
(
self
,
src_tokens
,
beam_size
=
None
,
maxlen
=
None
):
def
generate
(
self
,
src_tokens
,
beam_size
=
None
,
maxlen
=
None
):
...
...
fairseq/utils.py
View file @
97d7fcb9
...
@@ -202,3 +202,14 @@ def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, dst_dic
...
@@ -202,3 +202,14 @@ def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, dst_dic
# Note that the dictionary can be modified inside the method.
# Note that the dictionary can be modified inside the method.
hypo_tokens
=
tokenizer
.
Tokenizer
.
tokenize
(
hypo_str
,
dst_dict
,
add_if_not_exist
=
True
)
hypo_tokens
=
tokenizer
.
Tokenizer
.
tokenize
(
hypo_str
,
dst_dict
,
add_if_not_exist
=
True
)
return
hypo_tokens
,
hypo_str
,
alignment
return
hypo_tokens
,
hypo_str
,
alignment
def
lstrip_pad
(
tensor
,
pad
):
return
tensor
[
tensor
.
eq
(
pad
).
sum
():]
def
rstrip_pad
(
tensor
,
pad
):
strip
=
tensor
.
eq
(
pad
).
sum
()
if
strip
>
0
:
return
tensor
[:
-
strip
]
return
tensor
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment