Commit 860010e9 authored by Dmytro Okhonko's avatar Dmytro Okhonko Committed by Facebook Github Bot
Browse files

Handle 3+ dimensional input in sequence_generator + nits

Summary: sequence_generator assumes that model input is 2d tensor of longs. But it can be something like 3d tensor of floats and we should be able to handle this as long as first dimension is batch size followed by source lengths.

Reviewed By: myleott

Differential Revision: D14420044

fbshipit-source-id: bf8b1e42ad1873f7b803c1a377b0af21648db015
parent d17fa851
...@@ -223,6 +223,8 @@ def add_dataset_args(parser, train=False, gen=False): ...@@ -223,6 +223,8 @@ def add_dataset_args(parser, train=False, gen=False):
help='maximum number of tokens in a batch') help='maximum number of tokens in a batch')
group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N', group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N',
help='maximum number of sentences in a batch') help='maximum number of sentences in a batch')
group.add_argument('--required-batch-size-multiple', default=8, type=int, metavar='N',
help='batch size will be a multiplier of this value')
if train: if train:
group.add_argument('--train-subset', default='train', metavar='SPLIT', group.add_argument('--train-subset', default='train', metavar='SPLIT',
choices=['train', 'valid', 'test'], choices=['train', 'valid', 'test'],
...@@ -359,6 +361,8 @@ def add_common_eval_args(group): ...@@ -359,6 +361,8 @@ def add_common_eval_args(group):
group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT', group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override model args at generation ' help='a dictionary used to override model args at generation '
'that were used during model training') 'that were used during model training')
group.add_argument('--results-path', metavar='RESDIR', type=str, default=None,
help='path to save eval results (optional)"')
# fmt: on # fmt: on
......
...@@ -127,7 +127,10 @@ class SequenceGenerator(object): ...@@ -127,7 +127,10 @@ class SequenceGenerator(object):
src_tokens = encoder_input['src_tokens'] src_tokens = encoder_input['src_tokens']
src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
bsz, src_len = src_tokens.size() input_size = src_tokens.size()
# batch dimension goes first followed by source lengths
bsz = input_size[0]
src_len = input_size[1]
beam_size = self.beam_size beam_size = self.beam_size
if self.match_source_len: if self.match_source_len:
...@@ -148,7 +151,7 @@ class SequenceGenerator(object): ...@@ -148,7 +151,7 @@ class SequenceGenerator(object):
# initialize buffers # initialize buffers
scores = src_tokens.new(bsz * beam_size, max_len + 1).float().fill_(0) scores = src_tokens.new(bsz * beam_size, max_len + 1).float().fill_(0)
scores_buf = scores.clone() scores_buf = scores.clone()
tokens = src_tokens.new(bsz * beam_size, max_len + 2).fill_(self.pad) tokens = src_tokens.data.new(bsz * beam_size, max_len + 2).long().fill_(self.pad)
tokens_buf = tokens.clone() tokens_buf = tokens.clone()
tokens[:, 0] = bos_token or self.eos tokens[:, 0] = bos_token or self.eos
attn, attn_buf = None, None attn, attn_buf = None, None
......
...@@ -202,7 +202,7 @@ class Trainer(object): ...@@ -202,7 +202,7 @@ class Trainer(object):
sample_sizes.append(sample_size) sample_sizes.append(sample_size)
except RuntimeError as e: except RuntimeError as e:
if 'out of memory' in str(e): if 'out of memory' in str(e):
print('| WARNING: ran out of memory, skipping batch') print(('| WARNING: ran out of memory with exception: {};\n Skipping batch').format(str(e)))
ooms += 1 ooms += 1
self.zero_grad() self.zero_grad()
else: else:
......
...@@ -74,7 +74,7 @@ def main(args): ...@@ -74,7 +74,7 @@ def main(args):
*[model.max_positions() for model in models] *[model.max_positions() for model in models]
), ),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=8, required_batch_size_multiple=args.required_batch_size_multiple,
num_shards=args.num_shards, num_shards=args.num_shards,
shard_id=args.shard_id, shard_id=args.shard_id,
num_workers=args.num_workers, num_workers=args.num_workers,
......
...@@ -234,6 +234,7 @@ class TestCommonOptions(unittest.TestCase): ...@@ -234,6 +234,7 @@ class TestCommonOptions(unittest.TestCase):
if os.path.exists(last_checkpoint): if os.path.exists(last_checkpoint):
os.remove(last_checkpoint) os.remove(last_checkpoint)
train_translation_model(data_dir, 'lstm', [ train_translation_model(data_dir, 'lstm', [
'--required-batch-size-multiple', '1',
'--encoder-layers', '1', '--encoder-layers', '1',
'--encoder-hidden-size', '32', '--encoder-hidden-size', '32',
'--decoder-layers', '1', '--decoder-layers', '1',
......
...@@ -82,7 +82,7 @@ def main(args, init_distributed=False): ...@@ -82,7 +82,7 @@ def main(args, init_distributed=False):
max_sentences=args.max_sentences, max_sentences=args.max_sentences,
max_positions=max_positions, max_positions=max_positions,
ignore_invalid_inputs=True, ignore_invalid_inputs=True,
required_batch_size_multiple=8, required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed, seed=args.seed,
num_shards=args.distributed_world_size, num_shards=args.distributed_world_size,
shard_id=args.distributed_rank, shard_id=args.distributed_rank,
...@@ -220,7 +220,7 @@ def validate(args, trainer, task, epoch_itr, subsets): ...@@ -220,7 +220,7 @@ def validate(args, trainer, task, epoch_itr, subsets):
trainer.get_model().max_positions(), trainer.get_model().max_positions(),
), ),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=8, required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed, seed=args.seed,
num_shards=args.distributed_world_size, num_shards=args.distributed_world_size,
shard_id=args.distributed_rank, shard_id=args.distributed_rank,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment