test_huggingface.py 4.83 KB
Newer Older
baberabb's avatar
baberabb committed
1
from __future__ import annotations
2
3

import sys
baberabb's avatar
baberabb committed
4
from pathlib import Path
5

6
import numpy as np
baberabb's avatar
baberabb committed
7
import torch
baberabb's avatar
baberabb committed
8

9
10
11
12
13
import lm_eval.tasks as tasks
from lm_eval.api.instance import Instance
from lm_eval.models.huggingface import HFLM


14
15
tasks.initialize_tasks()

baberabb's avatar
baberabb committed
16
17

class Test_HFLM:
baberabb's avatar
baberabb committed
18
    torch.use_deterministic_algorithms(True)
baberabb's avatar
baberabb committed
19
    version_minor = sys.version_info.minor
baberabb's avatar
baberabb committed
20
21
22
    multiple_choice_task = tasks.TASK_REGISTRY.get("arc_easy")()  # type: ignore
    multiple_choice_task.build_all_requests(limit=10, rank=0, world_size=1)
    MULTIPLE_CH: list[Instance] = multiple_choice_task.instances
haileyschoelkopf's avatar
haileyschoelkopf committed
23
    generate_until_task = tasks.TASK_REGISTRY.get("gsm8k")()  # type: ignore
24
25
26
    generate_until_task.build_all_requests(limit=10, rank=0, world_size=1)
    generate_until_task._config.generation_kwargs["max_gen_toks"] = 10
    generate_until: list[Instance] = generate_until_task.instances
baberabb's avatar
baberabb committed
27
28
29
30
31
    rolling_task = tasks.TASK_REGISTRY.get("wikitext")()  # type: ignore
    rolling_task.build_all_requests(limit=10, rank=0, world_size=1)
    ROLLING: list[Instance] = rolling_task.instances

    MULTIPLE_CH_RES = [
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        -41.902435302734375,
        -42.939308166503906,
        -33.914180755615234,
        -37.07139205932617,
        -22.95258331298828,
        -20.342208862304688,
        -14.818366050720215,
        -27.942853927612305,
        -15.80704116821289,
        -15.936427116394043,
        -13.052018165588379,
        -18.04828453063965,
        -13.345029830932617,
        -13.366025924682617,
        -12.127134323120117,
        -11.872495651245117,
        -47.10598373413086,
        -47.76410675048828,
        -36.4406852722168,
        -50.0289421081543,
        -16.72093963623047,
        -18.535587310791016,
        -26.46993637084961,
        -20.355995178222656,
        -17.757919311523438,
        -21.80595588684082,
        -33.1990852355957,
        -39.28636932373047,
        -14.759679794311523,
        -16.753942489624023,
        -11.486852645874023,
        -15.42177677154541,
        -13.15798282623291,
        -15.887393951416016,
        -15.28614616394043,
        -12.339089393615723,
        -44.59441375732422,
        -55.40888214111328,
        -52.70050811767578,
        -56.25089645385742,
baberabb's avatar
baberabb committed
72
    ]
73
    generate_until_RES = [
baberabb's avatar
baberabb committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        " The average of $2.50 each is $",
        " A robe takes 2 bolts of blue fiber and half",
        " $50,000 in repairs.",
        " He runs 1 sprint 3 times a week.",
        " They feed each of her chickens three cups of mixed",
        " The price of the glasses is $5, but",
        " The total percentage of students who said they like to",
        " Carla is downloading a 200 GB file. Normally",
        " John drives for 3 hours at a speed of 60",
        " Eliza sells 4 tickets to 5 friends so she",
    ]
    ROLLING_RES = [
        -3603.6328125,
        -19779.23974609375,
        -8834.16455078125,
        -27967.591796875,
        -7636.794982910156,
        -9491.93505859375,
        -41043.4248046875,
        -8397.689819335938,
        -45969.47155761719,
        -7158.90625,
    ]
    LM = HFLM(pretrained="EleutherAI/pythia-70m", device="cpu", dtype="float32")

    def test_logliklihood(self) -> None:
        res = self.LM.loglikelihood(self.MULTIPLE_CH)
101
        _RES, _res = self.MULTIPLE_CH_RES, [r[0] for r in res]
baberabb's avatar
baberabb committed
102
        # log samples to CI
baberabb's avatar
baberabb committed
103
104
105
106
        dir_path = Path("test_logs")
        dir_path.mkdir(parents=True, exist_ok=True)

        file_path = dir_path / f"outputs_log_{self.version_minor}.txt"
baberabb's avatar
baberabb committed
107
108
109
        file_path = file_path.resolve()
        with open(file_path, "w") as f:
            f.write("\n".join(str(x) for x in _res))
baberabb's avatar
baberabb committed
110
        assert np.allclose(_res, _RES, atol=1e-2)
baberabb's avatar
baberabb committed
111
        # check indices for Multiple Choice
112
113
114
115
        argmax_RES, argmax_res = (
            np.argmax(np.array(_RES).reshape(-1, 4), axis=1),
            np.argmax(np.array(_res).reshape(-1, 4), axis=1),
        )
baberabb's avatar
baberabb committed
116
        assert (argmax_RES == argmax_res).all()
baberabb's avatar
baberabb committed
117

118
119
120
    def test_generate_until(self) -> None:
        res = self.LM.generate_until(self.generate_until)
        assert res == self.generate_until_RES
baberabb's avatar
baberabb committed
121
122
123

    def test_logliklihood_rolling(self) -> None:
        res = self.LM.loglikelihood_rolling(self.ROLLING)
124
        assert np.allclose(res, self.ROLLING_RES, atol=1e-1)
baberabb's avatar
baberabb committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142

    def test_toc_encode(self) -> None:
        res = self.LM.tok_encode("foo bar")
        assert res == [12110, 2534]

    def test_toc_decode(self) -> None:
        res = self.LM.tok_decode([12110, 2534])
        assert res == "foo bar"

    def test_batch_encode(self) -> None:
        res = self.LM.tok_batch_encode(["foo bar", "bar foo"])[0].tolist()
        assert res == [[12110, 2534], [2009, 17374]]

    def test_model_generate(self) -> None:
        context = self.LM.tok_batch_encode(["foo bar"])[0]
        res = self.LM._model_generate(context, max_length=10, stop=["\n\n"])
        res = self.LM.tok_decode(res[0])
        assert res == "foo bar\n<bazhang>!info bar"