Commit 3e28eed1 authored by Baber's avatar Baber
Browse files

add `max_thinking_tokens` for anthropic

parent c0fc7172
from __future__ import annotations
import logging import logging
import os import os
from functools import cached_property from functools import cached_property
from typing import Any, Dict, List, Tuple, Union from typing import Any
from tqdm import tqdm from tqdm import tqdm
...@@ -20,7 +22,7 @@ def anthropic_completion( ...@@ -20,7 +22,7 @@ def anthropic_completion(
prompt: str, prompt: str,
max_tokens_to_sample: int, max_tokens_to_sample: int,
temperature: float, temperature: float,
stop: List[str], stop: list[str],
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
"""Wrapper function around the Anthropic completion API client with exponential back-off """Wrapper function around the Anthropic completion API client with exponential back-off
...@@ -83,7 +85,7 @@ def anthropic_chat( ...@@ -83,7 +85,7 @@ def anthropic_chat(
prompt: str, prompt: str,
max_tokens: int, max_tokens: int,
temperature: float, temperature: float,
stop: List[str], stop: list[str],
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
"""Wrapper function around the Anthropic completion API client with exponential back-off """Wrapper function around the Anthropic completion API client with exponential back-off
...@@ -205,16 +207,16 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install ...@@ -205,16 +207,16 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install
# Isn't used because we override _loglikelihood_tokens # Isn't used because we override _loglikelihood_tokens
raise NotImplementedError("No support for logits.") raise NotImplementedError("No support for logits.")
def tok_encode(self, string: str) -> List[int]: def tok_encode(self, string: str) -> list[int]:
return self.tokenizer.encode(string).ids return self.tokenizer.encode(string).ids
def tok_decode(self, tokens: List[int]) -> str: def tok_decode(self, tokens: list[int]) -> str:
return self.tokenizer.decode(tokens) return self.tokenizer.decode(tokens)
def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False): def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.") raise NotImplementedError("No support for logits.")
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]: def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
try: try:
import anthropic import anthropic
except ModuleNotFoundError as exception: except ModuleNotFoundError as exception:
...@@ -226,7 +228,7 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install ...@@ -226,7 +228,7 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install
if not requests: if not requests:
return [] return []
_requests: List[Tuple[str, dict]] = [req.args for req in requests] _requests: list[tuple[str, dict]] = [req.args for req in requests]
res = [] res = []
for request in tqdm(_requests, disable=disable_tqdm): for request in tqdm(_requests, disable=disable_tqdm):
...@@ -279,6 +281,7 @@ class AnthropicChat(LocalCompletionsAPI): ...@@ -279,6 +281,7 @@ class AnthropicChat(LocalCompletionsAPI):
self, self,
base_url="https://api.anthropic.com/v1/messages", base_url="https://api.anthropic.com/v1/messages",
tokenizer_backend=None, tokenizer_backend=None,
max_thinking_tokens: int | None = None,
**kwargs, **kwargs,
): ):
super().__init__( super().__init__(
...@@ -288,6 +291,11 @@ class AnthropicChat(LocalCompletionsAPI): ...@@ -288,6 +291,11 @@ class AnthropicChat(LocalCompletionsAPI):
"Chat completions does not support batching. Defaulting to batch size 1." "Chat completions does not support batching. Defaulting to batch size 1."
) )
self._batch_size = 1 self._batch_size = 1
if max_thinking_tokens == 0:
max_thinking_tokens = None
if max_thinking_tokens is not None:
assert max_thinking_tokens >= 1024, "max_thinking_tokens must be >= 1024"
self.max_thinking_tokens = max_thinking_tokens
self.anthropic_version = "2023-06-01" self.anthropic_version = "2023-06-01"
eval_logger.warning( eval_logger.warning(
f"Using Anthropic Version: {self.anthropic_version}. Confirm the current version here: https://docs.anthropic.com/en/api/versioning" f"Using Anthropic Version: {self.anthropic_version}. Confirm the current version here: https://docs.anthropic.com/en/api/versioning"
...@@ -312,12 +320,13 @@ class AnthropicChat(LocalCompletionsAPI): ...@@ -312,12 +320,13 @@ class AnthropicChat(LocalCompletionsAPI):
def _create_payload( def _create_payload(
self, self,
messages: List[Dict], messages: list[dict],
generate=True, generate=True,
gen_kwargs: dict = None, gen_kwargs: dict | None = None,
eos="\n\nHuman:", eos="\n\nHuman:",
**kwargs, **kwargs,
) -> dict: ) -> dict:
gen_kwargs = gen_kwargs or {}
system = ( system = (
messages[0].get("content") if messages[0].get("role") == "system" else None messages[0].get("content") if messages[0].get("role") == "system" else None
) )
...@@ -354,17 +363,20 @@ class AnthropicChat(LocalCompletionsAPI): ...@@ -354,17 +363,20 @@ class AnthropicChat(LocalCompletionsAPI):
} }
if system: if system:
out["system"] = system out["system"] = system
if self.max_thinking_tokens:
out["thinking"] = (
{"type": "enabled", "budget_tokens": self.max_thinking_tokens},
)
return out return out
def parse_generations( def parse_generations(self, outputs: dict | list[dict], **kwargs) -> list[str]:
self, outputs: Union[Dict, List[Dict]], **kwargs
) -> List[str]:
res = [] res = []
if not isinstance(outputs, list): if not isinstance(outputs, list):
outputs = [outputs] outputs = [outputs]
for out in outputs: for out in outputs:
for choices in out["content"]: for choices in out["content"]:
res.append(choices["text"]) if _out := choices.get("text"):
res.append(_out)
return res return res
def tok_encode( def tok_encode(
...@@ -373,7 +385,7 @@ class AnthropicChat(LocalCompletionsAPI): ...@@ -373,7 +385,7 @@ class AnthropicChat(LocalCompletionsAPI):
left_truncate_len=None, left_truncate_len=None,
add_special_tokens=None, add_special_tokens=None,
**kwargs, **kwargs,
) -> List[str]: ) -> list[str]:
return [string] return [string]
def loglikelihood(self, requests, **kwargs): def loglikelihood(self, requests, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment