Commit c132b9b9 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix tokenization (fixes #926) (#929)

Summary:
Fixes https://github.com/pytorch/fairseq/issues/926
Pull Request resolved: https://github.com/pytorch/fairseq/pull/929

Differential Revision: D16560281

Pulled By: myleott

fbshipit-source-id: 751051bcdbf25207315bb05f5bee0235d21be627
parent 138dc8e4
...@@ -36,8 +36,8 @@ class RobertaHubInterface(nn.Module): ...@@ -36,8 +36,8 @@ class RobertaHubInterface(nn.Module):
def encode(self, sentence: str, *addl_sentences) -> torch.LongTensor: def encode(self, sentence: str, *addl_sentences) -> torch.LongTensor:
bpe_sentence = '<s> ' + self.bpe.encode(sentence) + ' </s>' bpe_sentence = '<s> ' + self.bpe.encode(sentence) + ' </s>'
for s in addl_sentences: for s in addl_sentences:
bpe_sentence += ' </s> ' + self.bpe.encode(s) bpe_sentence += ' </s> ' + self.bpe.encode(s) + ' </s>'
tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=True) tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=False)
return tokens.long() return tokens.long()
def extract_features(self, tokens: torch.LongTensor, return_all_hiddens=False) -> torch.Tensor: def extract_features(self, tokens: torch.LongTensor, return_all_hiddens=False) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment