Commit 860010e9 authored by Dmytro Okhonko's avatar Dmytro Okhonko Committed by Facebook Github Bot
Browse files

Handle 3+ dimensional input in sequence_generator + nits

Summary: sequence_generator assumes that model input is 2d tensor of longs. But it can be something like 3d tensor of floats and we should be able to handle this as long as first dimension is batch size followed by source lengths.

Reviewed By: myleott

Differential Revision: D14420044

fbshipit-source-id: bf8b1e42ad1873f7b803c1a377b0af21648db015
parent d17fa851
......@@ -223,6 +223,8 @@ def add_dataset_args(parser, train=False, gen=False):
help='maximum number of tokens in a batch')
group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N',
help='maximum number of sentences in a batch')
group.add_argument('--required-batch-size-multiple', default=8, type=int, metavar='N',
help='batch size will be a multiplier of this value')
if train:
group.add_argument('--train-subset', default='train', metavar='SPLIT',
choices=['train', 'valid', 'test'],
......@@ -359,6 +361,8 @@ def add_common_eval_args(group):
group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override model args at generation '
'that were used during model training')
group.add_argument('--results-path', metavar='RESDIR', type=str, default=None,
help='path to save eval results (optional)"')
# fmt: on
......
......@@ -127,7 +127,10 @@ class SequenceGenerator(object):
src_tokens = encoder_input['src_tokens']
src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
bsz, src_len = src_tokens.size()
input_size = src_tokens.size()
# batch dimension goes first followed by source lengths
bsz = input_size[0]
src_len = input_size[1]
beam_size = self.beam_size
if self.match_source_len:
......@@ -148,7 +151,7 @@ class SequenceGenerator(object):
# initialize buffers
scores = src_tokens.new(bsz * beam_size, max_len + 1).float().fill_(0)
scores_buf = scores.clone()
tokens = src_tokens.new(bsz * beam_size, max_len + 2).fill_(self.pad)
tokens = src_tokens.data.new(bsz * beam_size, max_len + 2).long().fill_(self.pad)
tokens_buf = tokens.clone()
tokens[:, 0] = bos_token or self.eos
attn, attn_buf = None, None
......
......@@ -202,7 +202,7 @@ class Trainer(object):
sample_sizes.append(sample_size)
except RuntimeError as e:
if 'out of memory' in str(e):
print('| WARNING: ran out of memory, skipping batch')
print(('| WARNING: ran out of memory with exception: {};\n Skipping batch').format(str(e)))
ooms += 1
self.zero_grad()
else:
......
......@@ -74,7 +74,7 @@ def main(args):
*[model.max_positions() for model in models]
),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=8,
required_batch_size_multiple=args.required_batch_size_multiple,
num_shards=args.num_shards,
shard_id=args.shard_id,
num_workers=args.num_workers,
......
......@@ -234,6 +234,7 @@ class TestCommonOptions(unittest.TestCase):
if os.path.exists(last_checkpoint):
os.remove(last_checkpoint)
train_translation_model(data_dir, 'lstm', [
'--required-batch-size-multiple', '1',
'--encoder-layers', '1',
'--encoder-hidden-size', '32',
'--decoder-layers', '1',
......
......@@ -82,7 +82,7 @@ def main(args, init_distributed=False):
max_sentences=args.max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=True,
required_batch_size_multiple=8,
required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
......@@ -220,7 +220,7 @@ def validate(args, trainer, task, epoch_itr, subsets):
trainer.get_model().max_positions(),
),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=8,
required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment