Commit 12c2ee1e authored by Leo Gao's avatar Leo Gao
Browse files

Merge branch 'master' of github.com:EleutherAI/lm_evaluation_harness

parents cf69ba9c 61ff104e
......@@ -3,44 +3,20 @@ import random
class LM(abc.ABC):
@abc.abstractmethod
def generate(self, context, max_gen_length):
"""Conditional text generation with an LM
:param context: str
Context string for conditional generation
:param max_gen_length: int
Maximum number of tokens to generate
:return: str
"""
pass
@abc.abstractmethod
def loglikelihood(self, context, continuation):
"""Compute log-likelihood of a generation a continuation from a context
Assume that the final text will simple be
context + continuation
"""Compute log-likelihood of generating a continuation from a context
:param context: str
Context string for conditional generation
Context string
:param continuation: str
Maximum number of tokens to generate
The continuation over which log likelihood will be calculated. If
there is a word boundary, the space should be in the continuation.
For example, context="hello" continuation=" world" is correct.
:return: float
"""
pass
@classmethod
def num_tokens(cls, string):
"""Return the number of tokens in a string, based on tokenization
:param string: str
Input string
:return: int
"""
pass
@classmethod
def create_from_arg_string(cls, arg_string):
"""Constructor method, in case models need additional arguments
......
......@@ -5,8 +5,5 @@ from . import MODEL_REGISTRY
@MODEL_REGISTRY.register("dummy")
class DummyLM(LM):
def generate(self, context, max_gen_length):
return "lol"
def loglikelihood(self, context, continuation):
return 0.0
......@@ -10,36 +10,21 @@ class GPT2LM(LM):
self.device = torch.device(device)
self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(self.device)
self.gpt2.eval()
self.tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
@classmethod
def create_from_arg_string(cls, arg_string):
args = utils.simple_parse_args_string(arg_string)
return cls(device=args.get("device", "cpu"))
def generate(self, context, max_gen_length, truncate=True):
# when too long to fit in context, truncate from the left
context_tensor = torch.tensor([self.tokenizer.encode(context.strip())[max_gen_length - 1024:]], dtype=torch.long).to(self.device)
res = self.gpt2.generate(
context_tensor,
# TODO: change to have until rather than using eos_token_id
eos_token_id=self.tokenizer.eos_token_id,
do_sample=False,
max_length=self.num_tokens(context) + max_gen_length,
)
# chop off the prompt and the final eos token
return self.tokenizer.decode(res[0][min(1024 - max_gen_length, len(context_tensor[0])):-1]).strip()
def loglikelihood(self, context, continuation, truncate=True):
# when too long to fit in context, truncate from the left
inp = torch.tensor([self.tokenizer.encode(context + continuation)[-1024:]], dtype=torch.long).to(self.device)
ctxlen = len(self.tokenizer.encode(context.strip()))
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
inp = torch.tensor([(context_enc + continuation_enc)[-1024:]], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)
cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
return torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1)
def num_tokens(self, string):
return len(self.tokenizer.tokenize(string))
......@@ -18,7 +18,7 @@ class GPT3LM(LM):
"""
import openai
self.engine = engine
self.tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
self.truncate = truncate
# Read from environment variable OPENAI_API_SECRET_KEY
......@@ -29,49 +29,21 @@ class GPT3LM(LM):
args = utils.simple_parse_args_string(arg_string)
return cls(engine=args.get("engine", "davinci"))
def generate(self, context, max_gen_length):
import openai
if self.truncate:
prompt = self.smart_truncate(context, buffer=max_gen_length)
else:
prompt = context
response = openai.Completion.create(
engine=self.engine,
prompt=prompt,
max_tokens=max_gen_length,
temperature=0.0,
)
return response.choices[0]["text"]
def loglikelihood(self, context, continuation):
import openai
full_text = context + continuation
full_text_length = len(self.tokenizer.tokenize(full_text))
context_length = len(self.tokenizer.tokenize(context))
continuation_length = len(self.tokenizer.tokenize(continuation))
assert full_text_length == context_length + continuation_length
if self.truncate:
prompt = self.smart_truncate(full_text, buffer=0)
else:
prompt = full_text
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
inp = (context_enc + continuation_enc)[-1024:]
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)
response = openai.Completion.create(
engine=self.engine,
prompt=prompt,
prompt=inp,
echo=True,
max_tokens=0, temperature=0.0,
logprobs=0,
)
logprobs = response.choices[0]["logprobs"]["token_logprobs"]
continuation_logprobs = logprobs[-continuation_length:]
continuation_logprobs = logprobs[ctxlen:]
return sum(continuation_logprobs)
def smart_truncate(self, string, buffer=1):
tokens = self.tokenizer.tokenize(string)
available_length = self.MAX_LENGTH - 1 - buffer # OpenAI adds 1 token
kept_tokens = tokens[-available_length:]
new_string = self.tokenizer.convert_tokens_to_string(kept_tokens)
return new_string
def num_tokens(self, string):
return len(self.tokenizer.tokenize(string))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment