Unverified Commit 2d8ee981 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Wav2Vec2] Fix tokenizer set lang (#26349)

* fix wav2vec2 doctest

* suggestion

* fix

* final fix

* revert since we need AddedTokens
parent f9ab07f9
...@@ -206,14 +206,9 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): ...@@ -206,14 +206,9 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
# make sure that tokens made of several # make sure that tokens made of several
# characters are not split at tokenization # characters are not split at tokenization
# TODO @ArthurZ add them or just update the trie?
unique_no_split_tokens = []
for token in self.encoder.keys(): for token in self.encoder.keys():
if len(token) > 1: if len(token) > 1:
unique_no_split_tokens.append(AddedToken(token, rstrip=True, lstrip=True, normalized=False)) self.add_tokens(AddedToken(token, rstrip=True, lstrip=True, normalized=False))
self.add_tokens(unique_no_split_tokens)
def set_target_lang(self, target_lang: str): def set_target_lang(self, target_lang: str):
""" """
...@@ -232,7 +227,9 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): ...@@ -232,7 +227,9 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
# make sure that tokens made of several # make sure that tokens made of several
# characters are not split at tokenization # characters are not split at tokenization
self.add_tokens([token for token in self.encoder.keys() if len(token) > 1]) for token in self.encoder.keys():
if len(token) > 1:
self.add_tokens(AddedToken(token, rstrip=True, lstrip=True, normalized=False))
@property @property
def word_delimiter_token(self) -> str: def word_delimiter_token(self) -> str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment