anthropic_llms.py 13.1 KB
Newer Older
1
2
from __future__ import annotations

Lintang Sutawika's avatar
Lintang Sutawika committed
3
import logging
Baber Abbasi's avatar
Baber Abbasi committed
4
5
import os
from functools import cached_property
6
from typing import Any
7
8
9
10
11

from tqdm import tqdm

from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
Baber Abbasi's avatar
Baber Abbasi committed
12
from lm_eval.models.openai_completions import LocalCompletionsAPI
13
from lm_eval.models.utils import handle_stop_sequences, retry_on_specific_exceptions
14

Jason Phang's avatar
Jason Phang committed
15

Lintang Sutawika's avatar
Lintang Sutawika committed
16
eval_logger = logging.getLogger(__name__)
Jason Phang's avatar
Jason Phang committed
17

lintangsutawika's avatar
lintangsutawika committed
18

lintangsutawika's avatar
lintangsutawika committed
19
def anthropic_completion(
20
    client,  #: anthropic.Anthropic,
baberabb's avatar
baberabb committed
21
22
23
24
    model: str,
    prompt: str,
    max_tokens_to_sample: int,
    temperature: float,
25
    stop: list[str],
baberabb's avatar
baberabb committed
26
    **kwargs: Any,
baberabb's avatar
baberabb committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
) -> str:
    """Wrapper function around the Anthropic completion API client with exponential back-off
    in case of RateLimitError.

    params:
        client: anthropic.Anthropic
            Anthropic API client
        model: str
            Anthropic model e.g. 'claude-instant-v1', 'claude-2'
        prompt: str
            Prompt to feed to the model
        max_tokens_to_sample: int
            Maximum number of tokens to sample from the model
        temperature: float
            Sampling temperature
        stop: List[str]
            List of stop sequences
        kwargs: Any
            Additional model_args to pass to the API client
Jason Phang's avatar
Jason Phang committed
46
    """
47
48
49

    try:
        import anthropic
50
51
    except ModuleNotFoundError as exception:
        raise type(exception)(
52
            "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
Seungwoo Ryu's avatar
Seungwoo Ryu committed
53
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
54
55
        )

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    def _exception_callback(e: Exception, sleep_time: float) -> None:
        eval_logger.warning(
            f"RateLimitError occurred: {e.__cause__}\n Retrying in {sleep_time} seconds"
        )

    @retry_on_specific_exceptions(
        on_exceptions=[anthropic.RateLimitError],
        max_retries=None,  # retry forever, consider changing
        on_exception_callback=_exception_callback,
    )
    def completion():
        response = client.completions.create(
            prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
            model=model,
            # NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
            #       (e.g. gsm8k's ":") may truncate a lot of the input.
            stop_sequences=[anthropic.HUMAN_PROMPT] + stop,
            max_tokens_to_sample=max_tokens_to_sample,
            temperature=temperature,
            **kwargs,
        )
        return response.completion

    return completion()
Jason Phang's avatar
Jason Phang committed
80
81


Seungwoo Ryu's avatar
Seungwoo Ryu committed
82
83
84
85
86
87
def anthropic_chat(
    client,  #: anthropic.Anthropic,
    model: str,
    prompt: str,
    max_tokens: int,
    temperature: float,
88
    stop: list[str],
Seungwoo Ryu's avatar
Seungwoo Ryu committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    **kwargs: Any,
) -> str:
    """Wrapper function around the Anthropic completion API client with exponential back-off
    in case of RateLimitError.

    params:
        client: anthropic.Anthropic
            Anthropic API client
        model: str
            Anthropic model e.g. 'claude-3-opus-20240229', 'claude-3-sonnet-20240229'
        prompt: str
            Prompt to feed to the model
        max_tokens: int
            Maximum number of tokens to sample from the model
        temperature: float
            Sampling temperature
        stop: List[str]
            List of stop sequences
        kwargs: Any
            Additional model_args to pass to the API client
    """

    try:
        import anthropic
113
114
    except ModuleNotFoundError as exception:
        raise type(exception)(
Seungwoo Ryu's avatar
Seungwoo Ryu committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
            "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
        )

    def _exception_callback(e: Exception, sleep_time: float) -> None:
        eval_logger.warning(
            f"RateLimitError occurred: {e.__cause__}\n Retrying in {sleep_time} seconds"
        )

    @retry_on_specific_exceptions(
        on_exceptions=[
            anthropic.RateLimitError,
            anthropic.APIConnectionError,
            anthropic.APIStatusError,
        ],
        max_retries=None,  # retry forever, consider changing
        on_exception_callback=_exception_callback,
    )
    def messages():
        response = client.messages.create(
            model=model,
            max_tokens=max_tokens,
            temperature=temperature,
            messages=[{"role": "user", "content": f"{prompt}"}],
            **kwargs,
        )
        return response.content[0].text

    return messages()


Baber Abbasi's avatar
Baber Abbasi committed
146
@register_model("anthropic-completions")
lintangsutawika's avatar
lintangsutawika committed
147
class AnthropicLM(LM):
baberabb's avatar
baberabb committed
148
    REQ_CHUNK_SIZE = 20  # TODO: not used
Jason Phang's avatar
Jason Phang committed
149

baberabb's avatar
baberabb committed
150
151
    def __init__(
        self,
152
        batch_size: int = 1,
baberabb's avatar
baberabb committed
153
154
        model: str = "claude-2.0",
        max_tokens_to_sample: int = 256,
155
156
        temperature: float = 0,  # defaults to 1
        **kwargs,  # top_p, top_k, etc.
Ethan Smith's avatar
Ethan Smith committed
157
    ) -> None:
baberabb's avatar
baberabb committed
158
        """Anthropic API wrapper.
Jason Phang's avatar
Jason Phang committed
159
160

        :param model: str
baberabb's avatar
baberabb committed
161
            Anthropic model e.g. 'claude-instant-v1', 'claude-2'
baberabb's avatar
baberabb committed
162
163
164
165
166
167
        :param max_tokens_to_sample: int
            Maximum number of tokens to sample from the model
        :param temperature: float
            Sampling temperature
        :param kwargs: Any
            Additional model_args to pass to the API client
Jason Phang's avatar
Jason Phang committed
168
169
        """
        super().__init__()
lintangsutawika's avatar
lintangsutawika committed
170

171
172
        try:
            import anthropic
173
174
        except ModuleNotFoundError as exception:
            raise type(exception)(
175
                "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
Seungwoo Ryu's avatar
Seungwoo Ryu committed
176
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
177
178
            )

Jason Phang's avatar
Jason Phang committed
179
        self.model = model
baberabb's avatar
baberabb committed
180
        # defaults to os.environ.get("ANTHROPIC_API_KEY")
baberabb's avatar
baberabb committed
181
        self.client = anthropic.Anthropic()
baberabb's avatar
baberabb committed
182
183
184
        self.temperature = temperature
        self.max_tokens_to_sample = max_tokens_to_sample
        self.tokenizer = self.client.get_tokenizer()
baberabb's avatar
baberabb committed
185
        self.kwargs = kwargs
Jason Phang's avatar
Jason Phang committed
186
187
188

    @property
    def eot_token_id(self):
baberabb's avatar
baberabb committed
189
        # Not sure but anthropic.HUMAN_PROMPT ?
Jason Phang's avatar
Jason Phang committed
190
191
192
        raise NotImplementedError("No idea about anthropic tokenization.")

    @property
baberabb's avatar
baberabb committed
193
    def max_length(self) -> int:
Jason Phang's avatar
Jason Phang committed
194
195
196
        return 2048

    @property
baberabb's avatar
baberabb committed
197
    def max_gen_toks(self) -> int:
baberabb's avatar
baberabb committed
198
        return self.max_tokens_to_sample
Jason Phang's avatar
Jason Phang committed
199
200
201
202

    @property
    def batch_size(self):
        # Isn't used because we override _loglikelihood_tokens
baberabb's avatar
baberabb committed
203
        raise NotImplementedError("No support for logits.")
Jason Phang's avatar
Jason Phang committed
204
205
206
207

    @property
    def device(self):
        # Isn't used because we override _loglikelihood_tokens
baberabb's avatar
baberabb committed
208
        raise NotImplementedError("No support for logits.")
Jason Phang's avatar
Jason Phang committed
209

210
    def tok_encode(self, string: str) -> list[int]:
baberabb's avatar
baberabb committed
211
        return self.tokenizer.encode(string).ids
Jason Phang's avatar
Jason Phang committed
212

213
    def tok_decode(self, tokens: list[int]) -> str:
baberabb's avatar
baberabb committed
214
        return self.tokenizer.decode(tokens)
Jason Phang's avatar
Jason Phang committed
215

Ethan Smith's avatar
Ethan Smith committed
216
    def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False):
Jason Phang's avatar
Jason Phang committed
217
218
        raise NotImplementedError("No support for logits.")

219
    def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
220
221
        try:
            import anthropic
222
223
        except ModuleNotFoundError as exception:
            raise type(exception)(
224
                "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
Seungwoo Ryu's avatar
Seungwoo Ryu committed
225
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
226
227
            )

Jason Phang's avatar
Jason Phang committed
228
229
230
        if not requests:
            return []

231
        _requests: list[tuple[str, dict]] = [req.args for req in requests]
haileyschoelkopf's avatar
haileyschoelkopf committed
232

Jason Phang's avatar
Jason Phang committed
233
        res = []
234
        for request in tqdm(_requests, disable=disable_tqdm):
baberabb's avatar
baberabb committed
235
236
237
            try:
                inp = request[0]
                request_args = request[1]
238
239
240
241
                # generation_kwargs
                until = request_args.get("until")
                max_gen_toks = request_args.get("max_gen_toks", self.max_length)
                temperature = request_args.get("temperature", self.temperature)
baberabb's avatar
baberabb committed
242
243
244
245
                response = anthropic_completion(
                    client=self.client,
                    model=self.model,
                    prompt=inp,
246
247
                    max_tokens_to_sample=max_gen_toks,
                    temperature=temperature,  # TODO: implement non-greedy sampling for Anthropic
baberabb's avatar
baberabb committed
248
                    stop=until,  # type: ignore
baberabb's avatar
baberabb committed
249
250
251
252
                    **self.kwargs,
                )
                res.append(response)

253
                self.cache_hook.add_partial("generate_until", request, response)
baberabb's avatar
baberabb committed
254
            except anthropic.APIConnectionError as e:  # type: ignore # noqa: F821
baberabb's avatar
baberabb committed
255
256
                eval_logger.critical(f"Server unreachable: {e.__cause__}")
                break
baberabb's avatar
baberabb committed
257
            except anthropic.APIStatusError as e:  # type: ignore # noqa: F821
baberabb's avatar
baberabb committed
258
259
                eval_logger.critical(f"API error {e.status_code}: {e.message}")
                break
haileyschoelkopf's avatar
haileyschoelkopf committed
260

Jason Phang's avatar
Jason Phang committed
261
262
263
264
265
266
267
        return res

    def _model_call(self, inps):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    def _model_generate(self, context, max_length, eos_token_id):
268
        # Isn't used because we override generate_until
Jason Phang's avatar
Jason Phang committed
269
        raise NotImplementedError()
baberabb's avatar
baberabb committed
270

271
    def loglikelihood(self, requests, disable_tqdm: bool = False):
baberabb's avatar
baberabb committed
272
273
        raise NotImplementedError("No support for logits.")

274
    def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
baberabb's avatar
baberabb committed
275
        raise NotImplementedError("No support for logits.")
Seungwoo Ryu's avatar
Seungwoo Ryu committed
276
277
278


@register_model("anthropic-chat", "anthropic-chat-completions")
Baber Abbasi's avatar
Baber Abbasi committed
279
class AnthropicChat(LocalCompletionsAPI):
Seungwoo Ryu's avatar
Seungwoo Ryu committed
280
281
    def __init__(
        self,
Baber Abbasi's avatar
Baber Abbasi committed
282
283
        base_url="https://api.anthropic.com/v1/messages",
        tokenizer_backend=None,
284
        max_thinking_tokens: int | None = None,
Baber Abbasi's avatar
Baber Abbasi committed
285
286
287
288
289
290
291
292
293
        **kwargs,
    ):
        super().__init__(
            base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
        )
        eval_logger.warning(
            "Chat completions does not support batching. Defaulting to batch size 1."
        )
        self._batch_size = 1
294
295
296
297
298
        if max_thinking_tokens == 0:
            max_thinking_tokens = None
        if max_thinking_tokens is not None:
            assert max_thinking_tokens >= 1024, "max_thinking_tokens must be >= 1024"
        self.max_thinking_tokens = max_thinking_tokens
Baber Abbasi's avatar
Baber Abbasi committed
299
300
301
302
        self.anthropic_version = "2023-06-01"
        eval_logger.warning(
            f"Using Anthropic Version: {self.anthropic_version}. Confirm the current version here: https://docs.anthropic.com/en/api/versioning"
        )
Seungwoo Ryu's avatar
Seungwoo Ryu committed
303

Baber Abbasi's avatar
Baber Abbasi committed
304
305
306
307
308
309
310
    @cached_property
    def api_key(self):
        """Override this property to return the API key for the API request."""
        key = os.environ.get("ANTHROPIC_API_KEY", None)
        if key is None:
            raise ValueError(
                "API key not found. Please set the ANTHROPIC_API_KEY environment variable."
Seungwoo Ryu's avatar
Seungwoo Ryu committed
311
            )
Baber Abbasi's avatar
Baber Abbasi committed
312
313
314
315
316
317
318
319
320
321
        return key

    @cached_property
    def header(self):
        return {
            "x-api-key": f"{self.api_key}",
            "anthropic-version": self.anthropic_version,
        }

    def _create_payload(
322
        self,
323
        messages: list[dict],
324
        generate=True,
325
        gen_kwargs: dict | None = None,
326
327
        eos="\n\nHuman:",
        **kwargs,
Baber Abbasi's avatar
Baber Abbasi committed
328
    ) -> dict:
329
        gen_kwargs = gen_kwargs or {}
Baber Abbasi's avatar
Baber Abbasi committed
330
331
332
333
334
        system = (
            messages[0].get("content") if messages[0].get("role") == "system" else None
        )
        if system:
            messages = messages[1:]
335
336
337
338
339
340
341
342
343
344
345

        cleaned_messages = []
        for msg in messages:
            cleaned_msg = {
                "role": msg["role"],
                "content": [
                    {"type": msg["type"], msg["type"]: msg["content"]},
                ],
            }
            cleaned_messages.append(cleaned_msg)

Baber Abbasi's avatar
Baber Abbasi committed
346
347
348
        gen_kwargs.pop("do_sample", False)
        max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
        temperature = gen_kwargs.pop("temperature", 0)
349
        stop = handle_stop_sequences(gen_kwargs.pop("until", ["\n\nHuman:"]), eos=eos)
Baber Abbasi's avatar
Baber Abbasi committed
350
351
        if not isinstance(stop, list):
            stop = [stop]
352
353
354
355

        # Filter out empty or whitespace-only stop sequences for Anthropic API
        stop = [s for s in stop if s and s.strip()]

Baber Abbasi's avatar
Baber Abbasi committed
356
        out = {
357
            "messages": cleaned_messages,
Baber Abbasi's avatar
Baber Abbasi committed
358
359
360
361
362
363
364
365
            "model": self.model,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "stop_sequences": stop,
            **gen_kwargs,
        }
        if system:
            out["system"] = system
366
367
368
369
        if self.max_thinking_tokens:
            out["thinking"] = (
                {"type": "enabled", "budget_tokens": self.max_thinking_tokens},
            )
Baber Abbasi's avatar
Baber Abbasi committed
370
371
        return out

372
    def parse_generations(self, outputs: dict | list[dict], **kwargs) -> list[str]:
Seungwoo Ryu's avatar
Seungwoo Ryu committed
373
        res = []
Baber Abbasi's avatar
Baber Abbasi committed
374
375
376
377
        if not isinstance(outputs, list):
            outputs = [outputs]
        for out in outputs:
            for choices in out["content"]:
378
379
                if _out := choices.get("text"):
                    res.append(_out)
Seungwoo Ryu's avatar
Seungwoo Ryu committed
380
        return res
Baber Abbasi's avatar
Baber Abbasi committed
381
382
383
384
385
386
387

    def tok_encode(
        self,
        string: str,
        left_truncate_len=None,
        add_special_tokens=None,
        **kwargs,
388
    ) -> list[str]:
Baber Abbasi's avatar
Baber Abbasi committed
389
390
        return [string]

Baber Abbasi's avatar
Baber Abbasi committed
391
    def loglikelihood(self, requests, **kwargs):
Baber Abbasi's avatar
Baber Abbasi committed
392
        raise NotImplementedError(
Baber Abbasi's avatar
Baber Abbasi committed
393
            "Anthropic Chat Completions API does not support the return of loglikelihood"
Baber Abbasi's avatar
Baber Abbasi committed
394
        )