"src/turbomind/utils/memory_utils.h" did not exist on "720fc533da804ac3f46ee938864403e51fcd9fa7"
Commit 90e50b4c authored by Leo Gao's avatar Leo Gao
Browse files

Refactor to remove generate and fix some bad tokenization

In particular, the following assumptions are FALSE in general:
tokenize(context + continuation) = tokenize(context) + tokenize(continuation)
len(tokenize(context + continuation)) = len(tokenize(context)) + len(tokenize(continuation))
tokenize(context + continuation)[:len(tokenize(context))] = tokenize(context)

So we need to tip-toe around the problem by being careful with how we do it.

In particular, using Fast is not just for performance; while behavour of GPT2Tokenizer differs across Transformers 2 and 3, GPT2TokenizerFast doesn't.
parent 6de520af
...@@ -3,19 +3,6 @@ import random ...@@ -3,19 +3,6 @@ import random
class LM(abc.ABC): class LM(abc.ABC):
@abc.abstractmethod
def generate(self, context, max_gen_length):
"""Conditional text generation with an LM
:param context: str
Context string for conditional generation
:param max_gen_length: int
Maximum number of tokens to generate
:return: str
"""
pass
@abc.abstractmethod @abc.abstractmethod
def loglikelihood(self, context, continuation): def loglikelihood(self, context, continuation):
"""Compute log-likelihood of a generation a continuation from a context """Compute log-likelihood of a generation a continuation from a context
...@@ -24,9 +11,11 @@ class LM(abc.ABC): ...@@ -24,9 +11,11 @@ class LM(abc.ABC):
context + continuation context + continuation
:param context: str :param context: str
Context string for conditional generation Context string
:param continuation: str :param continuation: str
Maximum number of tokens to generate The continuation over which log likelihood will be calculated. If
there is a word boundary, the space should be in the continuation.
For example, context="hello" continuation=" world" is correct.
:return: float :return: float
""" """
pass pass
......
...@@ -10,31 +10,19 @@ class GPT2LM(LM): ...@@ -10,31 +10,19 @@ class GPT2LM(LM):
self.device = torch.device(device) self.device = torch.device(device)
self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(self.device) self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(self.device)
self.gpt2.eval() self.gpt2.eval()
self.tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2') self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
@classmethod @classmethod
def create_from_arg_string(cls, arg_string): def create_from_arg_string(cls, arg_string):
args = utils.simple_parse_args_string(arg_string) args = utils.simple_parse_args_string(arg_string)
return cls(device=args.get("device", "cpu")) return cls(device=args.get("device", "cpu"))
def generate(self, context, max_gen_length, truncate=True):
# when too long to fit in context, truncate from the left
context_tensor = torch.tensor([self.tokenizer.encode(context.strip())[max_gen_length - 1024:]], dtype=torch.long).to(self.device)
res = self.gpt2.generate(
context_tensor,
# TODO: change to have until rather than using eos_token_id
eos_token_id=self.tokenizer.eos_token_id,
do_sample=False,
max_length=self.num_tokens(context) + max_gen_length,
)
# chop off the prompt and the final eos token
return self.tokenizer.decode(res[0][min(1024 - max_gen_length, len(context_tensor[0])):-1]).strip()
def loglikelihood(self, context, continuation, truncate=True): def loglikelihood(self, context, continuation, truncate=True):
# when too long to fit in context, truncate from the left # when too long to fit in context, truncate from the left
inp = torch.tensor([self.tokenizer.encode(context + continuation)[-1024:]], dtype=torch.long).to(self.device) context_enc = self.tokenizer.encode(context)
ctxlen = len(self.tokenizer.encode(context.strip())) continuation_enc = self.tokenizer.encode(continuation)
inp = torch.tensor([(context_enc + continuation_enc)[-1024:]], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)
cont_toks = inp[:, ctxlen:] # [batch, seq] cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab] logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
......
...@@ -18,7 +18,7 @@ class GPT3LM(LM): ...@@ -18,7 +18,7 @@ class GPT3LM(LM):
""" """
import openai import openai
self.engine = engine self.engine = engine
self.tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2') self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
self.truncate = truncate self.truncate = truncate
# Read from environment variable OPENAI_API_SECRET_KEY # Read from environment variable OPENAI_API_SECRET_KEY
...@@ -29,49 +29,22 @@ class GPT3LM(LM): ...@@ -29,49 +29,22 @@ class GPT3LM(LM):
args = utils.simple_parse_args_string(arg_string) args = utils.simple_parse_args_string(arg_string)
return cls(engine=args.get("engine", "davinci")) return cls(engine=args.get("engine", "davinci"))
def generate(self, context, max_gen_length): def loglikelihood(self, context, continuation):
import openai import openai
if self.truncate:
prompt = self.smart_truncate(context, buffer=max_gen_length)
else:
prompt = context
response = openai.Completion.create( context_enc = self.tokenizer.encode(context)
engine=self.engine, continuation_enc = self.tokenizer.encode(continuation)
prompt=prompt, inp = (context_enc + continuation_enc)[-1024:]
max_tokens=max_gen_length, ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)
temperature=0.0,
)
return response.choices[0]["text"]
def loglikelihood(self, context, continuation):
import openai
full_text = context + continuation
full_text_length = len(self.tokenizer.tokenize(full_text))
context_length = len(self.tokenizer.tokenize(context))
continuation_length = len(self.tokenizer.tokenize(continuation))
assert full_text_length == context_length + continuation_length
if self.truncate:
prompt = self.smart_truncate(full_text, buffer=0)
else:
prompt = full_text
response = openai.Completion.create( response = openai.Completion.create(
engine=self.engine, engine=self.engine,
prompt=prompt, prompt=inp,
echo=True, echo=True,
max_tokens=0, temperature=0.0, max_tokens=0, temperature=0.0,
logprobs=0, logprobs=0,
) )
logprobs = response.choices[0]["logprobs"]["token_logprobs"] logprobs = response.choices[0]["logprobs"]["token_logprobs"]
continuation_logprobs = logprobs[-continuation_length:] continuation_logprobs = logprobs[ctxlen:]
return sum(continuation_logprobs) return sum(continuation_logprobs)
def smart_truncate(self, string, buffer=1):
tokens = self.tokenizer.tokenize(string)
available_length = self.MAX_LENGTH - 1 - buffer # OpenAI adds 1 token
kept_tokens = tokens[-available_length:]
new_string = self.tokenizer.convert_tokens_to_string(kept_tokens)
return new_string
def num_tokens(self, string):
return len(self.tokenizer.tokenize(string))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment