Commit eb005cdb authored by Myle Ott's avatar Myle Ott
Browse files

Streamline data formatting utils

parent 6f6cb4ab
......@@ -229,13 +229,15 @@ class LanguagePairDataset(object):
return {
'id': torch.LongTensor([s['id'].item() for s in samples]),
'ntokens': sum(len(s['target']) for s in samples),
'net_input': {
'src_tokens': merge('source', left_pad=LanguagePairDataset.LEFT_PAD_SOURCE),
# we create a shifted version of targets for feeding the previous
# output token(s) into the next decoder step
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
'input_tokens': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
move_eos_to_beginning=True),
},
'target': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET),
'ntokens': sum(len(s['target']) for s in samples),
}
@staticmethod
......
......@@ -381,4 +381,4 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self._max_bsz_seen = sample['target'].size(0)
torch.cuda.empty_cache()
self._sample = utils.prepare_sample(sample, volatile=volatile, cuda_device=device_id)
self._sample = utils.make_variable(sample, volatile=volatile, cuda_device=device_id)
......@@ -65,7 +65,7 @@ class SequenceGenerator(object):
maxlen_b = self.maxlen
for sample in data_itr:
s = utils.prepare_sample(sample, volatile=True, cuda_device=cuda_device)
s = utils.make_variable(sample, volatile=True, cuda_device=cuda_device)
input = s['net_input']
srclen = input['src_tokens'].size(1)
if timer is not None:
......@@ -74,10 +74,10 @@ class SequenceGenerator(object):
maxlen=int(maxlen_a*srclen + maxlen_b))
if timer is not None:
timer.stop(s['ntokens'])
for i, id in enumerate(s['id']):
for i, id in enumerate(s['id'].data):
src = input['src_tokens'].data[i, :]
# remove padding from ref
ref = utils.rstrip_pad(s['target'].data[i, :], self.pad)
ref = utils.strip_pad(s['target'].data[i, :], self.pad)
yield id, src, ref, hypos[i]
def generate(self, src_tokens, beam_size=None, maxlen=None):
......
......@@ -176,23 +176,25 @@ def _upgrade_args(args):
return args
def prepare_sample(sample, volatile=False, cuda_device=None):
def make_variable(sample, volatile=False, cuda_device=None):
"""Wrap input tensors in Variable class."""
def make_variable(tensor):
def _make_variable(maybe_tensor):
if torch.is_tensor(maybe_tensor):
if cuda_device is not None and torch.cuda.is_available():
tensor = tensor.cuda(async=True, device=cuda_device)
return Variable(tensor, volatile=volatile)
maybe_tensor = maybe_tensor.cuda(async=True, device=cuda_device)
return Variable(maybe_tensor, volatile=volatile)
elif isinstance(maybe_tensor, dict):
return {
'id': sample['id'],
'ntokens': sample['ntokens'],
'target': make_variable(sample['target']),
'net_input': {
key: make_variable(sample[key])
for key in ['src_tokens', 'input_tokens']
},
key: _make_variable(value)
for key, value in maybe_tensor.items()
}
elif isinstance(maybe_tensor, list):
return [_make_variable(x) for x in maybe_tensor]
else:
return maybe_tensor
return _make_variable(sample)
def load_align_dict(replace_unk):
......@@ -247,6 +249,14 @@ def rstrip_pad(tensor, pad):
return tensor
def strip_pad(tensor, pad):
if tensor[0] == pad:
tensor = lstrip_pad(tensor, pad)
if tensor[-1] == pad:
tensor = rstrip_pad(tensor, pad)
return tensor
def maybe_no_grad(condition):
if hasattr(torch, 'no_grad') and condition:
return torch.no_grad()
......
......@@ -159,7 +159,7 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
del loss_dict['loss'] # don't include in extra_meters or extra_postfix
ntokens = sum(s['ntokens'] for s in sample)
nsentences = sum(s['src_tokens'].size(0) for s in sample)
nsentences = sum(s['net_input']['src_tokens'].size(0) for s in sample)
loss_meter.update(loss, nsentences if args.sentence_avg else ntokens)
bsz_meter.update(nsentences)
wpb_meter.update(ntokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment