Commit 12e12bc0 authored by Jason Phang's avatar Jason Phang
Browse files

gpt3

parent 7a32afeb
import abc
import random
class LM(abc.ABC):
@abc.abstractmethod
def generate(self, context, until):
def generate(self, context, max_gen_length):
"""Conditional text generation with an LM
:param context: str
Context string for conditional generation
:param max_gen_length: int
Maximum number of tokens to generate
:return: str
"""
pass
@abc.abstractmethod
def loglikelihood(self, context, continuation):
"""Compute log-prob of a generation a continuation from a context
Assume that the final text will simple be
context + continuation
:param context: str
Context string for conditional generation
:param continuation: str
Maximum number of tokens to generate
:return: float
"""
pass
@classmethod
def create_from_arg_string(cls, arg_string):
"""Constructor method, in case models need additional arguments
e.g. OpenAI API engine, paths for loading, other params
:param arg_string: str
Left up to individual model class to handle
"""
return cls()
class Dataset(abc.ABC):
@abc.abstractmethod
......@@ -50,4 +82,17 @@ class Dataset(abc.ABC):
@abc.abstractmethod
def evaluate(self, docs, lm, provide_description, num_fewshot):
pass
class Registry:
def __init__(self, registry_name):
self.registry_name = registry_name
self.registry = {}
def register(self, name):
def register_cls(new_cls):
if name in self.registry:
raise ValueError('Cannot register duplicate ({})'.format(self.registry_name, name))
self.registry[name] = new_cls
return new_cls
return register_cls
from gpt2 import GPT2LM
from models.gpt2 import GPT2LM
lm = GPT2LM()
......
import importlib
import os
from ..base import Registry
MODEL_REGISTRY = Registry(registry_name="models")
# Load all modules in models directory to populate registry
models_dir = os.path.dirname(__file__)
for file in os.listdir(models_dir):
path = os.path.join(models_dir, file)
if (
not file.startswith('_')
and not file.startswith('.')
and (file.endswith('.py') or os.path.isdir(path))
):
module_name = file[:file.find('.py')] if file.endswith('.py') else file
module = importlib.import_module('lm_evaluation_harness.models.' + module_name)
def get_model(model_name):
return MODEL_REGISTRY.registry[model_name]
import transformers
import torch
from ..base import LM
from . import MODEL_REGISTRY
@MODEL_REGISTRY.register("gpt2")
class GPT2LM(LM):
def __init__(self):
self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2')
self.tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
def generate(self, context, max_gen_length):
context = torch.tensor([self.tokenizer.encode(context.strip())], dtype=torch.long)
res = self.gpt2.generate(
context,
eos_token_id=self.tokenizer.eos_token_id,
do_sample=False,
max_length=max_gen_length,
)
# chop off the prompt and the final eos token
return self.tok.decode(res[0][len(context[0]):-1]).strip()
def nll_of(self, context, continuation):
pass
import os
import openai
import transformers
from ..base import LM
from .. import utils
from . import MODEL_REGISTRY
@MODEL_REGISTRY.register("gpt3")
class GPT3LM(LM):
def __init__(self, engine):
self.engine = engine
self.tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
# Read from environment variable OPENAI_API_SECRET_KEY
openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
@classmethod
def create_from_args(cls, arg_string):
args = utils.simple_parse_args_string(arg_string)
return cls(engine=args.get("engine", "davinci"))
def generate(self, context, max_gen_length):
response = openai.Completion.create(
engine=self.engine,
prompt=context,
max_tokens=max_gen_length,
temperature=0.0,
)
return response.choices[0]["text"]
def logprob_of(self, context, continuation):
full_text = context + continuation
full_text_length = len(self.tokenizer.tokenize(full_text))
context_length = len(self.tokenizer.tokenize(context))
continuation_length = len(self.tokenizer.tokenize(continuation))
assert full_text_length == context_length + continuation_length
response = openai.Completion.create(
engine=self.engine,
prompt=full_text,
max_tokens=0, temperature=0.0,
logprobs=0,
)
logprobs = response.choices[0]["logprobs"]["token_logprobs"]
continuation_logprobs = logprobs[-continuation_length:]
return sum(continuation_logprobs)
import importlib
import os
from ..base import Registry
TASK_REGISTRY = Registry(registry_name="tasks")
# Load all modules in models directory to populate registry
tasks_dir = os.path.dirname(__file__)
for file in os.listdir(tasks_dir):
path = os.path.join(tasks_dir, file)
if (
not file.startswith('_')
and not file.startswith('.')
and (file.endswith('.py') or os.path.isdir(path))
):
module_name = file[:file.find('.py')] if file.endswith('.py') else file
module = importlib.import_module('lm_evaluation_harness.tasks.' + module_name)
def get_task(model_name):
return TASK_REGISTRY.registry[model_name]
from base import Dataset
import os
import json
import random
from ..base import Dataset
from . import TASK_REGISTRY
@TASK_REGISTRY.register("coqa")
class CoQA(Dataset):
def has_training_docs(self):
return True
......
def simple_parse_args_string(args_string):
"""
Parses something like
args1=val1,arg2=val2
Into a dictionary
"""
args_string = args_string.split()
if not args_string:
return {}
arg_list = args_string.split(",")
args_dict = {}
for arg, in arg_list:
k, v = arg.split("=")
args_dict[k] = v
return args_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment