anthropic_llms.py 4.93 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
import os
haileyschoelkopf's avatar
haileyschoelkopf committed
2
3
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
Jason Phang's avatar
Jason Phang committed
4
5
from tqdm import tqdm
import time
baberabb's avatar
baberabb committed
6
7
import anthropic
from lm_eval.logger import eval_logger
baberabb's avatar
baberabb committed
8
from typing import List, Literal, Any
Jason Phang's avatar
Jason Phang committed
9
10


lintangsutawika's avatar
lintangsutawika committed
11
def anthropic_completion(
baberabb's avatar
baberabb committed
12
13
14
15
16
17
    client: anthropic.Anthropic,
    model: str,
    prompt: str,
    max_tokens_to_sample: int,
    temperature: float,
    stop: List[str],
baberabb's avatar
baberabb committed
18
    **kwargs: Any,
lintangsutawika's avatar
lintangsutawika committed
19
):
Jason Phang's avatar
Jason Phang committed
20
21
22
23
24
25
26
    """Query Anthropic API for completion.

    Retry with back-off until they respond
    """
    backoff_time = 3
    while True:
        try:
baberabb's avatar
baberabb committed
27
            response = client.completions.create(
Jason Phang's avatar
Jason Phang committed
28
29
30
31
32
33
34
                prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
                model=model,
                # NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
                #       (e.g. gsm8k's ":") may truncate a lot of the input.
                stop_sequences=[anthropic.HUMAN_PROMPT] + stop,
                max_tokens_to_sample=max_tokens_to_sample,
                temperature=temperature,
baberabb's avatar
baberabb committed
35
                **kwargs,
Jason Phang's avatar
Jason Phang committed
36
            )
baberabb's avatar
baberabb committed
37
38
39
40
41
            return response.completion
        except anthropic.RateLimitError as e:
            eval_logger.warning(
                f"RateLimitError occurred: {e.__cause__}\n Retrying in {backoff_time} seconds"
            )
Jason Phang's avatar
Jason Phang committed
42
43
            time.sleep(backoff_time)
            backoff_time *= 1.5
baberabb's avatar
baberabb committed
44
45
46
47
48
49
        except anthropic.APIConnectionError as e:
            eval_logger.critical(f"Server unreachable: {e.__cause__}")
            break
        except anthropic.APIStatusError as e:
            eval_logger.critical(f"API error {e.status_code}: {e.message}")
            break
Jason Phang's avatar
Jason Phang committed
50
51


haileyschoelkopf's avatar
haileyschoelkopf committed
52
@register_model("anthropic")
lintangsutawika's avatar
lintangsutawika committed
53
class AnthropicLM(LM):
baberabb's avatar
baberabb committed
54
    REQ_CHUNK_SIZE = 20  # TODO: not used
Jason Phang's avatar
Jason Phang committed
55

baberabb's avatar
baberabb committed
56
57
58
59
60
    def __init__(
        self,
        batch_size=None,
        model: str = "claude-2.0",
        max_tokens_to_sample: int = 256,
baberabb's avatar
baberabb committed
61
62
        temperature: float = 1.0,  # defaults to 1
        **kwargs: Any,  # top_p, top_k, etc.
baberabb's avatar
baberabb committed
63
64
    ):  # TODO: remove batch_size
        """Anthropic API wrapper.
Jason Phang's avatar
Jason Phang committed
65
66

        :param model: str
baberabb's avatar
baberabb committed
67
            Anthropic model e.g. 'claude-instant-v1', 'claude-2'
baberabb's avatar
baberabb committed
68
69
70
71
72
73
        :param max_tokens_to_sample: int
            Maximum number of tokens to sample from the model
        :param temperature: float
            Sampling temperature
        :param kwargs: Any
            Additional model_args to pass to the API client
Jason Phang's avatar
Jason Phang committed
74
75
        """
        super().__init__()
lintangsutawika's avatar
lintangsutawika committed
76

Jason Phang's avatar
Jason Phang committed
77
        self.model = model
baberabb's avatar
baberabb committed
78
        # defaults to os.environ.get("ANTHROPIC_API_KEY")
baberabb's avatar
baberabb committed
79
        self.client = anthropic.Anthropic()
baberabb's avatar
baberabb committed
80
81
82
        self.temperature = temperature
        self.max_tokens_to_sample = max_tokens_to_sample
        self.tokenizer = self.client.get_tokenizer()
baberabb's avatar
baberabb committed
83
        self.kwargs = kwargs
Jason Phang's avatar
Jason Phang committed
84
85
86

    @property
    def eot_token_id(self):
baberabb's avatar
baberabb committed
87
        # Not sure but anthropic.AI_PROMPT -> [203, 203, 50803, 30]
Jason Phang's avatar
Jason Phang committed
88
89
90
91
92
93
94
95
        raise NotImplementedError("No idea about anthropic tokenization.")

    @property
    def max_length(self):
        return 2048

    @property
    def max_gen_toks(self):
baberabb's avatar
baberabb committed
96
        return self.max_tokens_to_sample
Jason Phang's avatar
Jason Phang committed
97
98
99
100

    @property
    def batch_size(self):
        # Isn't used because we override _loglikelihood_tokens
baberabb's avatar
baberabb committed
101
        raise NotImplementedError("No support for logits.")
Jason Phang's avatar
Jason Phang committed
102
103
104
105

    @property
    def device(self):
        # Isn't used because we override _loglikelihood_tokens
baberabb's avatar
baberabb committed
106
        raise NotImplementedError("No support for logits.")
Jason Phang's avatar
Jason Phang committed
107

baberabb's avatar
baberabb committed
108
109
    def tok_encode(self, string: str) -> List[int]:
        return self.tokenizer.encode(string).ids
Jason Phang's avatar
Jason Phang committed
110

baberabb's avatar
baberabb committed
111
112
    def tok_decode(self, tokens: List[int]) -> str:
        return self.tokenizer.decode(tokens)
Jason Phang's avatar
Jason Phang committed
113
114
115
116
117
118
119
120

    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
        raise NotImplementedError("No support for logits.")

    def greedy_until(self, requests):
        if not requests:
            return []

haileyschoelkopf's avatar
haileyschoelkopf committed
121
122
        requests = [req.args for req in requests]

Jason Phang's avatar
Jason Phang committed
123
124
125
126
127
128
129
130
131
        res = []
        for request in tqdm(requests):
            inp = request[0]
            request_args = request[1]
            until = request_args["until"]
            response = anthropic_completion(
                client=self.client,
                model=self.model,
                prompt=inp,
baberabb's avatar
baberabb committed
132
133
                max_tokens_to_sample=self.max_tokens_to_sample,
                temperature=self.temperature,  # TODO: implement non-greedy sampling for Anthropic
Jason Phang's avatar
Jason Phang committed
134
                stop=until,
baberabb's avatar
baberabb committed
135
                **self.kwargs,
Jason Phang's avatar
Jason Phang committed
136
137
            )
            res.append(response)
haileyschoelkopf's avatar
haileyschoelkopf committed
138
139
140

            self.cache_hook.add_partial("greedy_until", request, response)

Jason Phang's avatar
Jason Phang committed
141
142
143
144
145
146
147
148
149
        return res

    def _model_call(self, inps):
        # Isn't used because we override _loglikelihood_tokens
        raise NotImplementedError()

    def _model_generate(self, context, max_length, eos_token_id):
        # Isn't used because we override greedy_until
        raise NotImplementedError()
baberabb's avatar
baberabb committed
150
151
152
153
154
155

    def loglikelihood(self, requests):
        raise NotImplementedError("No support for logits.")

    def loglikelihood_rolling(self, requests):
        raise NotImplementedError("No support for logits.")