"src/vscode:/vscode.git/clone" did not exist on "5d848ec07c2011d600ce5e5c1aa02a03152aea9b"
Commit 2f6d8b35 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add missing files for RoBERTa hub interface

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/923

Differential Revision: D16541289

Pulled By: myleott

fbshipit-source-id: b3563a9d61507d4864ac6ecf0648672eaa40b5f3
parent 36df0dad
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from .hub_interface import * # noqa
from .model import * # noqa
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.data import encoders
class RobertaHubInterface(nn.Module):
"""A simple PyTorch Hub interface to RoBERTa.
Usage: https://github.com/pytorch/fairseq/tree/master/examples/roberta
"""
def __init__(self, args, task, model):
super().__init__()
self.args = args
self.task = task
self.model = model
self.bpe = encoders.build_bpe(args)
# this is useful for determining the device
self.register_buffer('_float_tensor', torch.tensor([0], dtype=torch.float))
@property
def device(self):
return self._float_tensor.device
def encode(self, sentence: str, *addl_sentences) -> torch.LongTensor:
bpe_sentence = '<s> ' + self.bpe.encode(sentence) + ' </s>'
for s in addl_sentences:
bpe_sentence += ' </s> ' + self.bpe.encode(s)
tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=True)
return tokens.long()
def extract_features(self, tokens: torch.LongTensor, return_all_hiddens=False) -> torch.Tensor:
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)
features, extra = self.model(
tokens.to(device=self.device),
features_only=True,
return_all_hiddens=return_all_hiddens,
)
if return_all_hiddens:
# convert from T x B x C -> B x T x C
inner_states = extra['inner_states']
return [inner_state.transpose(0, 1) for inner_state in inner_states]
else:
return features # just the last layer's features
def register_classification_head(
self, name: str, num_classes: int = None, embedding_size: int = None, **kwargs
):
self.model.register_classification_head(
name, num_classes=num_classes, embedding_size=embedding_size, **kwargs
)
def predict(self, head: str, tokens: torch.LongTensor):
features = self.extract_features(tokens)
logits = self.model.classification_heads[head](features)
return F.log_softmax(logits, dim=-1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment