"tests/data/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "e87a04e3c7c16aeb28ccd0cba693aa7c5ebaab48"
Commit 930c9580 authored by Myle Ott's avatar Myle Ott
Browse files

Support FP16 during inference

parent 9a88b71d
......@@ -35,6 +35,8 @@ def main(args):
# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
for model in models:
model.make_generation_fast_()
if args.fp16:
model.half()
itr = data.EpochBatchIterator(
dataset=task.dataset(args.gen_subset),
......
......@@ -117,6 +117,7 @@ def get_parser(desc, default_task='translation'):
choices=['json', 'none', 'simple', 'tqdm'])
parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed')
parser.add_argument('--fp16', action='store_true', help='use FP16')
# Task definitions can be found under fairseq/tasks/
parser.add_argument(
......@@ -187,8 +188,6 @@ def add_optimization_args(parser):
' (default is to normalize by number of tokens)')
group.add_argument('--update-freq', default='1', metavar='N',
help='update parameters every N_i batches, when in epoch i')
group.add_argument('--fp16', action='store_true',
help='use FP16 during training')
# Optimizer definitions can be found under fairseq/optim/
group.add_argument('--optimizer', default='nag', metavar='OPT',
......
......@@ -43,6 +43,8 @@ def main(args):
# Optimize ensemble for generation
for model in models:
model.make_generation_fast_(beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
if args.fp16:
model.half()
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
......
......@@ -82,9 +82,9 @@ def main(args):
# Optimize ensemble for generation
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
)
model.make_generation_fast_(beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
if args.fp16:
model.half()
# Initialize generator
translator = SequenceGenerator(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment