"sgl-kernel/git@developer.sourcefind.cn:change/sglang.git" did not exist on "c9bcffd2a53423e6a183e312a58675fb48435d2a"
Commit 5bdee18e authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Iterate on torch.hub interface

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/793

Differential Revision: D15758755

Pulled By: myleott

fbshipit-source-id: b93e4ac11bde36a0b59b4d6d1c84d31c3124d767
parent eea4d20b
...@@ -5,48 +5,11 @@ ...@@ -5,48 +5,11 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from fairseq.models.transformer import TransformerModel from fairseq.models import MODEL_REGISTRY
from fairseq.models.fconv import FConvModel
from fairseq.models.fconv_self_att import FConvModelSelfAtt
from generator import Generator
from fairseq import options
dependencies = [ dependencies = ['torch']
'torch',
'sacremoses',
'subword_nmt',
]
def transformer(*args, **kwargs): for model, cls in MODEL_REGISTRY.items():
""" globals()[model] = cls.from_pretrained
Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017)
<https://arxiv.org/abs/1706.03762>`_.
"""
parser = options.get_interactive_generation_parser()
model = TransformerModel.from_pretrained(parser, *args, **kwargs)
return model
def fconv(*args, **kwargs):
"""
A fully convolutional model, i.e. a convolutional encoder and a
convolutional decoder, as described in `"Convolutional Sequence to Sequence
Learning" (Gehring et al., 2017) <https://arxiv.org/abs/1705.03122>`_.
"""
parser = options.get_interactive_generation_parser()
model = FConvModel.from_pretrained(parser, *args, **kwargs)
return model
def fconv_self_att(*args, **kwargs):
parser = options.get_interactive_generation_parser()
model = FConvModelSelfAtt.from_pretrained(parser, *args, **kwargs)
return model
def generator(*args, **kwargs):
parser = options.get_generation_parser(interactive=True)
generator = Generator.from_pretrained(parser, *args, **kwargs)
return generator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment