Commit d48895bd authored by Naman Goyal's avatar Naman Goyal Committed by Facebook Github Bot
Browse files

fixed word level extract features for roberta-xlmr

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/933

Differential Revision: D18783780

fbshipit-source-id: fa0a27fab886a5fa5be8d5f49151d1d9dd9775f1
parent 1c565940
...@@ -22,6 +22,7 @@ def align_bpe_to_words(roberta, bpe_tokens: torch.LongTensor, other_tokens: List ...@@ -22,6 +22,7 @@ def align_bpe_to_words(roberta, bpe_tokens: torch.LongTensor, other_tokens: List
List[str]: mapping from *other_tokens* to corresponding *bpe_tokens*. List[str]: mapping from *other_tokens* to corresponding *bpe_tokens*.
""" """
assert bpe_tokens.dim() == 1 assert bpe_tokens.dim() == 1
assert bpe_tokens[0] == 0
def clean(text): def clean(text):
return text.strip() return text.strip()
...@@ -32,7 +33,6 @@ def align_bpe_to_words(roberta, bpe_tokens: torch.LongTensor, other_tokens: List ...@@ -32,7 +33,6 @@ def align_bpe_to_words(roberta, bpe_tokens: torch.LongTensor, other_tokens: List
other_tokens = [clean(str(o)) for o in other_tokens] other_tokens = [clean(str(o)) for o in other_tokens]
# strip leading <s> # strip leading <s>
assert bpe_tokens[0] == '<s>'
bpe_tokens = bpe_tokens[1:] bpe_tokens = bpe_tokens[1:]
assert ''.join(bpe_tokens) == ''.join(other_tokens) assert ''.join(bpe_tokens) == ''.join(other_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment