anthropic_llms.py 12.5 KB
Newer Older
Lintang Sutawika's avatar
Lintang Sutawika committed
1
import logging
Baber Abbasi's avatar
Baber Abbasi committed
2
3
4
import os
from functools import cached_property
from typing import Any, Dict, List, Tuple, Union
5
6
7
8
9

from tqdm import tqdm

from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
Baber Abbasi's avatar
Baber Abbasi committed
10
from lm_eval.models.openai_completions import LocalCompletionsAPI
11
from lm_eval.models.utils import handle_stop_sequences, retry_on_specific_exceptions
12

Jason Phang's avatar
Jason Phang committed
13

Lintang Sutawika's avatar
Lintang Sutawika committed
14
eval_logger = logging.getLogger(__name__)
Jason Phang's avatar
Jason Phang committed
15

lintangsutawika's avatar
lintangsutawika committed
16

lintangsutawika's avatar
lintangsutawika committed
17
def anthropic_completion(
18
    client,  #: anthropic.Anthropic,
baberabb's avatar
baberabb committed
19
20
21
22
23
    model: str,
    prompt: str,
    max_tokens_to_sample: int,
    temperature: float,
    stop: List[str],
baberabb's avatar
baberabb committed
24
    **kwargs: Any,
baberabb's avatar
baberabb committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
) -> str:
    """Wrapper function around the Anthropic completion API client with exponential back-off
    in case of RateLimitError.

    params:
        client: anthropic.Anthropic
            Anthropic API client
        model: str
            Anthropic model e.g. 'claude-instant-v1', 'claude-2'
        prompt: str
            Prompt to feed to the model
        max_tokens_to_sample: int
            Maximum number of tokens to sample from the model
        temperature: float
            Sampling temperature
        stop: List[str]
            List of stop sequences
        kwargs: Any
            Additional model_args to pass to the API client
Jason Phang's avatar
Jason Phang committed
44
    """
45
46
47

    try:
        import anthropic
48
49
    except ModuleNotFoundError as exception:
        raise type(exception)(
50
            "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
Seungwoo Ryu's avatar
Seungwoo Ryu committed
51
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
52
53
        )

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    def _exception_callback(e: Exception, sleep_time: float) -> None:
        eval_logger.warning(
            f"RateLimitError occurred: {e.__cause__}\n Retrying in {sleep_time} seconds"
        )

    @retry_on_specific_exceptions(
        on_exceptions=[anthropic.RateLimitError],
        max_retries=None,  # retry forever, consider changing
        on_exception_callback=_exception_callback,
    )
    def completion():
        response = client.completions.create(
            prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
            model=model,
            # NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
            #       (e.g. gsm8k's ":") may truncate a lot of the input.
            stop_sequences=[anthropic.HUMAN_PROMPT] + stop,
            max_tokens_to_sample=max_tokens_to_sample,
            temperature=temperature,
            **kwargs,
        )
        return response.completion

    return completion()
Jason Phang's avatar
Jason Phang committed
78
79


Seungwoo Ryu's avatar
Seungwoo Ryu committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
def anthropic_chat(
    client,  #: anthropic.Anthropic,
    model: str,
    prompt: str,
    max_tokens: int,
    temperature: float,
    stop: List[str],
    **kwargs: Any,
) -> str:
    """Wrapper function around the Anthropic completion API client with exponential back-off
    in case of RateLimitError.

    params:
        client: anthropic.Anthropic
            Anthropic API client
        model: str
            Anthropic model e.g. 'claude-3-opus-20240229', 'claude-3-sonnet-20240229'
        prompt: str
            Prompt to feed to the model
        max_tokens: int
            Maximum number of tokens to sample from the model
        temperature: float
            Sampling temperature
        stop: List[str]
            List of stop sequences
        kwargs: Any
            Additional model_args to pass to the API client
    """

    try:
        import anthropic
111
112
    except ModuleNotFoundError as exception:
        raise type(exception)(
Seungwoo Ryu's avatar
Seungwoo Ryu committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
            "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
        )

    def _exception_callback(e: Exception, sleep_time: float) -> None:
        eval_logger.warning(
            f"RateLimitError occurred: {e.__cause__}\n Retrying in {sleep_time} seconds"
        )

    @retry_on_specific_exceptions(
        on_exceptions=[
            anthropic.RateLimitError,
            anthropic.APIConnectionError,
            anthropic.APIStatusError,
        ],
        max_retries=None,  # retry forever, consider changing
        on_exception_callback=_exception_callback,
    )
    def messages():
        response = client.messages.create(
            model=model,
            max_tokens=max_tokens,
            temperature=temperature,
            messages=[{"role": "user", "content": f"{prompt}"}],
            **kwargs,
        )
        return response.content[0].text

    return messages()


Baber Abbasi's avatar
Baber Abbasi committed
144
@register_model("anthropic-completions")
lintangsutawika's avatar
lintangsutawika committed
145
class AnthropicLM(LM):
baberabb's avatar
baberabb committed
146
    REQ_CHUNK_SIZE = 20  # TODO: not used
Jason Phang's avatar
Jason Phang committed
147

baberabb's avatar
baberabb committed
148
149
    def __init__(
        self,
150
        batch_size: int = 1,
baberabb's avatar
baberabb committed
151
152
        model: str = "claude-2.0",
        max_tokens_to_sample: int = 256,
153
154
        temperature: float = 0,  # defaults to 1
        **kwargs,  # top_p, top_k, etc.
Ethan Smith's avatar
Ethan Smith committed
155
    ) -> None:
baberabb's avatar
baberabb committed
156
        """Anthropic API wrapper.
Jason Phang's avatar
Jason Phang committed
157
158

        :param model: str
baberabb's avatar
baberabb committed
159
            Anthropic model e.g. 'claude-instant-v1', 'claude-2'
baberabb's avatar
baberabb committed
160
161
162
163
164
165
        :param max_tokens_to_sample: int
            Maximum number of tokens to sample from the model
        :param temperature: float
            Sampling temperature
        :param kwargs: Any
            Additional model_args to pass to the API client
Jason Phang's avatar
Jason Phang committed
166
167
        """
        super().__init__()
lintangsutawika's avatar
lintangsutawika committed
168

169
170
        try:
            import anthropic
171
172
        except ModuleNotFoundError as exception:
            raise type(exception)(
173
                "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
Seungwoo Ryu's avatar
Seungwoo Ryu committed
174
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
175
176
            )

Jason Phang's avatar
Jason Phang committed
177
        self.model = model
baberabb's avatar
baberabb committed
178
        # defaults to os.environ.get("ANTHROPIC_API_KEY")
baberabb's avatar
baberabb committed
179
        self.client = anthropic.Anthropic()
baberabb's avatar
baberabb committed
180
181
182
        self.temperature = temperature
        self.max_tokens_to_sample = max_tokens_to_sample
        self.tokenizer = self.client.get_tokenizer()
baberabb's avatar
baberabb committed
183
        self.kwargs = kwargs
Jason Phang's avatar
Jason Phang committed
184
185
186

    @property
    def eot_token_id(self):
baberabb's avatar
baberabb committed
187
        # Not sure but anthropic.HUMAN_PROMPT ?
Jason Phang's avatar
Jason Phang committed
188
189
190
        raise NotImplementedError("No idea about anthropic tokenization.")

    @property
baberabb's avatar
baberabb committed
191
    def max_length(self) -> int:
Jason Phang's avatar
Jason Phang committed
192
193
194
        return 2048

    @property
baberabb's avatar
baberabb committed
195
    def max_gen_toks(self) -> int:
baberabb's avatar
baberabb committed
196
        return self.max_tokens_to_sample
Jason Phang's avatar
Jason Phang committed
197
198
199
200

    @property
    def batch_size(self):
        # Isn't used because we override _loglikelihood_tokens
baberabb's avatar
baberabb committed
201
        raise NotImplementedError("No support for logits.")
Jason Phang's avatar
Jason Phang committed
202
203
204
205

    @property
    def device(self):
        # Isn't used because we override _loglikelihood_tokens
baberabb's avatar
baberabb committed
206
        raise NotImplementedError("No support for logits.")
Jason Phang's avatar
Jason Phang committed
207

baberabb's avatar
baberabb committed
208
209
    def tok_encode(self, string: str) -> List[int]:
        return self.tokenizer.encode(string).ids
Jason Phang's avatar
Jason Phang committed
210

baberabb's avatar
baberabb committed
211
212
    def tok_decode(self, tokens: List[int]) -> str:
        return self.tokenizer.decode(tokens)
Jason Phang's avatar
Jason Phang committed
213

Ethan Smith's avatar
Ethan Smith committed
214
    def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False):
Jason Phang's avatar
Jason Phang committed
215
216
        raise NotImplementedError("No support for logits.")

217
    def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
218
219
        try:
            import anthropic
220
221
        except ModuleNotFoundError as exception:
            raise type(exception)(
222
                "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
Seungwoo Ryu's avatar
Seungwoo Ryu committed
223
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
224
225
            )

Jason Phang's avatar
Jason Phang committed
226
227
228
        if not requests:
            return []

baberabb's avatar
baberabb committed
229
        _requests: List[Tuple[str, dict]] = [req.args for req in requests]
haileyschoelkopf's avatar
haileyschoelkopf committed
230

Jason Phang's avatar
Jason Phang committed
231
        res = []
232
        for request in tqdm(_requests, disable=disable_tqdm):
baberabb's avatar
baberabb committed
233
234
235
            try:
                inp = request[0]
                request_args = request[1]
236
237
238
239
                # generation_kwargs
                until = request_args.get("until")
                max_gen_toks = request_args.get("max_gen_toks", self.max_length)
                temperature = request_args.get("temperature", self.temperature)
baberabb's avatar
baberabb committed
240
241
242
243
                response = anthropic_completion(
                    client=self.client,
                    model=self.model,
                    prompt=inp,
244
245
                    max_tokens_to_sample=max_gen_toks,
                    temperature=temperature,  # TODO: implement non-greedy sampling for Anthropic
baberabb's avatar
baberabb committed
246
                    stop=until,  # type: ignore
baberabb's avatar
baberabb committed
247
248
249
250
                    **self.kwargs,
                )
                res.append(response)

251
                self.cache_hook.add_partial("generate_until", request, response)
baberabb's avatar
baberabb committed
252
            except anthropic.APIConnectionError as e:  # type: ignore # noqa: F821
baberabb's avatar
baberabb committed
253
254
                eval_logger.critical(f"Server unreachable: {e.__cause__}")
                break
baberabb's avatar
baberabb committed
255
            except anthropic.APIStatusError as e:  # type: ignore # noqa: F821
baberabb's avatar
baberabb committed
256
257
                eval_logger.critical(f"API error {e.status_code}: {e.message}")
                break
haileyschoelkopf's avatar
haileyschoelkopf committed
258

Jason Phang's avatar
Jason Phang committed
259
260
261
262
263
264
265
        return res

    def _model_call(self, inps):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    def _model_generate(self, context, max_length, eos_token_id):
266
        # Isn't used because we override generate_until
Jason Phang's avatar
Jason Phang committed
267
        raise NotImplementedError()
baberabb's avatar
baberabb committed
268

269
    def loglikelihood(self, requests, disable_tqdm: bool = False):
baberabb's avatar
baberabb committed
270
271
        raise NotImplementedError("No support for logits.")

272
    def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
baberabb's avatar
baberabb committed
273
        raise NotImplementedError("No support for logits.")
Seungwoo Ryu's avatar
Seungwoo Ryu committed
274
275
276


@register_model("anthropic-chat", "anthropic-chat-completions")
Baber Abbasi's avatar
Baber Abbasi committed
277
class AnthropicChat(LocalCompletionsAPI):
Seungwoo Ryu's avatar
Seungwoo Ryu committed
278
279
    def __init__(
        self,
Baber Abbasi's avatar
Baber Abbasi committed
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
        base_url="https://api.anthropic.com/v1/messages",
        tokenizer_backend=None,
        **kwargs,
    ):
        super().__init__(
            base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
        )
        eval_logger.warning(
            "Chat completions does not support batching. Defaulting to batch size 1."
        )
        self._batch_size = 1
        self.anthropic_version = "2023-06-01"
        eval_logger.warning(
            f"Using Anthropic Version: {self.anthropic_version}. Confirm the current version here: https://docs.anthropic.com/en/api/versioning"
        )
Seungwoo Ryu's avatar
Seungwoo Ryu committed
295

Baber Abbasi's avatar
Baber Abbasi committed
296
297
298
299
300
301
302
    @cached_property
    def api_key(self):
        """Override this property to return the API key for the API request."""
        key = os.environ.get("ANTHROPIC_API_KEY", None)
        if key is None:
            raise ValueError(
                "API key not found. Please set the ANTHROPIC_API_KEY environment variable."
Seungwoo Ryu's avatar
Seungwoo Ryu committed
303
            )
Baber Abbasi's avatar
Baber Abbasi committed
304
305
306
307
308
309
310
311
312
313
        return key

    @cached_property
    def header(self):
        return {
            "x-api-key": f"{self.api_key}",
            "anthropic-version": self.anthropic_version,
        }

    def _create_payload(
314
315
316
317
318
319
        self,
        messages: List[Dict],
        generate=True,
        gen_kwargs: dict = None,
        eos="\n\nHuman:",
        **kwargs,
Baber Abbasi's avatar
Baber Abbasi committed
320
321
322
323
324
325
    ) -> dict:
        system = (
            messages[0].get("content") if messages[0].get("role") == "system" else None
        )
        if system:
            messages = messages[1:]
326
327
328
329
330
331
332
333
334
335
336

        cleaned_messages = []
        for msg in messages:
            cleaned_msg = {
                "role": msg["role"],
                "content": [
                    {"type": msg["type"], msg["type"]: msg["content"]},
                ],
            }
            cleaned_messages.append(cleaned_msg)

Baber Abbasi's avatar
Baber Abbasi committed
337
338
339
        gen_kwargs.pop("do_sample", False)
        max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
        temperature = gen_kwargs.pop("temperature", 0)
340
        stop = handle_stop_sequences(gen_kwargs.pop("until", ["\n\nHuman:"]), eos=eos)
Baber Abbasi's avatar
Baber Abbasi committed
341
342
        if not isinstance(stop, list):
            stop = [stop]
343
344
345
346

        # Filter out empty or whitespace-only stop sequences for Anthropic API
        stop = [s for s in stop if s and s.strip()]

Baber Abbasi's avatar
Baber Abbasi committed
347
        out = {
348
            "messages": cleaned_messages,
Baber Abbasi's avatar
Baber Abbasi committed
349
350
351
352
353
354
355
356
357
358
359
360
361
            "model": self.model,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "stop_sequences": stop,
            **gen_kwargs,
        }
        if system:
            out["system"] = system
        return out

    def parse_generations(
        self, outputs: Union[Dict, List[Dict]], **kwargs
    ) -> List[str]:
Seungwoo Ryu's avatar
Seungwoo Ryu committed
362
        res = []
Baber Abbasi's avatar
Baber Abbasi committed
363
364
365
366
367
        if not isinstance(outputs, list):
            outputs = [outputs]
        for out in outputs:
            for choices in out["content"]:
                res.append(choices["text"])
Seungwoo Ryu's avatar
Seungwoo Ryu committed
368
        return res
Baber Abbasi's avatar
Baber Abbasi committed
369
370
371
372
373
374
375
376
377
378

    def tok_encode(
        self,
        string: str,
        left_truncate_len=None,
        add_special_tokens=None,
        **kwargs,
    ) -> List[str]:
        return [string]

Baber Abbasi's avatar
Baber Abbasi committed
379
    def loglikelihood(self, requests, **kwargs):
Baber Abbasi's avatar
Baber Abbasi committed
380
        raise NotImplementedError(
Baber Abbasi's avatar
Baber Abbasi committed
381
            "Anthropic Chat Completions API does not support the return of loglikelihood"
Baber Abbasi's avatar
Baber Abbasi committed
382
        )