textsynth.py 5.74 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
""" TextSynth API
Implementation provided by Fabrice Bellard:
    https://github.com/EleutherAI/lm-evaluation-harness/issues/295

In order to use the API, you must have a valid TextSynth account and
enough credits.

Example usage:

    python main.py --model textsynth --model_args engine=gptj_6B --no_cache --tasks piqa

Homepage: https://textsynth.com/index.html
"""
import logging
import os
Rayyyyy's avatar
Rayyyyy committed
16

Rayyyyy's avatar
Rayyyyy committed
17
18
import requests as _requests
from tqdm import tqdm
Rayyyyy's avatar
Rayyyyy committed
19
20
21
22

from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.models.utils import retry_on_specific_exceptions
Rayyyyy's avatar
Rayyyyy committed
23
24
25
26
27
28
29
30
31
32


logger = logging.getLogger(__name__)


def textsynth_completion(**kwargs):
    """Query TextSynth API for completion.
    Retry with back-off until they respond.
    """

Rayyyyy's avatar
Rayyyyy committed
33
34
35
36
37
38
39
40
41
42
43
44
    def _exception_callback(e: Exception, sleep_time: float) -> None:
        import traceback

        traceback.print_exc()

    @retry_on_specific_exceptions(
        on_exceptions=[_requests.exceptions.RequestException],
        max_retries=None,  # retry forever, consider changing
        on_exception_callback=_exception_callback,
    )
    def completion():
        return _requests.post(**kwargs)
Rayyyyy's avatar
Rayyyyy committed
45

Rayyyyy's avatar
Rayyyyy committed
46
    return completion()
Rayyyyy's avatar
Rayyyyy committed
47

Rayyyyy's avatar
Rayyyyy committed
48
49
50
51

@register_model("textsynth")
class TextSynthLM(LM):
    def __init__(self, engine, truncate: bool = False, **kwargs) -> None:
Rayyyyy's avatar
Rayyyyy committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
        """
        :param engine: str
            TextSynth API engine (e.g. `gptj_6B`)
        :param truncate: bool
            Truncate input if too long (if False and input is too long, throw error)
        """
        super().__init__()

        self.engine = engine
        self.truncate = truncate
        self.api_url = "https://api.textsynth.com"
        # Read from environment variable TEXTSYNTH_API_SECRET_KEY
        self.api_key = os.environ["TEXTSYNTH_API_SECRET_KEY"]

    @property
    def eot_token_id(self):
Rayyyyy's avatar
Rayyyyy committed
68
        # Isn't used because we override loglikelihood, loglikelihood_rolling and generate_until
Rayyyyy's avatar
Rayyyyy committed
69
70
71
        raise NotImplementedError()

    @property
Rayyyyy's avatar
Rayyyyy committed
72
    def max_length(self) -> int:
Rayyyyy's avatar
Rayyyyy committed
73
74
75
76
        # NOTE: Turn on truncation to avoid errors on long inputs.
        return 2048

    @property
Rayyyyy's avatar
Rayyyyy committed
77
    def max_gen_toks(self) -> int:
Rayyyyy's avatar
Rayyyyy committed
78
79
80
81
        return 256

    @property
    def batch_size(self):
Rayyyyy's avatar
Rayyyyy committed
82
        # Isn't used because we override loglikelihood, loglikelihood_rolling and generate_until
Rayyyyy's avatar
Rayyyyy committed
83
84
85
86
        raise NotImplementedError()

    @property
    def device(self):
Rayyyyy's avatar
Rayyyyy committed
87
        # Isn't used because we override loglikelihood, loglikelihood_rolling and generate_until
Rayyyyy's avatar
Rayyyyy committed
88
89
90
        raise NotImplementedError()

    def tok_encode(self, string: str):
Rayyyyy's avatar
Rayyyyy committed
91
        # Isn't used because we override loglikelihood, loglikelihood_rolling and generate_until
Rayyyyy's avatar
Rayyyyy committed
92
93
94
        raise NotImplementedError()

    def tok_decode(self, tokens):
Rayyyyy's avatar
Rayyyyy committed
95
        # Isn't used because we override loglikelihood, loglikelihood_rolling and generate_until
Rayyyyy's avatar
Rayyyyy committed
96
97
        raise NotImplementedError()

Rayyyyy's avatar
Rayyyyy committed
98
    def loglikelihood(self, requests, disable_tqdm: bool = False):
Rayyyyy's avatar
Rayyyyy committed
99
        res = []
Rayyyyy's avatar
Rayyyyy committed
100
        for context, continuation in tqdm(requests, disable=disable_tqdm):
Rayyyyy's avatar
Rayyyyy committed
101
102
103
104
105
106
107
108
109
110
            response = textsynth_completion(
                url=self.api_url + "/v1/engines/" + self.engine + "/logprob",
                headers={"Authorization": "Bearer " + self.api_key},
                json={"context": context, "continuation": continuation},
            )
            resp = response.json()
            if "logprob" in resp:
                logprob = resp["logprob"]
                is_greedy = resp["is_greedy"]
                res.append((logprob, is_greedy))
Rayyyyy's avatar
Rayyyyy committed
111
112
113
114

                self.cache_hook.add_partial(
                    "loglikelihood", (context, continuation), (logprob, is_greedy)
                )
Rayyyyy's avatar
Rayyyyy committed
115
116
117
118
119
120
121
            else:
                logger.error(
                    f"The following response does not contain `logprobs`. Got:\n{resp}"
                )
                assert False
        return res

Rayyyyy's avatar
Rayyyyy committed
122
    def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
Rayyyyy's avatar
Rayyyyy committed
123
124
125
126
127
128
129
130
131
        # TODO: The TextSynth API does not support tokenized inputs so we cannot
        # manually partition long contexts into smaller rolling windows as
        # done for other models derived from `BaseLM`. Override this method
        # with a windowing scheme that works for direct string inputs.
        raise NotImplementedError(
            "`loglikelihood_rolling` is currently not supported due to lack of "
            "input tokenization support from TextSynth."
        )

Rayyyyy's avatar
Rayyyyy committed
132
    def generate_until(self, requests, disable_tqdm: bool = False):
Rayyyyy's avatar
Rayyyyy committed
133
134
135
136
        if not requests:
            return []

        res = []
Rayyyyy's avatar
Rayyyyy committed
137
        for request in tqdm(requests, disable=disable_tqdm):
Rayyyyy's avatar
Rayyyyy committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
            inp = request[0]
            request_args = request[1]
            until = request_args["until"]
            response = textsynth_completion(
                url=self.api_url + "/v1/engines/" + self.engine + "/completions",
                headers={"Authorization": "Bearer " + self.api_key},
                json={
                    "prompt": inp,
                    "max_tokens": self.max_gen_toks,
                    "top_k": 1,
                    "stop": until,
                },
            )
            resp = response.json()
            if "text" in resp:
                s = resp["text"]
                res.append(s)
Rayyyyy's avatar
Rayyyyy committed
155
156

                self.cache_hook.add_partial("generate_until", (inp, request_args), s)
Rayyyyy's avatar
Rayyyyy committed
157
158
            else:
                logger.error(
Rayyyyy's avatar
Rayyyyy committed
159
                    "The following response does not contain generated `text`. "
Rayyyyy's avatar
Rayyyyy committed
160
161
162
163
164
165
166
167
168
169
                    "Got:\n{resp}"
                )
                assert False
        return res

    def _model_call(self, inps):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    def _model_generate(self, context, max_length, eos_token_id):
Rayyyyy's avatar
Rayyyyy committed
170
        # Isn't used because we override generate_until
Rayyyyy's avatar
Rayyyyy committed
171
        raise NotImplementedError()