Commit eb005cdb authored by Myle Ott's avatar Myle Ott
Browse files

Streamline data formatting utils

parent 6f6cb4ab
...@@ -229,13 +229,15 @@ class LanguagePairDataset(object): ...@@ -229,13 +229,15 @@ class LanguagePairDataset(object):
return { return {
'id': torch.LongTensor([s['id'].item() for s in samples]), 'id': torch.LongTensor([s['id'].item() for s in samples]),
'src_tokens': merge('source', left_pad=LanguagePairDataset.LEFT_PAD_SOURCE),
# we create a shifted version of targets for feeding the previous
# output token(s) into the next decoder step
'input_tokens': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
move_eos_to_beginning=True),
'target': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET),
'ntokens': sum(len(s['target']) for s in samples), 'ntokens': sum(len(s['target']) for s in samples),
'net_input': {
'src_tokens': merge('source', left_pad=LanguagePairDataset.LEFT_PAD_SOURCE),
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
'input_tokens': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
move_eos_to_beginning=True),
},
'target': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET),
} }
@staticmethod @staticmethod
......
...@@ -381,4 +381,4 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -381,4 +381,4 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self._max_bsz_seen = sample['target'].size(0) self._max_bsz_seen = sample['target'].size(0)
torch.cuda.empty_cache() torch.cuda.empty_cache()
self._sample = utils.prepare_sample(sample, volatile=volatile, cuda_device=device_id) self._sample = utils.make_variable(sample, volatile=volatile, cuda_device=device_id)
...@@ -65,7 +65,7 @@ class SequenceGenerator(object): ...@@ -65,7 +65,7 @@ class SequenceGenerator(object):
maxlen_b = self.maxlen maxlen_b = self.maxlen
for sample in data_itr: for sample in data_itr:
s = utils.prepare_sample(sample, volatile=True, cuda_device=cuda_device) s = utils.make_variable(sample, volatile=True, cuda_device=cuda_device)
input = s['net_input'] input = s['net_input']
srclen = input['src_tokens'].size(1) srclen = input['src_tokens'].size(1)
if timer is not None: if timer is not None:
...@@ -74,10 +74,10 @@ class SequenceGenerator(object): ...@@ -74,10 +74,10 @@ class SequenceGenerator(object):
maxlen=int(maxlen_a*srclen + maxlen_b)) maxlen=int(maxlen_a*srclen + maxlen_b))
if timer is not None: if timer is not None:
timer.stop(s['ntokens']) timer.stop(s['ntokens'])
for i, id in enumerate(s['id']): for i, id in enumerate(s['id'].data):
src = input['src_tokens'].data[i, :] src = input['src_tokens'].data[i, :]
# remove padding from ref # remove padding from ref
ref = utils.rstrip_pad(s['target'].data[i, :], self.pad) ref = utils.strip_pad(s['target'].data[i, :], self.pad)
yield id, src, ref, hypos[i] yield id, src, ref, hypos[i]
def generate(self, src_tokens, beam_size=None, maxlen=None): def generate(self, src_tokens, beam_size=None, maxlen=None):
......
...@@ -176,23 +176,25 @@ def _upgrade_args(args): ...@@ -176,23 +176,25 @@ def _upgrade_args(args):
return args return args
def prepare_sample(sample, volatile=False, cuda_device=None): def make_variable(sample, volatile=False, cuda_device=None):
"""Wrap input tensors in Variable class.""" """Wrap input tensors in Variable class."""
def make_variable(tensor): def _make_variable(maybe_tensor):
if cuda_device is not None and torch.cuda.is_available(): if torch.is_tensor(maybe_tensor):
tensor = tensor.cuda(async=True, device=cuda_device) if cuda_device is not None and torch.cuda.is_available():
return Variable(tensor, volatile=volatile) maybe_tensor = maybe_tensor.cuda(async=True, device=cuda_device)
return Variable(maybe_tensor, volatile=volatile)
return { elif isinstance(maybe_tensor, dict):
'id': sample['id'], return {
'ntokens': sample['ntokens'], key: _make_variable(value)
'target': make_variable(sample['target']), for key, value in maybe_tensor.items()
'net_input': { }
key: make_variable(sample[key]) elif isinstance(maybe_tensor, list):
for key in ['src_tokens', 'input_tokens'] return [_make_variable(x) for x in maybe_tensor]
}, else:
} return maybe_tensor
return _make_variable(sample)
def load_align_dict(replace_unk): def load_align_dict(replace_unk):
...@@ -247,6 +249,14 @@ def rstrip_pad(tensor, pad): ...@@ -247,6 +249,14 @@ def rstrip_pad(tensor, pad):
return tensor return tensor
def strip_pad(tensor, pad):
if tensor[0] == pad:
tensor = lstrip_pad(tensor, pad)
if tensor[-1] == pad:
tensor = rstrip_pad(tensor, pad)
return tensor
def maybe_no_grad(condition): def maybe_no_grad(condition):
if hasattr(torch, 'no_grad') and condition: if hasattr(torch, 'no_grad') and condition:
return torch.no_grad() return torch.no_grad()
......
...@@ -159,7 +159,7 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions): ...@@ -159,7 +159,7 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
del loss_dict['loss'] # don't include in extra_meters or extra_postfix del loss_dict['loss'] # don't include in extra_meters or extra_postfix
ntokens = sum(s['ntokens'] for s in sample) ntokens = sum(s['ntokens'] for s in sample)
nsentences = sum(s['src_tokens'].size(0) for s in sample) nsentences = sum(s['net_input']['src_tokens'].size(0) for s in sample)
loss_meter.update(loss, nsentences if args.sentence_avg else ntokens) loss_meter.update(loss, nsentences if args.sentence_avg else ntokens)
bsz_meter.update(nsentences) bsz_meter.update(nsentences)
wpb_meter.update(ntokens) wpb_meter.update(ntokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment