Commit 088fa7b7 authored by Lysandre's avatar Lysandre
Browse files

Correct segment ID for XLNet single sequence

parent 073219b4
...@@ -240,7 +240,7 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -240,7 +240,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
cls_segment_id = [2] cls_segment_id = [2]
if token_ids_1 is None: if token_ids_1 is None:
return len(token_ids_0 + sep + cls) * [0] return len(token_ids_0 + sep) * [0] + cls_segment_id
return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + cls_segment_id return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + cls_segment_id
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment