Commit abc13e28 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Make hub_utils.generator inherit from nn.Module

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/913

Differential Revision: D16536562

Pulled By: myleott

fbshipit-source-id: ce28642da6868ec884e3e416388a652977a062df
parent 5218a7c9
...@@ -12,6 +12,7 @@ import torch ...@@ -12,6 +12,7 @@ import torch
from fairseq import utils from fairseq import utils
from fairseq.data import encoders from fairseq.data import encoders
from fairseq.models import BaseFairseqModel
def from_pretrained( def from_pretrained(
...@@ -57,7 +58,7 @@ def from_pretrained( ...@@ -57,7 +58,7 @@ def from_pretrained(
} }
class Generator(object): class Generator(BaseFairseqModel):
"""PyTorch Hub API for generating sequences from a pre-trained translation """PyTorch Hub API for generating sequences from a pre-trained translation
or language model.""" or language model."""
...@@ -69,6 +70,11 @@ class Generator(object): ...@@ -69,6 +70,11 @@ class Generator(object):
self.tgt_dict = task.target_dictionary self.tgt_dict = task.target_dictionary
self.use_cuda = torch.cuda.is_available() and not getattr(args, 'cpu', False) self.use_cuda = torch.cuda.is_available() and not getattr(args, 'cpu', False)
if self.use_cuda:
if getattr(args, 'fp16', False):
self.half()
self.cuda()
# optimize model for generation # optimize model for generation
for model in self.models: for model in self.models:
model.make_generation_fast_( model.make_generation_fast_(
...@@ -78,10 +84,6 @@ class Generator(object): ...@@ -78,10 +84,6 @@ class Generator(object):
), ),
need_attn=getattr(args, 'print_alignment', False), need_attn=getattr(args, 'print_alignment', False),
) )
if self.use_cuda:
if getattr(args, 'fp16', False):
model.half()
model.cuda()
self.generator = self.task.build_generator(args) self.generator = self.task.build_generator(args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment