Commit ee5467ff authored by Jason Phang's avatar Jason Phang
Browse files

adding tests sortof

parent 3a490624
......@@ -26,7 +26,18 @@ def test_evaluator(taskname, Task):
res.append((-random.random(), False))
return res
def ll_perp_fn(reqs):
for string, in reqs:
assert isinstance(string, str)
res = []
random.seed(42)
for _ in reqs:
res.append((-random.random(),))
return res
lm.loglikelihood = ll_fn
lm.loglikelihood_perplexity = ll_perp_fn
evaluator.evaluate(lm, task_dict, False, 0, 10)
......@@ -34,4 +34,20 @@ def test_gpt2():
targets = [-61.60536193847656, -56.57843780517578, -62.131004333496094, -9.799489974975586, -153.96334838867188, -341.222900390625, -731.1475830078125, -61.60536193847656, -8.682319641113281]
for (pred, _), tgt in zip(vals, targets):
assert pred == pytest.approx(tgt)
\ No newline at end of file
assert pred == pytest.approx(tgt)
def test_gpt2_perplexity():
gpt2 = models.get_model('gpt2').create_from_arg_string("device=cpu")
test_string = "We study empirical scaling laws for language model performance on the cross-entropy loss."
perplexity = gpt2.loglikelihood_perplexity([(test_string,)])[0]
targets = [-4.9599953, -8.069298, -8.308624, -10.178513, -8.906924, -1.9318912, -7.745445, -7.146077, -5.2072, -3.5882986, -1.9957212, -8.044922, -0.20841774, -5.1096807, -0.099879116, -8.888423, -4.6180487]
for pred, tgt in zip(perplexity, targets):
assert pred == pytest.approx(tgt)
# Hack: modify gpt2 to have shorter context length to induce rolling windows
gpt2.max_length = 5
perplexity = gpt2.loglikelihood_perplexity([(test_string,)])[0]
targets = [-4.96001, -8.069275, -8.308612, -10.178482, -8.90691, -4.037338, -8.09261, -11.662385, -10.206891, -4.425003, -2.2563353, -7.909143, -1.9304147, -7.3610134, -2.3120654, -7.3229, -2.1643813]
for pred, tgt in zip(perplexity, targets):
assert pred == pytest.approx(tgt)
from lm_eval.utils import get_rolling_token_windows
# noinspection DuplicatedCode
def test_get_rolling_token_windows_v1():
gold = [
([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
([9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]),
([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]),
([23, 24, 25, 26, 27, 28, 29, 30, 31, 32], [30, 31, 32, 33]),
]
x = list(range(34))
generator = get_rolling_token_windows(
token_list=x,
prefix_token=-100,
max_seq_len=10,
context_len=1,
)
pred_length = 0
output = []
for input_tokens, pred_tokens in generator:
output.append((input_tokens, pred_tokens))
pred_length += len(pred_tokens)
assert pred_length == len(x)
assert gold == output
# noinspection DuplicatedCode
def test_get_rolling_token_windows_v2():
gold = [
([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
([2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [10, 11, 12]),
([5, 6, 7, 8, 9, 10, 11, 12, 13, 14], [13, 14, 15]),
([8, 9, 10, 11, 12, 13, 14, 15, 16, 17], [16, 17, 18]),
([11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [19, 20, 21]),
([14, 15, 16, 17, 18, 19, 20, 21, 22, 23], [22, 23, 24]),
([17, 18, 19, 20, 21, 22, 23, 24, 25, 26], [25, 26, 27]),
([20, 21, 22, 23, 24, 25, 26, 27, 28, 29], [28, 29, 30]),
([23, 24, 25, 26, 27, 28, 29, 30, 31, 32], [31, 32, 33]),
]
x = list(range(34))
generator = get_rolling_token_windows(
token_list=x,
prefix_token=-100,
max_seq_len=10,
context_len=8,
)
pred_length = 0
output = []
for input_tokens, pred_tokens in generator:
output.append((input_tokens, pred_tokens))
pred_length += len(pred_tokens)
assert pred_length == len(x)
assert gold == output
# noinspection DuplicatedCode
def test_get_rolling_token_windows_v3():
gold = [
([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [10]),
([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [11]),
([2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12]),
([3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [13]),
([4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [14]),
([5, 6, 7, 8, 9, 10, 11, 12, 13, 14], [15]),
([6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [16]),
([7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [17]),
([8, 9, 10, 11, 12, 13, 14, 15, 16, 17], [18]),
([9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [19]),
([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [20]),
([11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [21]),
([12, 13, 14, 15, 16, 17, 18, 19, 20, 21], [22]),
([13, 14, 15, 16, 17, 18, 19, 20, 21, 22], [23]),
([14, 15, 16, 17, 18, 19, 20, 21, 22, 23], [24]),
([15, 16, 17, 18, 19, 20, 21, 22, 23, 24], [25]),
([16, 17, 18, 19, 20, 21, 22, 23, 24, 25], [26]),
([17, 18, 19, 20, 21, 22, 23, 24, 25, 26], [27]),
([18, 19, 20, 21, 22, 23, 24, 25, 26, 27], [28]),
([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [29]),
([20, 21, 22, 23, 24, 25, 26, 27, 28, 29], [30]),
([21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [31]),
([22, 23, 24, 25, 26, 27, 28, 29, 30, 31], [32]),
([23, 24, 25, 26, 27, 28, 29, 30, 31, 32], [33]),
]
x = list(range(34))
generator = get_rolling_token_windows(
token_list=x,
prefix_token=-100,
max_seq_len=10,
context_len=10,
)
pred_length = 0
output = []
for input_tokens, pred_tokens in generator:
output.append((input_tokens, pred_tokens))
pred_length += len(pred_tokens)
assert pred_length == len(x)
assert gold == output
# noinspection DuplicatedCode
def test_get_rolling_token_windows_v4():
gold = [
([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [10]),
([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [11]),
([2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12]),
([3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [13]),
([4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [14]),
([5, 6, 7, 8, 9, 10, 11, 12, 13, 14], [15]),
([6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [16]),
([7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [17]),
([8, 9, 10, 11, 12, 13, 14, 15, 16, 17], [18]),
([9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [19]),
([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [20]),
([11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [21]),
([12, 13, 14, 15, 16, 17, 18, 19, 20, 21], [22]),
([13, 14, 15, 16, 17, 18, 19, 20, 21, 22], [23]),
([14, 15, 16, 17, 18, 19, 20, 21, 22, 23], [24]),
([15, 16, 17, 18, 19, 20, 21, 22, 23, 24], [25]),
([16, 17, 18, 19, 20, 21, 22, 23, 24, 25], [26]),
([17, 18, 19, 20, 21, 22, 23, 24, 25, 26], [27]),
([18, 19, 20, 21, 22, 23, 24, 25, 26, 27], [28]),
([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [29]),
]
x = list(range(30))
generator = get_rolling_token_windows(
token_list=x,
prefix_token=-100,
max_seq_len=10,
context_len=10,
)
pred_length = 0
output = []
for input_tokens, pred_tokens in generator:
output.append((input_tokens, pred_tokens))
pred_length += len(pred_tokens)
assert pred_length == len(x)
assert gold == output
# noinspection DuplicatedCode
def test_get_rolling_token_windows_v5():
gold = [
([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
([9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]),
([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]),
]
x = list(range(30))
generator = get_rolling_token_windows(
token_list=x,
prefix_token=-100,
max_seq_len=10,
context_len=1,
)
pred_length = 0
output = []
for input_tokens, pred_tokens in generator:
output.append((input_tokens, pred_tokens))
pred_length += len(pred_tokens)
assert pred_length == len(x)
assert gold == output
# noinspection DuplicatedCode
def test_get_rolling_token_windows_v6():
gold = [
([-100, 0], [0, 1]),
([1, 2], [2, 3]),
([3, 4], [4, 5]),
([5, 6], [6, 7]),
([6, 7], [8]),
]
x = list(range(9))
generator = get_rolling_token_windows(
token_list=x,
prefix_token=-100,
max_seq_len=2,
context_len=1,
)
pred_length = 0
output = []
for input_tokens, pred_tokens in generator:
output.append((input_tokens, pred_tokens))
pred_length += len(pred_tokens)
assert pred_length == len(x)
assert gold == output
def test_get_rolling_token_windows_empty():
generator = get_rolling_token_windows(
token_list=[],
prefix_token=-100,
max_seq_len=2,
context_len=1,
)
n = 0
for _ in generator:
n += 1
assert n == 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment