Commit 3e28eed1 authored by Baber's avatar Baber
Browse files

add `max_thinking_tokens` for anthropic

parent c0fc7172
from __future__ import annotations
import logging
import os
from functools import cached_property
from typing import Any, Dict, List, Tuple, Union
from typing import Any
from tqdm import tqdm
......@@ -20,7 +22,7 @@ def anthropic_completion(
prompt: str,
max_tokens_to_sample: int,
temperature: float,
stop: List[str],
stop: list[str],
**kwargs: Any,
) -> str:
"""Wrapper function around the Anthropic completion API client with exponential back-off
......@@ -83,7 +85,7 @@ def anthropic_chat(
prompt: str,
max_tokens: int,
temperature: float,
stop: List[str],
stop: list[str],
**kwargs: Any,
) -> str:
"""Wrapper function around the Anthropic completion API client with exponential back-off
......@@ -205,16 +207,16 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError("No support for logits.")
def tok_encode(self, string: str) -> List[int]:
def tok_encode(self, string: str) -> list[int]:
return self.tokenizer.encode(string).ids
def tok_decode(self, tokens: List[int]) -> str:
def tok_decode(self, tokens: list[int]) -> str:
return self.tokenizer.decode(tokens)
def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.")
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
try:
import anthropic
except ModuleNotFoundError as exception:
......@@ -226,7 +228,7 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install
if not requests:
return []
_requests: List[Tuple[str, dict]] = [req.args for req in requests]
_requests: list[tuple[str, dict]] = [req.args for req in requests]
res = []
for request in tqdm(_requests, disable=disable_tqdm):
......@@ -279,6 +281,7 @@ class AnthropicChat(LocalCompletionsAPI):
self,
base_url="https://api.anthropic.com/v1/messages",
tokenizer_backend=None,
max_thinking_tokens: int | None = None,
**kwargs,
):
super().__init__(
......@@ -288,6 +291,11 @@ class AnthropicChat(LocalCompletionsAPI):
"Chat completions does not support batching. Defaulting to batch size 1."
)
self._batch_size = 1
if max_thinking_tokens == 0:
max_thinking_tokens = None
if max_thinking_tokens is not None:
assert max_thinking_tokens >= 1024, "max_thinking_tokens must be >= 1024"
self.max_thinking_tokens = max_thinking_tokens
self.anthropic_version = "2023-06-01"
eval_logger.warning(
f"Using Anthropic Version: {self.anthropic_version}. Confirm the current version here: https://docs.anthropic.com/en/api/versioning"
......@@ -312,12 +320,13 @@ class AnthropicChat(LocalCompletionsAPI):
def _create_payload(
self,
messages: List[Dict],
messages: list[dict],
generate=True,
gen_kwargs: dict = None,
gen_kwargs: dict | None = None,
eos="\n\nHuman:",
**kwargs,
) -> dict:
gen_kwargs = gen_kwargs or {}
system = (
messages[0].get("content") if messages[0].get("role") == "system" else None
)
......@@ -354,17 +363,20 @@ class AnthropicChat(LocalCompletionsAPI):
}
if system:
out["system"] = system
if self.max_thinking_tokens:
out["thinking"] = (
{"type": "enabled", "budget_tokens": self.max_thinking_tokens},
)
return out
def parse_generations(
self, outputs: Union[Dict, List[Dict]], **kwargs
) -> List[str]:
def parse_generations(self, outputs: dict | list[dict], **kwargs) -> list[str]:
res = []
if not isinstance(outputs, list):
outputs = [outputs]
for out in outputs:
for choices in out["content"]:
res.append(choices["text"])
if _out := choices.get("text"):
res.append(_out)
return res
def tok_encode(
......@@ -373,7 +385,7 @@ class AnthropicChat(LocalCompletionsAPI):
left_truncate_len=None,
add_special_tokens=None,
**kwargs,
) -> List[str]:
) -> list[str]:
return [string]
def loglikelihood(self, requests, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment