anthropic_llms.py 6.9 KB
Newer Older
1
2
3
4
from typing import Any, List, Tuple

from tqdm import tqdm

5
from lm_eval import utils
6
7
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
8
from lm_eval.utils import retry_on_specific_exceptions
9

Jason Phang's avatar
Jason Phang committed
10

11
eval_logger = utils.eval_logger
Jason Phang's avatar
Jason Phang committed
12

lintangsutawika's avatar
lintangsutawika committed
13

lintangsutawika's avatar
lintangsutawika committed
14
def anthropic_completion(
15
    client,  #: anthropic.Anthropic,
baberabb's avatar
baberabb committed
16
17
18
19
20
    model: str,
    prompt: str,
    max_tokens_to_sample: int,
    temperature: float,
    stop: List[str],
baberabb's avatar
baberabb committed
21
    **kwargs: Any,
baberabb's avatar
baberabb committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
) -> str:
    """Wrapper function around the Anthropic completion API client with exponential back-off
    in case of RateLimitError.

    params:
        client: anthropic.Anthropic
            Anthropic API client
        model: str
            Anthropic model e.g. 'claude-instant-v1', 'claude-2'
        prompt: str
            Prompt to feed to the model
        max_tokens_to_sample: int
            Maximum number of tokens to sample from the model
        temperature: float
            Sampling temperature
        stop: List[str]
            List of stop sequences
        kwargs: Any
            Additional model_args to pass to the API client
Jason Phang's avatar
Jason Phang committed
41
    """
42
43
44
45
46
47
48
49
50

    try:
        import anthropic
    except ModuleNotFoundError:
        raise Exception(
            "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e .[anthropic]`",
        )

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    def _exception_callback(e: Exception, sleep_time: float) -> None:
        eval_logger.warning(
            f"RateLimitError occurred: {e.__cause__}\n Retrying in {sleep_time} seconds"
        )

    @retry_on_specific_exceptions(
        on_exceptions=[anthropic.RateLimitError],
        max_retries=None,  # retry forever, consider changing
        on_exception_callback=_exception_callback,
    )
    def completion():
        response = client.completions.create(
            prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
            model=model,
            # NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
            #       (e.g. gsm8k's ":") may truncate a lot of the input.
            stop_sequences=[anthropic.HUMAN_PROMPT] + stop,
            max_tokens_to_sample=max_tokens_to_sample,
            temperature=temperature,
            **kwargs,
        )
        return response.completion

    return completion()
Jason Phang's avatar
Jason Phang committed
75
76


haileyschoelkopf's avatar
haileyschoelkopf committed
77
@register_model("anthropic")
lintangsutawika's avatar
lintangsutawika committed
78
class AnthropicLM(LM):
baberabb's avatar
baberabb committed
79
    REQ_CHUNK_SIZE = 20  # TODO: not used
Jason Phang's avatar
Jason Phang committed
80

baberabb's avatar
baberabb committed
81
82
    def __init__(
        self,
83
        batch_size: int = 1,
baberabb's avatar
baberabb committed
84
85
        model: str = "claude-2.0",
        max_tokens_to_sample: int = 256,
86
87
        temperature: float = 0,  # defaults to 1
        **kwargs,  # top_p, top_k, etc.
Ethan Smith's avatar
Ethan Smith committed
88
    ) -> None:
baberabb's avatar
baberabb committed
89
        """Anthropic API wrapper.
Jason Phang's avatar
Jason Phang committed
90
91

        :param model: str
baberabb's avatar
baberabb committed
92
            Anthropic model e.g. 'claude-instant-v1', 'claude-2'
baberabb's avatar
baberabb committed
93
94
95
96
97
98
        :param max_tokens_to_sample: int
            Maximum number of tokens to sample from the model
        :param temperature: float
            Sampling temperature
        :param kwargs: Any
            Additional model_args to pass to the API client
Jason Phang's avatar
Jason Phang committed
99
100
        """
        super().__init__()
lintangsutawika's avatar
lintangsutawika committed
101

102
103
104
105
106
107
108
109
        try:
            import anthropic
        except ModuleNotFoundError:
            raise Exception(
                "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e .[anthropic]`",
            )

Jason Phang's avatar
Jason Phang committed
110
        self.model = model
baberabb's avatar
baberabb committed
111
        # defaults to os.environ.get("ANTHROPIC_API_KEY")
baberabb's avatar
baberabb committed
112
        self.client = anthropic.Anthropic()
baberabb's avatar
baberabb committed
113
114
115
        self.temperature = temperature
        self.max_tokens_to_sample = max_tokens_to_sample
        self.tokenizer = self.client.get_tokenizer()
baberabb's avatar
baberabb committed
116
        self.kwargs = kwargs
Jason Phang's avatar
Jason Phang committed
117
118
119

    @property
    def eot_token_id(self):
baberabb's avatar
baberabb committed
120
        # Not sure but anthropic.HUMAN_PROMPT ?
Jason Phang's avatar
Jason Phang committed
121
122
123
        raise NotImplementedError("No idea about anthropic tokenization.")

    @property
baberabb's avatar
baberabb committed
124
    def max_length(self) -> int:
Jason Phang's avatar
Jason Phang committed
125
126
127
        return 2048

    @property
baberabb's avatar
baberabb committed
128
    def max_gen_toks(self) -> int:
baberabb's avatar
baberabb committed
129
        return self.max_tokens_to_sample
Jason Phang's avatar
Jason Phang committed
130
131
132
133

    @property
    def batch_size(self):
        # Isn't used because we override _loglikelihood_tokens
baberabb's avatar
baberabb committed
134
        raise NotImplementedError("No support for logits.")
Jason Phang's avatar
Jason Phang committed
135
136
137
138

    @property
    def device(self):
        # Isn't used because we override _loglikelihood_tokens
baberabb's avatar
baberabb committed
139
        raise NotImplementedError("No support for logits.")
Jason Phang's avatar
Jason Phang committed
140

baberabb's avatar
baberabb committed
141
142
    def tok_encode(self, string: str) -> List[int]:
        return self.tokenizer.encode(string).ids
Jason Phang's avatar
Jason Phang committed
143

baberabb's avatar
baberabb committed
144
145
    def tok_decode(self, tokens: List[int]) -> str:
        return self.tokenizer.decode(tokens)
Jason Phang's avatar
Jason Phang committed
146

Ethan Smith's avatar
Ethan Smith committed
147
    def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False):
Jason Phang's avatar
Jason Phang committed
148
149
        raise NotImplementedError("No support for logits.")

150
    def generate_until(self, requests) -> List[str]:
151
152
153
154
155
156
157
158
        try:
            import anthropic
        except ModuleNotFoundError:
            raise Exception(
                "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e .[anthropic]`",
            )

Jason Phang's avatar
Jason Phang committed
159
160
161
        if not requests:
            return []

baberabb's avatar
baberabb committed
162
        _requests: List[Tuple[str, dict]] = [req.args for req in requests]
haileyschoelkopf's avatar
haileyschoelkopf committed
163

Jason Phang's avatar
Jason Phang committed
164
        res = []
baberabb's avatar
baberabb committed
165
        for request in tqdm(_requests):
baberabb's avatar
baberabb committed
166
167
168
            try:
                inp = request[0]
                request_args = request[1]
169
170
171
172
                # generation_kwargs
                until = request_args.get("until")
                max_gen_toks = request_args.get("max_gen_toks", self.max_length)
                temperature = request_args.get("temperature", self.temperature)
baberabb's avatar
baberabb committed
173
174
175
176
                response = anthropic_completion(
                    client=self.client,
                    model=self.model,
                    prompt=inp,
177
178
                    max_tokens_to_sample=max_gen_toks,
                    temperature=temperature,  # TODO: implement non-greedy sampling for Anthropic
baberabb's avatar
baberabb committed
179
                    stop=until,  # type: ignore
baberabb's avatar
baberabb committed
180
181
182
183
                    **self.kwargs,
                )
                res.append(response)

184
                self.cache_hook.add_partial("generate_until", request, response)
baberabb's avatar
baberabb committed
185
            except anthropic.APIConnectionError as e:  # type: ignore # noqa: F821
baberabb's avatar
baberabb committed
186
187
                eval_logger.critical(f"Server unreachable: {e.__cause__}")
                break
baberabb's avatar
baberabb committed
188
            except anthropic.APIStatusError as e:  # type: ignore # noqa: F821
baberabb's avatar
baberabb committed
189
190
                eval_logger.critical(f"API error {e.status_code}: {e.message}")
                break
haileyschoelkopf's avatar
haileyschoelkopf committed
191

Jason Phang's avatar
Jason Phang committed
192
193
194
195
196
197
198
        return res

    def _model_call(self, inps):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    def _model_generate(self, context, max_length, eos_token_id):
199
        # Isn't used because we override generate_until
Jason Phang's avatar
Jason Phang committed
200
        raise NotImplementedError()
baberabb's avatar
baberabb committed
201
202
203
204
205
206

    def loglikelihood(self, requests):
        raise NotImplementedError("No support for logits.")

    def loglikelihood_rolling(self, requests):
        raise NotImplementedError("No support for logits.")