Commit 86a63070 authored by erenup's avatar erenup
Browse files

Merge branch 'huggingface/master'

parents b5d73976 82f6abd9
......@@ -754,32 +754,59 @@ class XLMTokenizer(PreTrainedTokenizer):
out_string = ''.join(tokens).replace('</w>', ' ').strip()
return out_string
def add_special_tokens_single_sequence(self, token_ids):
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
"""
Adds special tokens to a sequence for sequence classification tasks.
An XLM sequence has the following format: [CLS] X [SEP]
"""
return [self.cls_token_id] + token_ids + [self.sep_token_id]
def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
"""
Adds special tokens to a sequence pair for sequence classification tasks.
An XLM sequence pair has the following format: [CLS] A [SEP] B [SEP]
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens.
A RoBERTa sequence has the following format:
single sequence: <s> X </s>
pair of sequences: <s> A </s></s> B </s>
"""
if token_ids_1 is None:
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
sep = [self.sep_token_id]
cls = [self.cls_token_id]
return cls + token_ids_0 + sep + token_ids_1 + sep
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1):
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
Args:
token_ids_0: list of ids (must not contain special tokens)
token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
for sequence pairs
already_has_special_tokens: (default False) Set to True if the token list is already formated with
special tokens for the model
Returns:
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
"""
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model.")
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is not None:
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1]
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
An XLM sequence pair mask has the following format:
0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1
| first sequence | second sequence
if token_ids_1 is None, only returns the first portion of the mask (0's).
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
def save_vocabulary(self, save_directory):
......
......@@ -181,36 +181,61 @@ class XLNetTokenizer(PreTrainedTokenizer):
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
return out_string
def add_special_tokens_single_sequence(self, token_ids):
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
"""
Adds special tokens to a sequence for sequence classification tasks.
An XLNet sequence has the following format: X [SEP][CLS]
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens.
A RoBERTa sequence has the following format:
single sequence: <s> X </s>
pair of sequences: <s> A </s></s> B </s>
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
return token_ids + sep + cls
if token_ids_1 is None:
return token_ids_0 + sep + cls
return token_ids_0 + sep + token_ids_1 + sep + cls
def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
"""
Adds special tokens to a sequence pair for sequence classification tasks.
An XLNet sequence pair has the following format: A [SEP] B [SEP][CLS]
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
Args:
token_ids_0: list of ids (must not contain special tokens)
token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
for sequence pairs
already_has_special_tokens: (default False) Set to True if the token list is already formated with
special tokens for the model
Returns:
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
return token_ids_0 + sep + token_ids_1 + sep + cls
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model.")
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is not None:
return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1, 1]
return ([0] * len(token_ids_0)) + [1, 1]
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1):
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
A BERT sequence pair mask has the following format:
0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 2
| first sequence | second sequence | CLS segment ID
if token_ids_1 is None, only returns the first portion of the mask (0's).
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
cls_segment_id = [2]
if token_ids_1 is None:
return len(token_ids_0 + sep + cls) * [0]
return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + cls_segment_id
def save_vocabulary(self, save_directory):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment