Commit d82517e9 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add roberta.decode to hub interface to decode BPE (#931)

Summary:
Fixes https://github.com/pytorch/fairseq/issues/930.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/931

Differential Revision: D16562511

Pulled By: myleott

fbshipit-source-id: c4c07e2f067326b79daa547dcb3db84aeddbd555
parent 3b2cecda
......@@ -60,6 +60,8 @@ $ tar -xzvf roberta.large.tar.gz
>>> tokens = roberta.encode('Hello world!')
>>> tokens
tensor([ 0, 31414, 232, 328, 2])
>>> roberta.decode(tokens)
'Hello world!'
```
##### Extract features from RoBERTa:
......
......@@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -38,6 +39,19 @@ class RobertaHubInterface(nn.Module):
tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=False)
return tokens.long()
def decode(self, tokens: torch.LongTensor):
assert tokens.dim() == 1
tokens = tokens.numpy()
if tokens[0] == self.task.source_dictionary.bos():
tokens = tokens[1:] # remove <s>
eos_mask = (tokens == self.task.source_dictionary.eos())
doc_mask = eos_mask[1:] & eos_mask[:-1]
sentences = np.split(tokens, doc_mask.nonzero()[0] + 1)
sentences = [self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences]
if len(sentences) == 1:
return sentences[0]
return sentences
def extract_features(self, tokens: torch.LongTensor, return_all_hiddens=False) -> torch.Tensor:
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment