"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "68ccc04ee6c762183ff2b34b8b85d139f77cbf14"
Commit 05f96184 authored by thomwolf's avatar thomwolf
Browse files

logging

parent 3a848111
...@@ -125,16 +125,19 @@ class OpenAIGPTTokenizer(object): ...@@ -125,16 +125,19 @@ class OpenAIGPTTokenizer(object):
merges = [tuple(merge.split()) for merge in merges] merges = [tuple(merge.split()) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {} self.cache = {}
if not special_tokens: self.set_special_tokens(special_tokens)
self.special_tokens = {}
else:
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
def __len__(self): def __len__(self):
return len(self.encoder) + len(self.special_tokens) return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens): def set_special_tokens(self, special_tokens):
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token): def bpe(self, token):
word = tuple(token[:-1]) + ( token[-1] + '</w>',) word = tuple(token[:-1]) + ( token[-1] + '</w>',)
...@@ -189,6 +192,11 @@ class OpenAIGPTTokenizer(object): ...@@ -189,6 +192,11 @@ class OpenAIGPTTokenizer(object):
def convert_tokens_to_ids(self, tokens): def convert_tokens_to_ids(self, tokens):
"""Converts a sequence of tokens into ids using the vocab.""" """Converts a sequence of tokens into ids using the vocab."""
ids = [] ids = []
if isinstance(tokens, str):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.encoder.get(tokens, 0)
for token in tokens: for token in tokens:
if token in self.special_tokens: if token in self.special_tokens:
ids.append(self.special_tokens[token]) ids.append(self.special_tokens[token])
...@@ -206,7 +214,10 @@ class OpenAIGPTTokenizer(object): ...@@ -206,7 +214,10 @@ class OpenAIGPTTokenizer(object):
"""Converts a sequence of ids in BPE tokens using the vocab.""" """Converts a sequence of ids in BPE tokens using the vocab."""
tokens = [] tokens = []
for i in ids: for i in ids:
tokens.append(self.decoder[i]) if i in self.special_tokens_decoder:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.decoder[i])
return tokens return tokens
def decode(self, ids): def decode(self, ids):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment