Commit 53f6bc34 authored by Jason Phang's avatar Jason Phang
Browse files

update GPT3 test data and more docs

parent 76ebb792
......@@ -11,7 +11,7 @@ import numpy as np
def simple_evaluate(model, model_args, task_names,
num_fewshot=0, batch_size=None, device=None,
no_cache=False, limit=None, bootstrap_iters=100000):
"""
"""Instantiate and evaluate a model on a list of tasks.
:param model: str
Name of model, see lm_eval.models.get_model
......@@ -24,7 +24,7 @@ def simple_evaluate(model, model_args, task_names,
:param batch_size: int, optional
Batch size for model
:param device: str, optional
PyTorch device (e.g. "cpu" or "cuda:0") for running models
:param no_cache: bool
Whether or not
:param limit: int, optional
......@@ -32,6 +32,7 @@ def simple_evaluate(model, model_args, task_names,
:param bootstrap_iters:
Number of iterations for bootstrap statistics
:return
Dictionary of results
"""
random.seed(1234)
np.random.seed(1234)
......@@ -64,6 +65,23 @@ def simple_evaluate(model, model_args, task_names,
def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_iters=100000):
"""Instantiate and evaluate a model on a list of tasks.
:param lm: obj
Language Model
:param task_dict: dict[str, Task]
Dictionary of tasks
:param provide_description: bool
NOT IMPLEMENTED
:param num_fewshot: int
Number of examples in few-shot context
:param limit: int, optional
Limit the number of examples per task (only use this for testing)
:param bootstrap_iters:
Number of iterations for bootstrap statistics
:return
Dictionary of results
"""
# TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces
# TODO: todo: implement proper description-providing system
......
......@@ -8,6 +8,18 @@ import time
def get_result(response, ctxlen):
"""Process results from OpenAI API response.
:param response: dict
OpenAI API Response
:param ctxlen: int
Length of context (so we can slice them away and only keep the predictions)
:return:
continuation_logprobs: np.array
Log probabilities of continuation tokens
is_greedy: bool
whether argmax matches given continuation exactly
"""
is_greedy = True
logprobs = response["logprobs"]["token_logprobs"]
continuation_logprobs = sum(logprobs[ctxlen:])
......
......@@ -19,6 +19,7 @@ def mock_completion(**kwargs):
with open(fname, 'rb') as fh:
return pickle.load(fh)
ret = openai.Completion.create(**kwargs)
ret.api_key = ""
with open(fname, 'wb') as fh:
pickle.dump(ret, fh)
return ret
......@@ -65,8 +66,8 @@ def test_gpt3():
print([x[0] for x in vals])
targets = [
-34.85833048, -47.114367866, -45.43520782100001, -5.289627985, -133.96879783896998, -321.30299892039994,
-658.0542459504098, -34.85833048, -7.5162964
-34.848301606999996, -47.148329679999996, -45.44380149599999, -5.285246016, -133.97821690686004,
-321.2616693239001, -658.0299524401041, -34.848301606999996, -7.525115,
]
for (pred, _), tgt in zip(vals, targets):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment