"...kernels/git@developer.sourcefind.cn:change/sglang.git" did not exist on "3844feb9bb1cdd1ee59653b85e3b40e8a4d107d1"
Commit b3c6ee0a authored by thomwolf's avatar thomwolf
Browse files

tokenization updates

parent 20577d8a
...@@ -135,9 +135,10 @@ class BertTokenizer(object): ...@@ -135,9 +135,10 @@ class BertTokenizer(object):
return tokens return tokens
def save_vocabulary(self, vocab_path): def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a path.""" """Save the tokenizer vocabulary to a directory or file."""
index = 0 index = 0
vocab_file = os.path.join(vocab_path, VOCAB_NAME) if os.path.isdir(vocab_path):
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
with open(vocab_file, "w", encoding="utf-8") as writer: with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index: if index != token_index:
......
...@@ -145,8 +145,10 @@ class TransfoXLTokenizer(object): ...@@ -145,8 +145,10 @@ class TransfoXLTokenizer(object):
raise ValueError('No <unkown> token in vocabulary') raise ValueError('No <unkown> token in vocabulary')
def save_vocabulary(self, vocab_path): def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a directory or file."""
index = 0 index = 0
vocab_file = os.path.join(vocab_path, VOCAB_NAME) if os.path.isdir(vocab_path):
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
torch.save(self.__dict__, vocab_file) torch.save(self.__dict__, vocab_file)
return vocab_file return vocab_file
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment