Commit 88ed8930 authored by comfyanonymous's avatar comfyanonymous
Browse files

Allow SPieceTokenizer to load model from a byte string.

parent 334ba48c
import os import os
class SPieceTokenizer: class SPieceTokenizer:
add_eos = True
@staticmethod @staticmethod
def from_pretrained(path): def from_pretrained(path):
return SPieceTokenizer(path) return SPieceTokenizer(path)
def __init__(self, tokenizer_path): def __init__(self, tokenizer_path):
import sentencepiece import sentencepiece
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path) if isinstance(tokenizer_path, bytes):
self.end = self.tokenizer.eos_id() self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_eos=self.add_eos)
else:
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_eos=self.add_eos)
def get_vocab(self): def get_vocab(self):
out = {} out = {}
...@@ -18,5 +22,4 @@ class SPieceTokenizer: ...@@ -18,5 +22,4 @@ class SPieceTokenizer:
def __call__(self, string): def __call__(self, string):
out = self.tokenizer.encode(string) out = self.tokenizer.encode(string)
out += [self.end]
return {"input_ids": out} return {"input_ids": out}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment