Commit 97d7fcb9 authored by Myle Ott's avatar Myle Ott
Browse files

Left pad source and right pad target

parent 7ae79c12
...@@ -142,8 +142,8 @@ def skip_group_enumerator(it, ngpus, offset=0): ...@@ -142,8 +142,8 @@ def skip_group_enumerator(it, ngpus, offset=0):
class LanguagePairDataset(object): class LanguagePairDataset(object):
# padding constants # padding constants
LEFT_PAD_SOURCE = False LEFT_PAD_SOURCE = True
LEFT_PAD_TARGET = True LEFT_PAD_TARGET = False
def __init__(self, src, dst, pad_idx, eos_idx): def __init__(self, src, dst, pad_idx, eos_idx):
self.src = src self.src = src
......
...@@ -61,10 +61,6 @@ class SequenceGenerator(object): ...@@ -61,10 +61,6 @@ class SequenceGenerator(object):
cuda_device: GPU on which to do generation. cuda_device: GPU on which to do generation.
timer: StopwatchMeter for timing generations. timer: StopwatchMeter for timing generations.
""" """
def lstrip_pad(tensor):
return tensor[tensor.eq(self.pad).sum():]
if maxlen_b is None: if maxlen_b is None:
maxlen_b = self.maxlen maxlen_b = self.maxlen
...@@ -80,8 +76,8 @@ class SequenceGenerator(object): ...@@ -80,8 +76,8 @@ class SequenceGenerator(object):
timer.stop(s['ntokens']) timer.stop(s['ntokens'])
for i, id in enumerate(s['id']): for i, id in enumerate(s['id']):
src = input['src_tokens'].data[i, :] src = input['src_tokens'].data[i, :]
# remove padding from ref, which appears at the beginning # remove padding from ref
ref = lstrip_pad(s['target'].data[i, :]) ref = utils.rstrip_pad(s['target'].data[i, :], self.pad)
yield id, src, ref, hypos[i] yield id, src, ref, hypos[i]
def generate(self, src_tokens, beam_size=None, maxlen=None): def generate(self, src_tokens, beam_size=None, maxlen=None):
......
...@@ -202,3 +202,14 @@ def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, dst_dic ...@@ -202,3 +202,14 @@ def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, dst_dic
# Note that the dictionary can be modified inside the method. # Note that the dictionary can be modified inside the method.
hypo_tokens = tokenizer.Tokenizer.tokenize(hypo_str, dst_dict, add_if_not_exist=True) hypo_tokens = tokenizer.Tokenizer.tokenize(hypo_str, dst_dict, add_if_not_exist=True)
return hypo_tokens, hypo_str, alignment return hypo_tokens, hypo_str, alignment
def lstrip_pad(tensor, pad):
return tensor[tensor.eq(pad).sum():]
def rstrip_pad(tensor, pad):
strip = tensor.eq(pad).sum()
if strip > 0:
return tensor[:-strip]
return tensor
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment