Commit 53f6bc34 authored by Jason Phang's avatar Jason Phang
Browse files

update GPT3 test data and more docs

parent 76ebb792
...@@ -11,7 +11,7 @@ import numpy as np ...@@ -11,7 +11,7 @@ import numpy as np
def simple_evaluate(model, model_args, task_names, def simple_evaluate(model, model_args, task_names,
num_fewshot=0, batch_size=None, device=None, num_fewshot=0, batch_size=None, device=None,
no_cache=False, limit=None, bootstrap_iters=100000): no_cache=False, limit=None, bootstrap_iters=100000):
""" """Instantiate and evaluate a model on a list of tasks.
:param model: str :param model: str
Name of model, see lm_eval.models.get_model Name of model, see lm_eval.models.get_model
...@@ -24,7 +24,7 @@ def simple_evaluate(model, model_args, task_names, ...@@ -24,7 +24,7 @@ def simple_evaluate(model, model_args, task_names,
:param batch_size: int, optional :param batch_size: int, optional
Batch size for model Batch size for model
:param device: str, optional :param device: str, optional
PyTorch device (e.g. "cpu" or "cuda:0") for running models
:param no_cache: bool :param no_cache: bool
Whether or not Whether or not
:param limit: int, optional :param limit: int, optional
...@@ -32,6 +32,7 @@ def simple_evaluate(model, model_args, task_names, ...@@ -32,6 +32,7 @@ def simple_evaluate(model, model_args, task_names,
:param bootstrap_iters: :param bootstrap_iters:
Number of iterations for bootstrap statistics Number of iterations for bootstrap statistics
:return :return
Dictionary of results
""" """
random.seed(1234) random.seed(1234)
np.random.seed(1234) np.random.seed(1234)
...@@ -64,6 +65,23 @@ def simple_evaluate(model, model_args, task_names, ...@@ -64,6 +65,23 @@ def simple_evaluate(model, model_args, task_names,
def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_iters=100000): def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_iters=100000):
"""Instantiate and evaluate a model on a list of tasks.
:param lm: obj
Language Model
:param task_dict: dict[str, Task]
Dictionary of tasks
:param provide_description: bool
NOT IMPLEMENTED
:param num_fewshot: int
Number of examples in few-shot context
:param limit: int, optional
Limit the number of examples per task (only use this for testing)
:param bootstrap_iters:
Number of iterations for bootstrap statistics
:return
Dictionary of results
"""
# TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces # TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces
# TODO: todo: implement proper description-providing system # TODO: todo: implement proper description-providing system
......
...@@ -8,6 +8,18 @@ import time ...@@ -8,6 +8,18 @@ import time
def get_result(response, ctxlen): def get_result(response, ctxlen):
"""Process results from OpenAI API response.
:param response: dict
OpenAI API Response
:param ctxlen: int
Length of context (so we can slice them away and only keep the predictions)
:return:
continuation_logprobs: np.array
Log probabilities of continuation tokens
is_greedy: bool
whether argmax matches given continuation exactly
"""
is_greedy = True is_greedy = True
logprobs = response["logprobs"]["token_logprobs"] logprobs = response["logprobs"]["token_logprobs"]
continuation_logprobs = sum(logprobs[ctxlen:]) continuation_logprobs = sum(logprobs[ctxlen:])
......
...@@ -19,6 +19,7 @@ def mock_completion(**kwargs): ...@@ -19,6 +19,7 @@ def mock_completion(**kwargs):
with open(fname, 'rb') as fh: with open(fname, 'rb') as fh:
return pickle.load(fh) return pickle.load(fh)
ret = openai.Completion.create(**kwargs) ret = openai.Completion.create(**kwargs)
ret.api_key = ""
with open(fname, 'wb') as fh: with open(fname, 'wb') as fh:
pickle.dump(ret, fh) pickle.dump(ret, fh)
return ret return ret
...@@ -65,8 +66,8 @@ def test_gpt3(): ...@@ -65,8 +66,8 @@ def test_gpt3():
print([x[0] for x in vals]) print([x[0] for x in vals])
targets = [ targets = [
-34.85833048, -47.114367866, -45.43520782100001, -5.289627985, -133.96879783896998, -321.30299892039994, -34.848301606999996, -47.148329679999996, -45.44380149599999, -5.285246016, -133.97821690686004,
-658.0542459504098, -34.85833048, -7.5162964 -321.2616693239001, -658.0299524401041, -34.848301606999996, -7.525115,
] ]
for (pred, _), tgt in zip(vals, targets): for (pred, _), tgt in zip(vals, targets):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment