Commit 930c9580 authored by Myle Ott's avatar Myle Ott
Browse files

Support FP16 during inference

parent 9a88b71d
...@@ -35,6 +35,8 @@ def main(args): ...@@ -35,6 +35,8 @@ def main(args):
# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
for model in models: for model in models:
model.make_generation_fast_() model.make_generation_fast_()
if args.fp16:
model.half()
itr = data.EpochBatchIterator( itr = data.EpochBatchIterator(
dataset=task.dataset(args.gen_subset), dataset=task.dataset(args.gen_subset),
......
...@@ -117,6 +117,7 @@ def get_parser(desc, default_task='translation'): ...@@ -117,6 +117,7 @@ def get_parser(desc, default_task='translation'):
choices=['json', 'none', 'simple', 'tqdm']) choices=['json', 'none', 'simple', 'tqdm'])
parser.add_argument('--seed', default=1, type=int, metavar='N', parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed') help='pseudo random number generator seed')
parser.add_argument('--fp16', action='store_true', help='use FP16')
# Task definitions can be found under fairseq/tasks/ # Task definitions can be found under fairseq/tasks/
parser.add_argument( parser.add_argument(
...@@ -187,8 +188,6 @@ def add_optimization_args(parser): ...@@ -187,8 +188,6 @@ def add_optimization_args(parser):
' (default is to normalize by number of tokens)') ' (default is to normalize by number of tokens)')
group.add_argument('--update-freq', default='1', metavar='N', group.add_argument('--update-freq', default='1', metavar='N',
help='update parameters every N_i batches, when in epoch i') help='update parameters every N_i batches, when in epoch i')
group.add_argument('--fp16', action='store_true',
help='use FP16 during training')
# Optimizer definitions can be found under fairseq/optim/ # Optimizer definitions can be found under fairseq/optim/
group.add_argument('--optimizer', default='nag', metavar='OPT', group.add_argument('--optimizer', default='nag', metavar='OPT',
......
...@@ -43,6 +43,8 @@ def main(args): ...@@ -43,6 +43,8 @@ def main(args):
# Optimize ensemble for generation # Optimize ensemble for generation
for model in models: for model in models:
model.make_generation_fast_(beamable_mm_beam_size=None if args.no_beamable_mm else args.beam) model.make_generation_fast_(beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
if args.fp16:
model.half()
# Load alignment dictionary for unknown word replacement # Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary) # (None if no unknown word replacement, empty if no path to align dictionary)
......
...@@ -82,9 +82,9 @@ def main(args): ...@@ -82,9 +82,9 @@ def main(args):
# Optimize ensemble for generation # Optimize ensemble for generation
for model in models: for model in models:
model.make_generation_fast_( model.make_generation_fast_(beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, if args.fp16:
) model.half()
# Initialize generator # Initialize generator
translator = SequenceGenerator( translator = SequenceGenerator(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment