Commit 90e50b4c authored by Leo Gao's avatar Leo Gao
Browse files

Refactor to remove generate and fix some bad tokenization

In particular, the following assumptions are FALSE in general:
tokenize(context + continuation) = tokenize(context) + tokenize(continuation)
len(tokenize(context + continuation)) = len(tokenize(context)) + len(tokenize(continuation))
tokenize(context + continuation)[:len(tokenize(context))] = tokenize(context)

So we need to tip-toe around the problem by being careful with how we do it.

In particular, using Fast is not just for performance; while behavour of GPT2Tokenizer differs across Transformers 2 and 3, GPT2TokenizerFast doesn't.
parent 6de520af
......@@ -3,19 +3,6 @@ import random
class LM(abc.ABC):
@abc.abstractmethod
def generate(self, context, max_gen_length):
"""Conditional text generation with an LM
:param context: str
Context string for conditional generation
:param max_gen_length: int
Maximum number of tokens to generate
:return: str
"""
pass
@abc.abstractmethod
def loglikelihood(self, context, continuation):
"""Compute log-likelihood of a generation a continuation from a context
......@@ -24,9 +11,11 @@ class LM(abc.ABC):
context + continuation
:param context: str
Context string for conditional generation
Context string
:param continuation: str
Maximum number of tokens to generate
The continuation over which log likelihood will be calculated. If
there is a word boundary, the space should be in the continuation.
For example, context="hello" continuation=" world" is correct.
:return: float
"""
pass
......
......@@ -10,31 +10,19 @@ class GPT2LM(LM):
self.device = torch.device(device)
self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(self.device)
self.gpt2.eval()
self.tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
@classmethod
def create_from_arg_string(cls, arg_string):
args = utils.simple_parse_args_string(arg_string)
return cls(device=args.get("device", "cpu"))
def generate(self, context, max_gen_length, truncate=True):
# when too long to fit in context, truncate from the left
context_tensor = torch.tensor([self.tokenizer.encode(context.strip())[max_gen_length - 1024:]], dtype=torch.long).to(self.device)
res = self.gpt2.generate(
context_tensor,
# TODO: change to have until rather than using eos_token_id
eos_token_id=self.tokenizer.eos_token_id,
do_sample=False,
max_length=self.num_tokens(context) + max_gen_length,
)
# chop off the prompt and the final eos token
return self.tokenizer.decode(res[0][min(1024 - max_gen_length, len(context_tensor[0])):-1]).strip()
def loglikelihood(self, context, continuation, truncate=True):
# when too long to fit in context, truncate from the left
inp = torch.tensor([self.tokenizer.encode(context + continuation)[-1024:]], dtype=torch.long).to(self.device)
ctxlen = len(self.tokenizer.encode(context.strip()))
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
inp = torch.tensor([(context_enc + continuation_enc)[-1024:]], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)
cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
......
......@@ -18,7 +18,7 @@ class GPT3LM(LM):
"""
import openai
self.engine = engine
self.tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
self.truncate = truncate
# Read from environment variable OPENAI_API_SECRET_KEY
......@@ -29,49 +29,22 @@ class GPT3LM(LM):
args = utils.simple_parse_args_string(arg_string)
return cls(engine=args.get("engine", "davinci"))
def generate(self, context, max_gen_length):
import openai
if self.truncate:
prompt = self.smart_truncate(context, buffer=max_gen_length)
else:
prompt = context
response = openai.Completion.create(
engine=self.engine,
prompt=prompt,
max_tokens=max_gen_length,
temperature=0.0,
)
return response.choices[0]["text"]
def loglikelihood(self, context, continuation):
import openai
full_text = context + continuation
full_text_length = len(self.tokenizer.tokenize(full_text))
context_length = len(self.tokenizer.tokenize(context))
continuation_length = len(self.tokenizer.tokenize(continuation))
assert full_text_length == context_length + continuation_length
if self.truncate:
prompt = self.smart_truncate(full_text, buffer=0)
else:
prompt = full_text
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
inp = (context_enc + continuation_enc)[-1024:]
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)
response = openai.Completion.create(
engine=self.engine,
prompt=prompt,
prompt=inp,
echo=True,
max_tokens=0, temperature=0.0,
logprobs=0,
)
logprobs = response.choices[0]["logprobs"]["token_logprobs"]
continuation_logprobs = logprobs[-continuation_length:]
continuation_logprobs = logprobs[ctxlen:]
return sum(continuation_logprobs)
def smart_truncate(self, string, buffer=1):
tokens = self.tokenizer.tokenize(string)
available_length = self.MAX_LENGTH - 1 - buffer # OpenAI adds 1 token
kept_tokens = tokens[-available_length:]
new_string = self.tokenizer.convert_tokens_to_string(kept_tokens)
return new_string
def num_tokens(self, string):
return len(self.tokenizer.tokenize(string))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment