Commit 01a3966b authored by thomwolf's avatar thomwolf
Browse files

more options on special tokens

parent 05f96184
......@@ -131,6 +131,10 @@ class OpenAIGPTTokenizer(object):
return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
......@@ -210,18 +214,19 @@ class OpenAIGPTTokenizer(object):
)
return ids
def convert_ids_to_tokens(self, ids):
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens = []
for i in ids:
if i in self.special_tokens_decoder:
if not skip_special_tokens:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.decoder[i])
return tokens
def decode(self, ids):
def decode(self, ids, skip_special_tokens=False):
"""Converts a sequence of ids in a string."""
tokens = self.convert_ids_to_tokens(ids)
tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
out_string = ''.join(tokens).replace('</w>', ' ')
return out_string
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment