textsynth.py 5.41 KB
Newer Older
Jonathan Tow's avatar
Jonathan Tow committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
""" TextSynth API
Implementation provided by Fabrice Bellard:
    https://github.com/EleutherAI/lm-evaluation-harness/issues/295

In order to use the API, you must have a valid TextSynth account and
enough credits.

Example usage:

    python main.py --model textsynth --model_args engine=gptj_6B --no_cache --tasks piqa

Homepage: https://textsynth.com/index.html
"""
import logging
import os
import requests as _requests
import time
from tqdm import tqdm
19
20
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
Jonathan Tow's avatar
Jonathan Tow committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41


logger = logging.getLogger(__name__)


def textsynth_completion(**kwargs):
    """Query TextSynth API for completion.
    Retry with back-off until they respond.
    """
    backoff_time = 3
    while True:
        try:
            return _requests.post(**kwargs)
        except _requests.exceptions.RequestException:
            import traceback

            traceback.print_exc()
            time.sleep(backoff_time)
            backoff_time *= 1.5


42
@register_model("textsynth")
43
class TextSynthLM(LM):
Ethan Smith's avatar
Ethan Smith committed
44
    def __init__(self, engine, truncate: bool = False) -> None:
Jonathan Tow's avatar
Jonathan Tow committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        """
        :param engine: str
            TextSynth API engine (e.g. `gptj_6B`)
        :param truncate: bool
            Truncate input if too long (if False and input is too long, throw error)
        """
        super().__init__()

        self.engine = engine
        self.truncate = truncate
        self.api_url = "https://api.textsynth.com"
        # Read from environment variable TEXTSYNTH_API_SECRET_KEY
        self.api_key = os.environ["TEXTSYNTH_API_SECRET_KEY"]

    @property
    def eot_token_id(self):
61
        # Isn't used because we override loglikelihood, loglikelihood_rolling and generate_until
Jonathan Tow's avatar
Jonathan Tow committed
62
63
64
        raise NotImplementedError()

    @property
Ethan Smith's avatar
Ethan Smith committed
65
    def max_length(self) -> int:
Jonathan Tow's avatar
Jonathan Tow committed
66
67
68
69
        # NOTE: Turn on truncation to avoid errors on long inputs.
        return 2048

    @property
Ethan Smith's avatar
Ethan Smith committed
70
    def max_gen_toks(self) -> int:
Jonathan Tow's avatar
Jonathan Tow committed
71
72
73
74
        return 256

    @property
    def batch_size(self):
75
        # Isn't used because we override loglikelihood, loglikelihood_rolling and generate_until
Jonathan Tow's avatar
Jonathan Tow committed
76
77
78
79
        raise NotImplementedError()

    @property
    def device(self):
80
        # Isn't used because we override loglikelihood, loglikelihood_rolling and generate_until
Jonathan Tow's avatar
Jonathan Tow committed
81
82
83
        raise NotImplementedError()

    def tok_encode(self, string: str):
84
        # Isn't used because we override loglikelihood, loglikelihood_rolling and generate_until
Jonathan Tow's avatar
Jonathan Tow committed
85
86
87
        raise NotImplementedError()

    def tok_decode(self, tokens):
88
        # Isn't used because we override loglikelihood, loglikelihood_rolling and generate_until
Jonathan Tow's avatar
Jonathan Tow committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        raise NotImplementedError()

    def loglikelihood(self, requests):
        res = []
        for context, continuation in tqdm(requests):
            response = textsynth_completion(
                url=self.api_url + "/v1/engines/" + self.engine + "/logprob",
                headers={"Authorization": "Bearer " + self.api_key},
                json={"context": context, "continuation": continuation},
            )
            resp = response.json()
            if "logprob" in resp:
                logprob = resp["logprob"]
                is_greedy = resp["is_greedy"]
                res.append((logprob, is_greedy))
haileyschoelkopf's avatar
haileyschoelkopf committed
104
105
106
107

                self.cache_hook.add_partial(
                    "loglikelihood", (context, continuation), (logprob, is_greedy)
                )
Jonathan Tow's avatar
Jonathan Tow committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
            else:
                logger.error(
                    f"The following response does not contain `logprobs`. Got:\n{resp}"
                )
                assert False
        return res

    def loglikelihood_rolling(self, requests):
        # TODO: The TextSynth API does not support tokenized inputs so we cannot
        # manually partition long contexts into smaller rolling windows as
        # done for other models derived from `BaseLM`. Override this method
        # with a windowing scheme that works for direct string inputs.
        raise NotImplementedError(
            "`loglikelihood_rolling` is currently not supported due to lack of "
            "input tokenization support from TextSynth."
        )

125
    def generate_until(self, requests):
Jonathan Tow's avatar
Jonathan Tow committed
126
127
128
129
130
131
        if not requests:
            return []

        res = []
        for request in tqdm(requests):
            inp = request[0]
132
            request_args = request[1]
133
            until = request_args["until"]
Jonathan Tow's avatar
Jonathan Tow committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
            response = textsynth_completion(
                url=self.api_url + "/v1/engines/" + self.engine + "/completions",
                headers={"Authorization": "Bearer " + self.api_key},
                json={
                    "prompt": inp,
                    "max_tokens": self.max_gen_toks,
                    "top_k": 1,
                    "stop": until,
                },
            )
            resp = response.json()
            if "text" in resp:
                s = resp["text"]
                res.append(s)
haileyschoelkopf's avatar
haileyschoelkopf committed
148

149
                self.cache_hook.add_partial("generate_until", (inp, request_args), s)
Jonathan Tow's avatar
Jonathan Tow committed
150
151
152
153
154
155
156
157
158
159
160
161
162
            else:
                logger.error(
                    f"The following response does not contain generated `text`. "
                    "Got:\n{resp}"
                )
                assert False
        return res

    def _model_call(self, inps):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    def _model_generate(self, context, max_length, eos_token_id):
163
        # Isn't used because we override generate_until
Jonathan Tow's avatar
Jonathan Tow committed
164
        raise NotImplementedError()