Commit 05f96184 authored by thomwolf's avatar thomwolf
Browse files

logging

parent 3a848111
......@@ -125,16 +125,19 @@ class OpenAIGPTTokenizer(object):
merges = [tuple(merge.split()) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
if not special_tokens:
self.special_tokens = {}
else:
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
self.set_special_tokens(special_tokens)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens):
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token):
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
......@@ -189,6 +192,11 @@ class OpenAIGPTTokenizer(object):
def convert_tokens_to_ids(self, tokens):
"""Converts a sequence of tokens into ids using the vocab."""
ids = []
if isinstance(tokens, str):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.encoder.get(tokens, 0)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
......@@ -206,6 +214,9 @@ class OpenAIGPTTokenizer(object):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens = []
for i in ids:
if i in self.special_tokens_decoder:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.decoder[i])
return tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment