"git@developer.sourcefind.cn:change/sglang.git" did not exist on "ab9b893e61c47c5c25ee59934e3881b99401e35b"
Commit c3df2136 authored by LysandreJik's avatar LysandreJik
Browse files

Added binary masking tests

parent e391d473
......@@ -186,3 +186,15 @@ class CommonTestCases:
for weights_list_2 in weights_lists_2:
self.assertListEqual(weights_list, weights_list_2)
def test_mask_output(self):
if sys.version_info <= (3, 0):
return
tokenizer = self.get_tokenizer()
if tokenizer.add_special_tokens_sentences_pair.__qualname__.split('.')[0] != "PreTrainedTokenizer":
seq_0 = "Test this method."
seq_1 = "With these inputs."
sequences, mask = tokenizer.encode(seq_0, seq_1, add_special_tokens=True, output_mask=True)
assert len(sequences) == len(mask)
......@@ -690,6 +690,8 @@ class PreTrainedTokenizer(object):
if add_special_tokens:
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens, output_mask)
else:
if output_mask:
logger.warning("Can't output mask if no special tokens are involved. Please call the method with add_special_tokens set to True.")
return first_sentence_tokens, second_sentence_tokens
def add_special_tokens_single_sentence(self, token_ids):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment