Commit 62976337 authored by soqeue1's avatar soqeue1
Browse files

fix: remove assert

parent 2987beb0
......@@ -27,16 +27,16 @@ class HFLM(BaseLM):
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer, revision=revision, subfolder=subfolder)
assert isinstance(self.tokenizer, (
transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast,
transformers.T5Tokenizer, transformers.T5TokenizerFast,
)), "this tokenizer has not been checked for compatibility yet!"
# assert isinstance(self.tokenizer, (
# transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast,
# transformers.T5Tokenizer, transformers.T5TokenizerFast,
# )), "this tokenizer has not been checked for compatibility yet!"
self.vocab_size = self.tokenizer.vocab_size
if isinstance(self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)):
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373], \
self.tokenizer.encode('hello\n\nhello')
# if isinstance(self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)):
# assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373], \
# self.tokenizer.encode('hello\n\nhello')
# multithreading and batching
self.batch_size_per_gpu = batch_size # todo: adaptive batch size
......@@ -75,7 +75,7 @@ class HFLM(BaseLM):
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
......@@ -89,7 +89,7 @@ class HFLM(BaseLM):
"""
with torch.no_grad():
return self.gpt2(inps)[0][:, :, :50257]
def _model_generate(self, context, max_length, eos_token_id):
return self.gpt2.generate(
context,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment