Commit e5066c69 authored by Leo Gao's avatar Leo Gao
Browse files

Make gpt2 a little more tokenizer agnostic

parent 7f8d6676
......@@ -10,10 +10,8 @@ import numpy as np
class GPT2LM(LM):
MAX_GEN_TOKS = 256
VOCAB_SIZE = 50257
EOT_TOKEN_ID = 50256
def __init__(self, device='cuda', pretrained='gpt2', revision="main", batch_size=1):
def __init__(self, device='cuda', pretrained='gpt2', revision='main', subfolder=None, tokenizer=None, batch_size=1):
super().__init__()
assert isinstance(device, str)
......@@ -24,19 +22,31 @@ class GPT2LM(LM):
self.device = torch.device(device)
else:
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(pretrained, revision=revision).to(self.device)
# TODO: update this to be less of a hack once subfolder is fixed in HF
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(pretrained, revision=revision +("/" + subfolder if subfolder is not None else "")).to(self.device)
self.gpt2.eval()
# pretrained tokenizer for neo is broken for now so just hardcoding this to gpt2
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
self.tokenizer.pad_token = "<|endoftext|>"
self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained if tokenizer is None else tokenizer, revision=revision, subfolder=subfolder)
assert isinstance(self.tokenizer, (
transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast,
transformers.T5Tokenizer, transformers.T5TokenizerFast,
)), "this tokenizer has not been checked for compatibility yet!"
self.VOCAB_SIZE = self.tokenizer.vocab_size
self.EOT_TOKEN_ID = self.tokenizer.eos_token_id
print(self.EOT_TOKEN_ID)
try:
self.max_length = self.gpt2.config.n_ctx
except AttributeError:
# gptneoconfig doesn't have n_ctx apparantly
self.max_length = self.gpt2.config.max_position_embeddings
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]
if isinstance(self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)):
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]
# multithreading and batching
gpus = torch.cuda.device_count()
......@@ -62,9 +72,9 @@ class GPT2LM(LM):
# end of text as context
context_enc = [self.EOT_TOKEN_ID]
else:
context_enc = self.tokenizer.encode(context)
context_enc = self.tokenizer.encode(context, add_special_tokens=False)
continuation_enc = self.tokenizer.encode(continuation)
continuation_enc = self.tokenizer.encode(continuation, add_special_tokens=False)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
......@@ -78,7 +88,7 @@ class GPT2LM(LM):
with torch.no_grad():
for string, in tqdm(requests):
rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows(
token_list=self.tokenizer.encode(string),
token_list=self.tokenizer.encode(string, add_special_tokens=False),
prefix_token=self.EOT_TOKEN_ID,
max_seq_len=self.max_length,
context_len=1,
......@@ -203,7 +213,7 @@ class GPT2LM(LM):
res = []
def _collate(x):
toks = self.tokenizer.encode(x[0])
toks = self.tokenizer.encode(x[0], add_special_tokens=False)
return (len(toks), x[0])
reord = utils.Reorderer(requests, _collate)
......@@ -211,9 +221,9 @@ class GPT2LM(LM):
for context, until in tqdm(reord.get_reordered()):
if isinstance(until, str): until = [until]
context_enc = torch.tensor([self.tokenizer.encode(context)[self.MAX_GEN_TOKS - self.max_length:]]).to(self.device)
context_enc = torch.tensor([self.tokenizer.encode(context, add_special_tokens=False)[self.MAX_GEN_TOKS - self.max_length:]]).to(self.device)
primary_until, = self.tokenizer.encode(until[0])
primary_until, = self.tokenizer.encode(until[0], add_special_tokens=False)
cont = self.gpt2.generate(
context_enc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment