Unverified Commit dd54a4b0 authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

Fix detokenization leaving special tokens (#1044)


Signed-off-by: default avatarAntoni Baum <antoni.baum@protonmail.com>
parent eda1a7ca
...@@ -23,7 +23,8 @@ TOKENIZERS = [ ...@@ -23,7 +23,8 @@ TOKENIZERS = [
] ]
def _run_incremental_decode(tokenizer, all_input_ids): def _run_incremental_decode(tokenizer, all_input_ids,
skip_special_tokens: bool):
decoded_text = "" decoded_text = ""
offset = 0 offset = 0
token_offset = 0 token_offset = 0
...@@ -35,7 +36,7 @@ def _run_incremental_decode(tokenizer, all_input_ids): ...@@ -35,7 +36,7 @@ def _run_incremental_decode(tokenizer, all_input_ids):
prev_tokens, prev_tokens,
offset, offset,
token_offset, token_offset,
skip_special_tokens=False) skip_special_tokens=skip_special_tokens)
decoded_text += text decoded_text += text
if prev_tokens is None: if prev_tokens is None:
prev_tokens = new_tokens prev_tokens = new_tokens
...@@ -46,10 +47,16 @@ def _run_incremental_decode(tokenizer, all_input_ids): ...@@ -46,10 +47,16 @@ def _run_incremental_decode(tokenizer, all_input_ids):
@pytest.mark.parametrize("truth", TRUTH) @pytest.mark.parametrize("truth", TRUTH)
@pytest.mark.parametrize("tokenizer_id", TOKENIZERS) @pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
def test_decode_streaming(tokenizer_id, truth): @pytest.mark.parametrize("skip_special_tokens", (True, False))
def test_decode_streaming(tokenizer_id, truth, skip_special_tokens):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"] all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
if skip_special_tokens:
all_input_ids = ([tokenizer.bos_token_id]
if tokenizer.bos_token_id is not None else
[]) + all_input_ids + [tokenizer.eos_token_id]
decoded_text = _run_incremental_decode(tokenizer, all_input_ids) decoded_text = _run_incremental_decode(
tokenizer, all_input_ids, skip_special_tokens=skip_special_tokens)
assert decoded_text == truth assert decoded_text == truth
...@@ -119,9 +119,9 @@ def detokenize_incrementally( ...@@ -119,9 +119,9 @@ def detokenize_incrementally(
prefix_offset = max(len(output_tokens) - 6, 0) prefix_offset = max(len(output_tokens) - 6, 0)
read_offset = max(len(output_tokens) - 1, 0) read_offset = max(len(output_tokens) - 1, 0)
else: else:
new_token = tokenizer.convert_ids_to_tokens( # Put new_token_id in a list so skip_special_tokens is respected
new_token_id, skip_special_tokens=skip_special_tokens) new_tokens = tokenizer.convert_ids_to_tokens(
new_tokens = [new_token] [new_token_id], skip_special_tokens=skip_special_tokens)
output_tokens = prev_tokens + new_tokens output_tokens = prev_tokens + new_tokens
# The prefix text is necessary only to defeat cleanup algorithms in # The prefix text is necessary only to defeat cleanup algorithms in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment