"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "8110a872d5b8fb88cd28c694486781f2dcb9a130"
Unverified Commit 12f043ea authored by Tanay Mehta's avatar Tanay Mehta Committed by GitHub
Browse files

Fix `MarianTokenizer` to remove metaspace character in `decode` (#26091)

* add: check to remove metaspace from marian tokenizer

* fix: metaspace character being removed from everywhere

* fix: remove redundant check at top

* add: test for marian tokenizer decode fix

* fix: simplified the test
parent 03e309d5
...@@ -55,6 +55,8 @@ PRETRAINED_VOCAB_FILES_MAP = { ...@@ -55,6 +55,8 @@ PRETRAINED_VOCAB_FILES_MAP = {
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"Helsinki-NLP/opus-mt-en-de": 512} PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"Helsinki-NLP/opus-mt-en-de": 512}
PRETRAINED_INIT_CONFIGURATION = {} PRETRAINED_INIT_CONFIGURATION = {}
SPIECE_UNDERLINE = "▁"
# Example URL https://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/vocab.json # Example URL https://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/vocab.json
...@@ -278,6 +280,7 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -278,6 +280,7 @@ class MarianTokenizer(PreTrainedTokenizer):
else: else:
current_sub_tokens.append(token) current_sub_tokens.append(token)
out_string += sp_model.decode_pieces(current_sub_tokens) out_string += sp_model.decode_pieces(current_sub_tokens)
out_string = out_string.replace(SPIECE_UNDERLINE, " ")
return out_string.strip() return out_string.strip()
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]: def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
......
...@@ -149,3 +149,10 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -149,3 +149,10 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
decoded = tokenizer.decode(target_ids, skip_special_tokens=True) decoded = tokenizer.decode(target_ids, skip_special_tokens=True)
self.assertEqual(decoded, target_text) self.assertEqual(decoded, target_text)
def test_tokenizer_decode(self):
tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-es")
source_text = "Hello World"
ids = tokenizer(source_text)["input_ids"]
output_text = tokenizer.decode(ids, skip_special_tokens=True)
self.assertEqual(source_text, output_text)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment