Unverified Commit 42dc2448 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Refactor API models (#2008)



* refactor pad_token handling to fn

* fix docs

* add pad_token_handling to vllm

* start on API superclass

* don't detokenize the returned logits

* streamline vllm tokenizer

* add type hint

* pre-commit

* seems to be in working order

* add model to init

* refactor api models

* nit

* cleanup

* add pbar

* fix type hints

* change optional dependencies

* json encode chat template

* add type hints

* deal with different prompt input requiremnts

* nits

* fix

* cache inside async

* fix

* fix

* nits

* nits

* nits

* nit

* fixup

* fixup

* nit

* add dummy retry

* add dummy retry

* handle imports; skip failing test

* add type hint

* add tests

* add dependency to tests

* add package names to exception

* nit

* docs; type hints

* handle api key

* nit

* tokenizer bug

* fix tokenizer

* nit

* nit

* add better error messages

* nit

* remove decorator

* CI: install api dep

* revert evaluator.py

* consolidate

* consolidate

* nits

* nit

* fix typealias

* nit

* nit

* nit

* Update lm_eval/models/api_models.py

typo
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* Update lm_eval/models/openai_completions.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* Update lm_eval/models/anthropic_llms.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* Update lm_eval/models/api_models.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* fix typo

* add news section

* add info for API

* pre-commit

* typo

* fix bug: unpack logliklehood requests

* fix bug: shared gen_kwargs mutated

* nit: handle copy properly

* Update README.md

* Update README.md

* Update README.md

* Update api_models.py

* Update README.md

---------
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent 4a62757d
...@@ -56,7 +56,7 @@ jobs: ...@@ -56,7 +56,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install -e '.[dev,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu pip install -e '.[dev,sentencepiece,api]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies # Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt # pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
...@@ -84,7 +84,7 @@ jobs: ...@@ -84,7 +84,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install -e '.[dev,optimum,deepsparse,sparseml]' --extra-index-url https://download.pytorch.org/whl/cpu pip install -e '.[dev,optimum,deepsparse,sparseml,api]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Test with pytest - name: Test with pytest
run: python -m pytest tests/models --showlocals -s -vv run: python -m pytest tests/models --showlocals -s -vv
- name: Archive artifacts - name: Archive artifacts
......
...@@ -2,6 +2,15 @@ ...@@ -2,6 +2,15 @@
[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.10256836.svg)](https://doi.org/10.5281/zenodo.10256836) [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.10256836.svg)](https://doi.org/10.5281/zenodo.10256836)
---
*Latest News 📣*
- [2024/07] API model support has been updated and refactored, introducing support for batched and async requests, and making it significantly easier to customize and use for your own purposes. **To run Llama 405B, we recommend using VLLM's OpenAI-compliant API to host the model, and use the `local-completions` model type to evaluate the model.**
- [2024/07] New Open LLM Leaderboard tasks have been added ! You can find them under the [leaderboard](lm_eval/tasks/leaderboard/README.md) task group.
---
## Announcement ## Announcement
**A new v0.4.0 release of lm-evaluation-harness is available** ! **A new v0.4.0 release of lm-evaluation-harness is available** !
...@@ -21,6 +30,8 @@ Please see our updated documentation pages in `docs/` for more details. ...@@ -21,6 +30,8 @@ Please see our updated documentation pages in `docs/` for more details.
Development will be continuing on the `main` branch, and we encourage you to give us feedback on what features are desired and how to improve the library further, or ask questions, either in issues or PRs on GitHub, or in the [EleutherAI discord](https://discord.gg/eleutherai)! Development will be continuing on the `main` branch, and we encourage you to give us feedback on what features are desired and how to improve the library further, or ask questions, either in issues or PRs on GitHub, or in the [EleutherAI discord](https://discord.gg/eleutherai)!
---
## Overview ## Overview
This project provides a unified framework to test generative language models on a large number of different evaluation tasks. This project provides a unified framework to test generative language models on a large number of different evaluation tasks.
...@@ -112,7 +123,7 @@ For cases where your model can fit on a single GPU, this allows you to evaluate ...@@ -112,7 +123,7 @@ For cases where your model can fit on a single GPU, this allows you to evaluate
The second way of using `accelerate` for multi-GPU evaluation is when your model is *too large to fit on a single GPU.* The second way of using `accelerate` for multi-GPU evaluation is when your model is *too large to fit on a single GPU.*
In this setting, run the library *outside of the `accelerate` launcher*, but passing `parallelize=True` to `--model_args` as follows: In this setting, run the library *outside the `accelerate` launcher*, but passing `parallelize=True` to `--model_args` as follows:
``` ```
lm_eval --model hf \ lm_eval --model hf \
...@@ -217,12 +228,12 @@ lm_eval --model openai-completions \ ...@@ -217,12 +228,12 @@ lm_eval --model openai-completions \
We also support using your own local inference server with servers that mirror the OpenAI Completions and ChatCompletions APIs. We also support using your own local inference server with servers that mirror the OpenAI Completions and ChatCompletions APIs.
```bash ```bash
lm_eval --model local-chat-completions --tasks gsm8k --model_args model=facebook/opt-125m,base_url=http://{yourip}:8000/v1 lm_eval --model local-completions --tasks gsm8k --model_args model=facebook/opt-125m,base_url=http://{yourip}:8000/v1,num_concurrent=1,max_retries=3,tokenized_requests=False
``` ```
Note that for externally hosted models, configs such as `--device` and `--batch_size` should not be used and do not function. Just like you can use `--model_args` to pass arbitrary arguments to the model constructor for local models, you can use it to pass arbitrary arguments to the model API for hosted models. See the documentation of the hosting service for information on what arguments they support. Note that for externally hosted models, configs such as `--device` and `--batch_size` should not be used and do not function. Just like you can use `--model_args` to pass arbitrary arguments to the model constructor for local models, you can use it to pass arbitrary arguments to the model API for hosted models. See the documentation of the hosting service for information on what arguments they support.
| API or Inference Server | Implemented? | `--model <xxx>` name | Models supported: | Request Types: | | API or Inference Server | Implemented? | `--model <xxx>` name | Models supported: | Request Types: |
|---------------------------------------------------------------------------------------------------------------------------|---------------------------------|---------------------------------------------------------------------|-----------------------------------------------------------------------------------------------|------------------------------------------------------------| |---------------------------------------------------------------------------------------------------------------------------|---------------------------------|-----------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------|
| OpenAI Completions | :heavy_check_mark: | `openai-completions`, `local-completions` | All OpenAI Completions API models | `generate_until`, `loglikelihood`, `loglikelihood_rolling` | | OpenAI Completions | :heavy_check_mark: | `openai-completions`, `local-completions` | All OpenAI Completions API models | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| OpenAI ChatCompletions | :heavy_check_mark: | `openai-chat-completions`, `local-chat-completions` | [All ChatCompletions API models](https://platform.openai.com/docs/guides/gpt) | `generate_until` (no logprobs) | | OpenAI ChatCompletions | :heavy_check_mark: | `openai-chat-completions`, `local-chat-completions` | [All ChatCompletions API models](https://platform.openai.com/docs/guides/gpt) | `generate_until` (no logprobs) |
| Anthropic | :heavy_check_mark: | `anthropic` | [Supported Anthropic Engines](https://docs.anthropic.com/claude/reference/selecting-a-model) | `generate_until` (no logprobs) | | Anthropic | :heavy_check_mark: | `anthropic` | [Supported Anthropic Engines](https://docs.anthropic.com/claude/reference/selecting-a-model) | `generate_until` (no logprobs) |
...@@ -236,7 +247,7 @@ Note that for externally hosted models, configs such as `--device` and `--batch_ ...@@ -236,7 +247,7 @@ Note that for externally hosted models, configs such as `--device` and `--batch_
| Neuron via AWS Inf2 (Causal LMs) | ✔️ | `neuronx` | Any decoder-only AutoModelForCausalLM supported to run on [huggingface-ami image for inferentia2](https://aws.amazon.com/marketplace/pp/prodview-gr3e6yiscria2) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` | ... | | Neuron via AWS Inf2 (Causal LMs) | ✔️ | `neuronx` | Any decoder-only AutoModelForCausalLM supported to run on [huggingface-ami image for inferentia2](https://aws.amazon.com/marketplace/pp/prodview-gr3e6yiscria2) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` | ... |
| [Neural Magic DeepSparse](https://github.com/neuralmagic/deepsparse) | ✔️ | `deepsparse` | Any LM from [SparseZoo](https://sparsezoo.neuralmagic.com/) or on [HF Hub with the "deepsparse" tag](https://huggingface.co/models?other=deepsparse) | `generate_until`, `loglikelihood` | ... | | [Neural Magic DeepSparse](https://github.com/neuralmagic/deepsparse) | ✔️ | `deepsparse` | Any LM from [SparseZoo](https://sparsezoo.neuralmagic.com/) or on [HF Hub with the "deepsparse" tag](https://huggingface.co/models?other=deepsparse) | `generate_until`, `loglikelihood` | ... |
| [Neural Magic SparseML](https://github.com/neuralmagic/sparseml) | ✔️ | `sparseml` | Any decoder-only AutoModelForCausalLM from [SparseZoo](https://sparsezoo.neuralmagic.com/) or on [HF Hub](https://huggingface.co/neuralmagic). Especially useful for models with quantization like [`zoo:llama2-7b-gsm8k_llama2_pretrain-pruned60_quantized`](https://sparsezoo.neuralmagic.com/models/llama2-7b-gsm8k_llama2_pretrain-pruned60_quantized) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` | ... | | [Neural Magic SparseML](https://github.com/neuralmagic/sparseml) | ✔️ | `sparseml` | Any decoder-only AutoModelForCausalLM from [SparseZoo](https://sparsezoo.neuralmagic.com/) or on [HF Hub](https://huggingface.co/neuralmagic). Especially useful for models with quantization like [`zoo:llama2-7b-gsm8k_llama2_pretrain-pruned60_quantized`](https://sparsezoo.neuralmagic.com/models/llama2-7b-gsm8k_llama2_pretrain-pruned60_quantized) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` | ... |
| Your local inference server! | :heavy_check_mark: | `local-completions` or `local-chat-completions` (using `openai-chat-completions` model type) | Any server address that accepts GET requests using HF models and mirror's OpenAI's Completions or ChatCompletions interface | `generate_until` | | ... | | Your local inference server! | :heavy_check_mark: | `local-completions` or `local-chat-completions` | Support for OpenAI API-compatible servers, with easy customization for other APIs. | `generate_until`, `loglikelihood`, `loglikelihood_rolling` | | ... |
Models which do not supply logits or logprobs can be used with tasks of type `generate_until` only, while local models, or APIs that supply logprobs/logits of their prompts, can be run on all task types: `generate_until`, `loglikelihood`, `loglikelihood_rolling`, and `multiple_choice`. Models which do not supply logits or logprobs can be used with tasks of type `generate_until` only, while local models, or APIs that supply logprobs/logits of their prompts, can be run on all task types: `generate_until`, `loglikelihood`, `loglikelihood_rolling`, and `multiple_choice`.
...@@ -437,8 +448,8 @@ The best way to get support is to open an issue on this repo or join the [Eleuth ...@@ -437,8 +448,8 @@ The best way to get support is to open an issue on this repo or join the [Eleuth
Extras dependencies can be installed via `pip install -e ".[NAME]"` Extras dependencies can be installed via `pip install -e ".[NAME]"`
| Name | Use | | Name | Use |
|---------------|---------------------------------------| |-----------------|----------------------------------------------|
| anthropic | For using Anthropic's models | | api | For using api models (Anthropic, OpenAI API) |
| deepsparse | For running NM's DeepSparse models | | deepsparse | For running NM's DeepSparse models |
| dev | For linting PRs and contributions | | dev | For linting PRs and contributions |
| gptq | For loading models with GPTQ | | gptq | For loading models with GPTQ |
...@@ -448,7 +459,6 @@ Extras dependencies can be installed via `pip install -e ".[NAME]"` ...@@ -448,7 +459,6 @@ Extras dependencies can be installed via `pip install -e ".[NAME]"`
| mamba | For loading Mamba SSM models | | mamba | For loading Mamba SSM models |
| math | For running math task answer checking | | math | For running math task answer checking |
| multilingual | For multilingual tokenizers | | multilingual | For multilingual tokenizers |
| openai | For using OpenAI's models |
| optimum | For running Intel OpenVINO models | | optimum | For running Intel OpenVINO models |
| promptsource | For using PromptSource prompts | | promptsource | For using PromptSource prompts |
| sentencepiece | For using the sentencepiece tokenizer | | sentencepiece | For using the sentencepiece tokenizer |
...@@ -456,7 +466,7 @@ Extras dependencies can be installed via `pip install -e ".[NAME]"` ...@@ -456,7 +466,7 @@ Extras dependencies can be installed via `pip install -e ".[NAME]"`
| testing | For running library test suite | | testing | For running library test suite |
| vllm | For loading models with vLLM | | vllm | For loading models with vLLM |
| zeno | For visualizing results with Zeno | | zeno | For visualizing results with Zeno |
|---------------|---------------------------------------| | --------------- | --------------------------------------- |
| all | Loads all extras (not recommended) | | all | Loads all extras (not recommended) |
## Cite as ## Cite as
......
...@@ -55,7 +55,7 @@ class LM(abc.ABC): ...@@ -55,7 +55,7 @@ class LM(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
def loglikelihood_rolling(self, requests) -> List[Tuple[float]]: def loglikelihood_rolling(self, requests) -> List[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation """Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model. - We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
...@@ -101,14 +101,13 @@ class LM(abc.ABC): ...@@ -101,14 +101,13 @@ class LM(abc.ABC):
"""Generate greedily until a stopping sequence """Generate greedily until a stopping sequence
:param requests: list[Instance] :param requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context, until). A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs).
context: str context: str
Context string Context string
until: [str] gen_kwargs: dict
The string sequences to generate until. These string sequences A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc.
may each span across multiple tokens, or may be part of one token.
:return: list[str] :return: list[str]
A list of strings continuation A list of model generated continuations.
continuation: str continuation: str
The generated continuation. The generated continuation.
""" """
...@@ -325,14 +324,19 @@ class TemplateLM(LM): ...@@ -325,14 +324,19 @@ class TemplateLM(LM):
return self.eot_token_id return self.eot_token_id
@abc.abstractmethod @abc.abstractmethod
def tok_encode(self, string: str, **kwargs): def tok_encode(self, string: str, **kwargs) -> List[int]:
"""
Tokenize a string using the model's tokenizer and return a list of token IDs.
"""
pass pass
@abc.abstractmethod @abc.abstractmethod
def _loglikelihood_tokens(self, requests, **kwargs): def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
pass pass
def _encode_pair(self, context, continuation): def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip()) n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0: if n_spaces > 0:
continuation = context[-n_spaces:] + continuation continuation = context[-n_spaces:] + continuation
...@@ -373,7 +377,7 @@ class TemplateLM(LM): ...@@ -373,7 +377,7 @@ class TemplateLM(LM):
@abc.abstractmethod @abc.abstractmethod
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests, disable_tqdm: bool = False self, requests, disable_tqdm: bool = False
) -> List[Tuple[float, bool]]: ) -> List[float]:
pass pass
@abc.abstractmethod @abc.abstractmethod
......
from . import ( from . import (
anthropic_llms, anthropic_llms,
api_models,
dummy, dummy,
gguf, gguf,
huggingface, huggingface,
......
from typing import Any, List, Tuple import os
from functools import cached_property
from typing import Any, Dict, List, Tuple, Union
from tqdm import tqdm from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.openai_completions import LocalCompletionsAPI
from lm_eval.models.utils import retry_on_specific_exceptions from lm_eval.models.utils import retry_on_specific_exceptions
...@@ -138,7 +141,7 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install ...@@ -138,7 +141,7 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install
return messages() return messages()
@register_model("anthropic") @register_model("anthropic-completions")
class AnthropicLM(LM): class AnthropicLM(LM):
REQ_CHUNK_SIZE = 20 # TODO: not used REQ_CHUNK_SIZE = 20 # TODO: not used
...@@ -271,90 +274,89 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install ...@@ -271,90 +274,89 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install
@register_model("anthropic-chat", "anthropic-chat-completions") @register_model("anthropic-chat", "anthropic-chat-completions")
class AnthropicChatLM(AnthropicLM): class AnthropicChat(LocalCompletionsAPI):
REQ_CHUNK_SIZE = 20 # TODO: not used
def __init__( def __init__(
self, self,
model: str, base_url="https://api.anthropic.com/v1/messages",
batch_size: int = 1, tokenizer_backend=None,
max_tokens: int = 256, **kwargs,
temperature: float = 0, # defaults to 1 ):
**kwargs, # top_p, top_k, etc. super().__init__(
) -> None: base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
"""Anthropic API wrapper.
:param model: str
Anthropic model e.g. 'claude-3-opus-20240229', 'claude-3-sonnet-20240229'
:param max_tokens: int
Maximum number of tokens to sample from the model
:param temperature: float
Sampling temperature
:param kwargs: Any
Additional model_args to pass to the API client
"""
super().__init__()
try:
import anthropic
except ModuleNotFoundError:
raise Exception(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
) )
eval_logger.warning(
self.model = model "Chat completions does not support batching. Defaulting to batch size 1."
# defaults to os.environ.get("ANTHROPIC_API_KEY") )
self.client = anthropic.Anthropic() self._batch_size = 1
self.temperature = temperature self.anthropic_version = "2023-06-01"
self.max_tokens = max_tokens eval_logger.warning(
self.tokenizer = self.client.get_tokenizer() f"Using Anthropic Version: {self.anthropic_version}. Confirm the current version here: https://docs.anthropic.com/en/api/versioning"
self.kwargs = kwargs
@property
def max_gen_toks(self) -> int:
return self.max_tokens
def generate_until(self, requests) -> List[str]:
try:
import anthropic
except ModuleNotFoundError:
raise Exception(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
) )
if not requests: @cached_property
return [] def api_key(self):
"""Override this property to return the API key for the API request."""
_requests: List[Tuple[str, dict]] = [req.args for req in requests] key = os.environ.get("ANTHROPIC_API_KEY", None)
if key is None:
res = [] raise ValueError(
for request in tqdm(_requests): "API key not found. Please set the ANTHROPIC_API_KEY environment variable."
try:
inp = request[0]
request_args = request[1]
# generation_kwargs
until = request_args.get("until")
max_tokens = request_args.get("max_gen_toks", self.max_length)
temperature = request_args.get("temperature", self.temperature)
response = anthropic_chat(
client=self.client,
model=self.model,
prompt=inp,
max_tokens=max_tokens,
temperature=temperature, # TODO: implement non-greedy sampling for Anthropic
stop=until, # type: ignore
**self.kwargs,
) )
res.append(response) return key
@cached_property
def header(self):
return {
"x-api-key": f"{self.api_key}",
"anthropic-version": self.anthropic_version,
}
def _create_payload(
self, messages: List[Dict], generate=True, gen_kwargs: dict = None, **kwargs
) -> dict:
system = (
messages[0].get("content") if messages[0].get("role") == "system" else None
)
if system:
messages = messages[1:]
gen_kwargs.pop("do_sample", False)
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
temperature = gen_kwargs.pop("temperature", 0)
stop = gen_kwargs.pop("until", ["\n\nHuman:"])
if not isinstance(stop, list):
stop = [stop]
out = {
"messages": messages,
"model": self.model,
"max_tokens": max_tokens,
"temperature": temperature,
"stop_sequences": stop,
**gen_kwargs,
}
if system:
out["system"] = system
return out
def parse_generations(
self, outputs: Union[Dict, List[Dict]], **kwargs
) -> List[str]:
res = []
if not isinstance(outputs, list):
outputs = [outputs]
for out in outputs:
for choices in out["content"]:
res.append(choices["text"])
return res
self.cache_hook.add_partial("generate_until", request, response) def tok_encode(
except anthropic.APIConnectionError as e: # type: ignore # noqa: F821 self,
eval_logger.critical(f"Server unreachable: {e.__cause__}") string: str,
break left_truncate_len=None,
except anthropic.APIStatusError as e: # type: ignore # noqa: F821 add_special_tokens=None,
eval_logger.critical(f"API error {e.status_code}: {e.message}") **kwargs,
break ) -> List[str]:
return [string]
return res def _loglikelihood_tokens(self, requests, **kwargs):
raise NotImplementedError(
"Anthropic Chat Completions API does not support the return of log"
)
import abc
import asyncio
import copy
import itertools
import json
from collections import namedtuple
from functools import cached_property
from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Tuple,
Union,
)
try:
import requests
from aiohttp import ClientSession, TCPConnector
from tenacity import RetryError, retry, stop_after_attempt, wait_exponential
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio
except ModuleNotFoundError:
pass
from importlib.util import find_spec
from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.model import TemplateLM
from lm_eval.models.utils import Collator, chunks, configure_pad_token
LogLikelihoodInputs = Tuple[Tuple[str, str], List[int], List[int]]
JsonChatStr = namedtuple("JsonChatStr", ["prompt"])
eval_logger = utils.eval_logger
class TemplateAPI(TemplateLM):
def __init__(
self,
model: str = None,
pretrained: str = None, # `model` takes precedence over `pretrained` when passed.
base_url: str = None,
tokenizer: Optional[str] = None,
# Logliklehood tasks require a tokenizer to calculate context lengths,
# however the requests can be sent as a string if the API doesn't support token inputs.
# use tokenized_requests=False
tokenizer_backend: Optional[
Literal["tiktoken", "huggingface", None]
] = "huggingface",
truncate: bool = False,
# number of concurrent requests. More useful if not batching
num_concurrent: int = 1,
max_retries: int = 3,
max_gen_toks: int = 256,
batch_size: Union[str, int] = 1,
seed: int = 1234,
max_length: Optional[int] = 2048,
add_bos_token: bool = False,
custom_prefix_token_id=None,
# send the requests as tokens or strings
tokenized_requests=True,
**kwargs,
) -> None:
super().__init__()
missing_packages = [
pkg
for pkg in ["aiohttp", "tqdm", "tenacity", "requests"]
if find_spec(pkg) is None
]
if missing_packages:
raise ModuleNotFoundError(
f"Attempted to use an API model, but the required packages {missing_packages} are not installed. "
'Please install these via `pip install lm-eval[api]` or `pip install -e ."[api]"`'
)
self.model = model or pretrained
self.base_url = base_url
self.tokenizer = tokenizer
if not isinstance(batch_size, int) and "auto" in batch_size:
eval_logger.warning(
"Automatic batch size is not supported for API models. Defaulting to batch size 1."
)
elif int(batch_size) > 1:
eval_logger.warning(
"Batch size > 1 detected. Ensure your API supports batched requests with varying total sequence lengths."
)
self._batch_size = int(batch_size) if batch_size != "auto" else 1
self._truncate = truncate
self._max_gen_toks = int(max_gen_toks)
self._seed = int(seed)
self.max_length = max_length
if int(num_concurrent) <= 1:
eval_logger.info(
"Concurrent requests are disabled. To enable concurrent requests, set `num_concurrent > 1`."
)
self._concurrent = int(num_concurrent)
self.tokenizer_backend = tokenizer_backend
self.add_bos_token = add_bos_token
self.custom_prefix_token_id = custom_prefix_token_id
self.tokenized_requests = tokenized_requests
self.max_retries = int(max_retries)
eval_logger.info(f"Using tokenizer {self.tokenizer_backend}")
if self.tokenizer_backend is None:
self.tokenizer = None
self.tokenized_requests = False
else:
if self.tokenizer_backend == "huggingface":
import transformers
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
self.tokenizer if self.tokenizer else self.model
)
# Not used as the API will handle padding but to mirror the behavior of the HFLM
self.tokenizer = configure_pad_token(self.tokenizer)
elif self.tokenizer_backend == "tiktoken":
try:
import tiktoken
self.tokenizer = tiktoken.encoding_for_model(self.model)
except ModuleNotFoundError as e:
raise Exception(
"Attempted to use 'openai' LM type, but the package `tiktoken` is not installed. "
"Please install it via `pip install lm-eval[api]` or `pip install -e .[api]`."
) from e
if "openai" not in self.base_url:
eval_logger.warning(
f"Passed `base_url={self.base_url}` but using (OpenAI) Tiktoken tokenizer backend. "
"Pass `tokenizer_backend=huggingface` and provide the HF tokenizer name if your model does not use Tiktoken."
)
@abc.abstractmethod
def _create_payload(
self,
messages: Union[List[List[int]], List[dict], List[str], str],
*,
generate: bool = True,
gen_kwargs: Optional[dict] = None,
**kwargs,
) -> dict:
"""This method is responsible for creating the json payload that will be sent to the API."""
raise NotImplementedError
def create_message(
self,
messages: Union[List[List[int]], List[str], List[JsonChatStr]],
generate=False,
) -> Union[List[List[int]], List[dict], List[str], str]:
"""Helper method to transform the prompt into the expected API input format. messages consist of batched requests"""
if isinstance(messages[0], JsonChatStr):
# for chat completions we need to decode the json string to list[dict,...]
assert (
self._batch_size == 1
), "non-tokenized chat requests are only supported with batch_size=1"
# list[dict["role":..., "content":...],...]
return json.loads(messages[0].prompt)
if not self.tokenized_requests:
# if messages are tokenized:
if isinstance(messages[0][0], int):
# assuming decoding is lossless. However, this is only for logliklehood requests
# as we need to compute the context length. For generations, we don't need to tokenize.
messages = self.decode_batch(messages)
if self._batch_size <= 1:
# if batch is 1 return str
return messages[0]
else:
# list[str,...]
return messages
# list[list[int], ...]
return messages
@staticmethod
@abc.abstractmethod
def parse_logprobs(
outputs: Union[Any, List[Any]],
tokens: List[List[int]] = None,
ctxlen: List[int] = None,
**kwargs,
) -> List[Tuple[float, bool]]:
"""Method used to parse the logprobs from the (batched) API response. This method should return a list of tuples"""
raise NotImplementedError
@staticmethod
@abc.abstractmethod
def parse_generations(outputs: Union[Any, List[Any]], **kwargs) -> List[str]:
"""Method used to parse the generations from the (batched) API response. This method should return a list of str"""
raise NotImplementedError
@cached_property
def api_key(self) -> str:
"""Override this property to return the API key for the API request."""
return ""
@cached_property
def header(self) -> dict:
"""Override this property to return the headers for the API request."""
return {"Authorization": f"Bearer {self.api_key}"}
@property
def chat_template(self) -> str:
"""Must be defined for LM subclasses that implement Chat Templating.
Should return the structure of the chat template applied to user/assistant messages.
Only used for logging and reproducibility.
"""
return ""
@property
def tokenizer_name(self) -> str:
"""Must be defined for LM subclasses which implement Chat Templating.
Should return the name of the tokenizer or chat template used.
Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
"""
return ""
def apply_chat_template(
self, chat_history: List[Dict[str, str]]
) -> Union[str, JsonChatStr]:
"""Applies a chat template to a list of chat history between user and model."""
if self.tokenizer_backend == "huggingface" and self.tokenized_requests:
return self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True
)
else:
# bit of a hack. We'll load back before sending to the API
return JsonChatStr(json.dumps(chat_history))
@cached_property
def eot_token_id(self) -> Optional[int]:
if self.tokenizer is None:
return None
else:
if self.tokenizer_backend == "huggingface":
return self.tokenizer.eos_token_id
elif self.tokenizer_backend == "tiktoken":
return self.tokenizer.eot_token
@cached_property
def prefix_token_id(self) -> Optional[int]:
if self.tokenizer is None:
return None
else:
if self.custom_prefix_token_id is not None:
return self.custom_prefix_token_id
if self.tokenizer_backend == "huggingface":
if self.tokenizer.bos_token_id is not None:
return self.tokenizer.bos_token_id
return self.tokenizer.eos_token_id
else:
return self.tokenizer.eot_token
def tok_encode(
self,
string: str,
left_truncate_len: int = None,
add_special_tokens: bool = False,
truncation: bool = False,
**kwargs,
) -> Union[List[List[int]], List[int], List[str]]:
if self.tokenizer_backend is None:
return [string]
elif self.tokenizer_backend == "huggingface":
# by default for CausalLM - false or self.add_bos_token is set
if not add_special_tokens:
add_special_tokens = False or self.add_bos_token
encoding: Union[List[List[int]], List[int]] = self.tokenizer(
string,
add_special_tokens=add_special_tokens,
truncation=truncation,
return_attention_mask=False,
).input_ids
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if left_truncate_len:
if not isinstance(string, str):
encoding = [enc[-left_truncate_len:] for enc in encoding]
else:
encoding = encoding[-left_truncate_len:]
return encoding
else:
try:
encoding = self.tokenizer.encode(string)
except Exception:
encoding = self.tokenizer.encode_batch(string)
return encoding
def decode_batch(self, tokens: List[List[int]]) -> List[str]:
if self.tokenizer_backend == "huggingface":
return self.tokenizer.batch_decode(tokens)
elif self.tokenizer_backend == "tiktoken":
return self.tokenizer.decode_batch(tokens)
def model_call(
self,
messages: Union[List[List[int]], List[str], List[JsonChatStr]],
*,
generate: bool = True,
gen_kwargs: Optional[Dict] = None,
**kwargs,
) -> Optional[dict]:
# !!! Copy: shared dict for each request, need new object !!!
gen_kwargs = copy.deepcopy(gen_kwargs)
try:
response = requests.post(
self.base_url,
json=self._create_payload(
self.create_message(messages),
generate=generate,
gen_kwargs=gen_kwargs,
**kwargs,
),
headers=self.header,
)
if not response.ok:
eval_logger.warning(
f"API request failed with error message: {response.text}. Retrying..."
)
response.raise_for_status()
return response.json()
except RetryError:
eval_logger.error(
"API request failed after multiple retries. Please check the API status."
)
return None
async def amodel_call(
self,
session: ClientSession,
messages: Union[List[List[int]], List[str], List[JsonChatStr]],
*,
generate: bool = True,
cache_keys: list = None,
ctxlens: Optional[List[int]] = None,
gen_kwargs: Optional[Dict] = None,
**kwargs,
) -> Union[List[str], List[Tuple[float, bool]], None]:
# !!! Copy: shared dict for each request, need new object !!!
gen_kwargs = copy.deepcopy(gen_kwargs)
payload = self._create_payload(
self.create_message(messages),
generate=generate,
gen_kwargs=gen_kwargs,
**kwargs,
)
cache_method = "generate_until" if generate else "loglikelihood"
try:
async with session.post(
self.base_url,
json=payload,
headers=self.header,
) as response:
if not response.ok:
error_text = await response.text()
eval_logger.warning(
f"API request failed with error message: {error_text}. Retrying..."
)
# raising exception will retry the request
response.raise_for_status()
outputs = await response.json()
answers = (
self.parse_generations(
outputs=outputs,
)
if generate
else self.parse_logprobs(
outputs=outputs,
tokens=messages,
ctxlens=ctxlens,
)
)
if cache_keys:
for res, cache in zip(answers, cache_keys):
self.cache_hook.add_partial(cache_method, cache, res)
return answers
# If the retries also fail
except RetryError:
eval_logger.error(
"API request failed after multiple retries. Please check the API status."
)
return None
def batch_logliklehood_requests(
self, chunks: Iterable[List[LogLikelihoodInputs]]
) -> Tuple[List[List[int]], List[int], List[Tuple[str, str]]]:
inputs = []
ctxlens = []
cache_keys = []
for chunk in chunks:
for cache_key, context_enc, continuation_enc in chunk:
inp = (context_enc + continuation_enc)[-(self.max_length) :]
ctxlen = len(context_enc) - max(
0, len(context_enc) + len(continuation_enc) - (self.max_length)
)
inputs.append(inp)
ctxlens.append(ctxlen)
cache_keys.append(cache_key)
return inputs, ctxlens, cache_keys
async def get_batched_requests(
self,
requests: list,
cache_keys: list,
*,
generate: bool = True,
ctxlens: List[int] = None,
**kwargs,
) -> Union[List[List[str]], List[List[Tuple[float, bool]]]]:
ctxlens = ctxlens if ctxlens else [None] * len(requests)
conn = TCPConnector(limit=self._concurrent)
async with ClientSession(connector=conn) as session:
retry_: Callable[..., Awaitable[Any]] = retry(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=0.5, min=1, max=10),
reraise=True,
)(self.amodel_call)
# Create tasks for each batch of request
tasks = [
asyncio.create_task(
retry_(
session=session,
messages=message,
cache_keys=cache_key,
generate=generate,
ctxlens=ctxlen,
**kwargs,
)
)
for message, cache_key, ctxlen in zip(
chunks(requests, n=self._batch_size),
chunks(cache_keys, n=self._batch_size),
chunks(ctxlens, n=self._batch_size),
)
]
return await tqdm_asyncio.gather(*tasks, desc="Requesting API")
def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
assert (
self.tokenizer is not None
), "Tokenizer is required for loglikelihood tasks to compute context lengths."
res = []
def _collate(req: LogLikelihoodInputs):
"""Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = req[1] + req[2]
return -len(toks), tuple(toks)
re_ord = Collator(
requests,
sort_fn=_collate,
group_by=None,
)
# if concurrent then we'll batch in the async context
chunked = re_ord.get_batched(n=self._batch_size if self._concurrent <= 1 else 0)
if self._concurrent <= 1:
pbar = tqdm(desc="Requesting API", total=len(requests))
for chunk in chunked:
inputs, ctxlens, cache_keys = self.batch_logliklehood_requests([chunk])
outputs = retry(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=0.5, min=1, max=10),
reraise=True,
)(self.model_call)(messages=self.create_message(inputs), generate=False)
if isinstance(outputs, dict):
outputs = [outputs]
for answer_, cache_key in zip(
self.parse_logprobs(
outputs=outputs, tokens=inputs, ctxlens=ctxlens
),
cache_keys,
):
if answer_ is not None:
res.append(answer_)
# partial caching
if cache_key is not None:
self.cache_hook.add_partial(
"loglikelihood", cache_key, answer_
)
pbar.update(1)
else:
inputs, ctxlens, cache_keys = self.batch_logliklehood_requests(chunked)
res = itertools.chain.from_iterable(
asyncio.run(
self.get_batched_requests(
inputs, cache_keys, generate=False, ctxlens=ctxlens
)
)
)
return re_ord.get_original(res)
def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
res = []
def _collate_gen(_requests):
# sort by the length of the non-tokenized contexts
return -len(_requests[0])
# Let the API deal with tokenization
requests, all_gen_kwargs = zip(*(req.args for req in requests))
if self.tokenized_requests:
encodings_list = self.tok_encode(
requests, add_special_tokens=self.add_bos_token
)
else:
encodings_list = [None] * len(requests)
requests = [
(a, b, c) for a, b, c in zip(requests, all_gen_kwargs, encodings_list)
]
re_ord = Collator(
requests,
sort_fn=_collate_gen,
group_by="gen_kwargs",
)
chunked = re_ord.get_batched(
n=self._batch_size if self._concurrent <= 1 else 0, batch_fn=None
)
if self._concurrent <= 1:
pbar = tqdm(desc="Requesting API", total=len(requests))
for chunk in chunked:
contexts, all_gen_kwargs, encodings_list = zip(*chunk)
req = encodings_list if self.tokenized_requests else contexts
outputs = retry(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=0.5, min=1, max=10),
reraise=True,
)(self.model_call)(
messages=req,
generate=True,
gen_kwargs=copy.deepcopy(all_gen_kwargs[0]),
)
for generated_text, context in zip(
self.parse_generations(
outputs=outputs,
contexts=contexts,
),
contexts,
):
if generated_text is not None:
res.append(generated_text)
# partial caching
if context is not None:
self.cache_hook.add_partial(
"generate_until",
(context, all_gen_kwargs[0]),
generated_text,
)
pbar.update(1)
else:
for chunk in chunked:
contexts, all_gen_kwargs, encodings_list = zip(*chunk)
req = encodings_list if self.tokenized_requests else contexts
results = itertools.chain.from_iterable(
asyncio.run(
self.get_batched_requests(
req,
cache_keys=[(ctx, all_gen_kwargs[0]) for ctx in contexts],
generate=True,
gen_kwargs=copy.deepcopy(all_gen_kwargs[0]),
)
)
)
res.extend(results)
return re_ord.get_original(res)
def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[float]:
loglikelihoods = []
for (string,) in tqdm([req.args for req in requests], disable=disable_tqdm):
rolling_token_windows = list(
map(
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=self.prefix_token_id,
max_seq_len=self.max_length,
context_len=1,
),
)
)
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
string_nll = self._loglikelihood_tokens(
rolling_token_windows,
disable_tqdm=True,
)
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods
import copy
import os import os
from collections import defaultdict from functools import cached_property
from importlib.util import find_spec from typing import Any, Dict, List, Optional, Tuple, Union
from typing import List, Literal, Optional, Tuple
from tqdm import tqdm
import lm_eval.models.utils
from lm_eval import utils
from lm_eval.api.model import LM, TemplateLM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.utils import retry_on_specific_exceptions from lm_eval.models.api_models import TemplateAPI
from lm_eval.utils import eval_logger from lm_eval.utils import eval_logger
def get_result(response) -> Tuple[float, bool]: @register_model("local-completions")
"""Process results from OpenAI API response. class LocalCompletionsAPI(TemplateAPI):
:param response: dict
OpenAI API Response
:return:
continuation_logprobs: np.array
Log probabilities of continuation tokens
is_greedy: bool
whether argmax matches given continuation exactly
"""
is_greedy = True
logprobs = response.logprobs.token_logprobs
continuation_logprobs = sum(logprobs)
for i in range(len(response.logprobs.token_logprobs)):
token = response.logprobs.token_logprobs[i]
top_tokens = response.logprobs.top_logprobs[i]
top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
if top_token != token:
is_greedy = False
break
return continuation_logprobs, is_greedy
def oa_completion(client, chat: bool = False, **kwargs):
"""Query OpenAI API for completion.
Retry with back-off until they respond
"""
if not find_spec("openai") or not find_spec("tiktoken"):
raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. "
"Please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`"
)
else:
import openai
def _exception_callback(e: Exception, sleep_time: float) -> None:
import traceback
traceback.print_exc()
@retry_on_specific_exceptions(
on_exceptions=[openai.OpenAIError],
max_retries=None, # retry forever, consider changing
on_exception_callback=_exception_callback,
)
def completion():
if chat:
return client.chat.completions.create(**kwargs)
else:
return client.completions.create(**kwargs)
return completion()
@register_model("openai-completions", "local-completions")
class OpenaiCompletionsLM(TemplateLM):
_DEFAULT_MAX_LENGTH = 2048
def __init__( def __init__(
self, self,
model: str, base_url=None,
base_url: str = None, tokenizer_backend="huggingface",
tokenizer: Optional[str] = None, **kwargs,
tokenizer_backend: Literal["tiktoken", "huggingface"] = "tiktoken", ):
truncate: bool = False, super().__init__(
max_gen_toks: int = 256, base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
batch_size: int = 1,
seed: int = 1234,
max_length: Optional[int] = None,
) -> None:
"""
:param engine: str
OpenAI API engine (e.g. gpt-3.5-turbo-instruct)
:param truncate: bool
Truncate input if too long (if False and input is too long, throw error)
"""
super().__init__()
self.seed = seed
try:
import openai # noqa: E401
import tiktoken
except ModuleNotFoundError:
raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
please install these via `pip install lm-eval[openai]` or `pip install -e .\"[openai]\"`",
)
self.model = model
self.base_url = base_url
self.tokenizer_backend = tokenizer_backend
self.truncate = truncate
self._batch_size = int(batch_size)
self._max_gen_toks = max_gen_toks
self._max_length = max_length
# if we have a local model, use HF tokenizer over tiktoken
if self.tokenizer_backend == "huggingface":
import transformers # noqa: E401
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
tokenizer if tokenizer else self.model
)
self.vocab_size = self.tokenizer.vocab
self.end_of_text_token_id = self.tokenizer.eos_token
elif self.tokenizer_backend == "tiktoken":
if self.base_url:
eval_logger.warning(
f"Passed `base_url={self.base_url}` but using Tiktoken tokenizer backend. "
"Pass `tokenizer_backend=huggingface` and provide the HF tokenizer name if your model does not use Tiktoken."
)
self.tokenizer = tiktoken.encoding_for_model(self.model)
self.vocab_size = self.tokenizer.n_vocab
self.end_of_text_token_id = self.tokenizer.eot_token
else:
raise ValueError(
f"Expected tokenizer_backend to be one of ['tiktoken', 'huggingface'] but got {self.tokenizer_backend}"
) )
# Read from environment variable OPENAI_API_KEY def _create_payload(
# Set to EMPTY for local self,
openai.api_key = os.environ["OPENAI_API_KEY"] messages: Union[List[List[int]], List[dict], List[str], str],
if self.base_url: generate=False,
self.client = openai.OpenAI(base_url=self.base_url) gen_kwargs: Optional[dict] = None,
else: **kwargs,
self.client = openai.OpenAI() ) -> dict:
if generate:
@property gen_kwargs.pop("do_sample", False)
def eot_token_id(self): max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
return self.end_of_text_token_id temperature = gen_kwargs.pop("temperature", 0)
stop = gen_kwargs.pop("until", ["<|endoftext|>"])
@property return {
def max_length(self) -> int: "prompt": messages,
if self._max_length: "model": self.model,
return self._max_length "max_tokens": max_tokens,
"temperature": temperature,
"stop": stop,
**gen_kwargs,
}
else: else:
return self._DEFAULT_MAX_LENGTH return {
"model": self.model,
@property "prompt": messages,
def max_gen_toks(self) -> int: "max_tokens": 1,
return self._max_gen_toks "logprobs": 1,
"echo": True,
@property }
def batch_size(self) -> int:
return self._batch_size @staticmethod
def parse_logprobs(
@property outputs: Union[Dict, List[Dict]],
def device(self): tokens: List[List[int]] = None,
# Isn't used because we override _loglikelihood_tokens ctxlens: List[int] = None,
raise NotImplementedError() **kwargs,
def tok_encode(self, string: str, **kwargs) -> List[int]:
return self.tokenizer.encode(string)
def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens)
def _loglikelihood_tokens(
self, requests, disable_tqdm: bool = False
) -> List[Tuple[float, bool]]: ) -> List[Tuple[float, bool]]:
res = [] res = []
if not isinstance(outputs, list):
outputs = [outputs]
for out in outputs:
for choice, ctxlen in zip(out["choices"], ctxlens):
assert ctxlen > 0, "Context length must be greater than 0"
logprobs = sum(choice["logprobs"]["token_logprobs"][ctxlen:-1])
tokens = choice["logprobs"]["token_logprobs"][ctxlen:-1]
top_logprobs = choice["logprobs"]["top_logprobs"][ctxlen:-1]
is_greedy = True
for tok, top in zip(tokens, top_logprobs):
if tok != max(top, key=top.get):
is_greedy = False
break
res.append((logprobs, is_greedy))
return res
def _collate(x): @staticmethod
# this doesn't efficiently handle last-token differences yet, but those are kinda annoying because def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]:
# it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
# we care about, and so we need some kind of backup for when it isn't
toks = x[1] + x[2]
return -len(toks), tuple(toks)
re_ord = utils.Reorderer(requests, _collate)
for chunk in tqdm(
list(lm_eval.models.utils.chunks(re_ord.get_reordered(), self.batch_size)),
disable=disable_tqdm,
):
inps = []
ctxlens = []
for cache_key, context_enc, continuation_enc in chunk:
# max_length+1 because the API takes up to 2049 tokens, including the first context token
inp = (context_enc + continuation_enc)[-(self.max_length + 1) :]
# TODO: the logic is much simpler if we just look at the length of continuation tokens
ctxlen = len(context_enc) - max(
0, len(context_enc) + len(continuation_enc) - (self.max_length + 1)
)
inps.append(inp)
ctxlens.append(ctxlen)
response = oa_completion(
client=self.client,
model=self.model,
prompt=inps,
max_tokens=0,
temperature=0.0,
logprobs=10,
seed=self.seed,
)
for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
response.choices, ctxlens, chunk
):
answer = get_result(resp)
res.append(answer)
# partial caching
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return re_ord.get_original(res)
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
if not requests:
return []
res = [] res = []
requests = [req.args for req in requests] if not isinstance(outputs, list):
outputs = [outputs]
def _collate(x): for out in outputs:
toks = self.tok_encode(x[0]) for choices in out["choices"]:
return len(toks), x[0] res.append(choices["text"])
return res
re_ord = utils.Reorderer(requests, _collate) @property
def api_key(self):
def sameuntil_chunks(xs, size): return os.environ.get("OPENAI_API_KEY", "")
ret = []
lastuntil = xs[0][1]
for x in xs:
if len(ret) >= size or x[1] != lastuntil:
yield ret, lastuntil
ret = []
lastuntil = x[1]
ret.append(x)
if ret:
yield ret, lastuntil
# todo: more intelligent batching for heterogeneous `until` @register_model("local-chat-completions")
for chunk, request_args in tqdm( class LocalChatCompletion(LocalCompletionsAPI):
list(sameuntil_chunks(re_ord.get_reordered(), self.batch_size)), def __init__(
disable=disable_tqdm, self,
base_url=None,
tokenizer_backend=None,
tokenized_requests=False,
**kwargs,
): ):
inps = [] super().__init__(
self._max_gen_toks = request_args.get("max_gen_toks", self.max_gen_toks) base_url=base_url,
for context, _ in chunk: tokenizer_backend=tokenizer_backend,
context_enc = self.tok_encode(context) tokenized_requests=tokenized_requests,
inp = context_enc[-(self.max_length - self.max_gen_toks) :] **kwargs,
inps.append(inp)
until = request_args.get("until", ["<|endoftext|>"])
request_args["temperature"] = request_args.get("temperature", 0)
response = oa_completion(
client=self.client,
model=self.model,
prompt=inps,
max_tokens=self.max_gen_toks,
stop=until,
seed=self.seed,
**{
k: v
for k, v in request_args.items()
if k not in {"do_sample", "max_gen_toks", "until"}
},
)
for resp, (context, args_) in zip(response.choices, chunk):
s = getattr(resp, "text")
until_ = until
for term in until_:
if len(term) > 0:
s = s.split(term)[0]
# partial caching
self.cache_hook.add_partial(
"generate_until", (context, {"until": until_}), s
)
res.append(s)
return re_ord.get_original(res)
def _model_call(self, inps):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override generate_until
raise NotImplementedError()
def loglikelihood_rolling(
self, requests, disable_tqdm: bool = False
) -> List[float]:
loglikelihoods = []
for (string,) in tqdm([req.args for req in requests], disable=disable_tqdm):
rolling_token_windows = list(
map(
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=self.eot_token_id,
max_seq_len=self.max_length,
context_len=1,
),
)
) )
if self._batch_size > 1:
eval_logger.warning(
"Chat completions does not support batching. Defaulting to batch size 1."
)
self._batch_size = 1
def _create_payload(
self, messages: List[Dict], generate=False, gen_kwargs: dict = None, **kwargs
) -> dict:
gen_kwargs.pop("do_sample", False)
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
temperature = gen_kwargs.pop("temperature", 0)
stop = gen_kwargs.pop("until", ["<|endoftext|>"])
if not isinstance(stop, (list, tuple)):
stop = [stop]
return {
"messages": messages,
"model": self.model,
"max_tokens": max_tokens,
"temperature": temperature,
"stop": stop[:4],
**gen_kwargs,
}
@staticmethod
def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]:
res = []
if not isinstance(outputs, list):
outputs = [outputs]
for out in outputs:
for choices in out["choices"]:
res.append(choices["message"]["content"])
return res
def tok_encode(
self,
string: Union[str, Any],
left_truncate_len=None,
add_special_tokens=None,
**kwargs,
) -> Union[List[str], List[int], Any]:
return string
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case def _loglikelihood_tokens(self, requests, **kwargs):
rolling_token_windows = [(None,) + x for x in rolling_token_windows] raise NotImplementedError(
"Loglikelihood is not supported for chat completions. Consider using the completions API instead."
string_nll = self._loglikelihood_tokens(
rolling_token_windows,
disable_tqdm=True,
) )
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll) @register_model(
loglikelihoods.append(string_nll) "openai-completions",
return loglikelihoods )
class OpenAICompletionsAPI(LocalCompletionsAPI):
@register_model("openai-chat-completions", "local-chat-completions")
class OpenaiChatCompletionsLM(LM):
def __init__( def __init__(
self, self,
model: str = "gpt-3.5-turbo", # GPT model or Local model using HuggingFace model paths base_url="https://api.openai.com/v1/completions",
base_url: str = None, tokenizer_backend="tiktoken",
truncate: bool = False,
**kwargs, **kwargs,
) -> None: ):
""" super().__init__(
base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
:param model: str
Implements an OpenAI-style chat completion API for
accessing both OpenAI OR locally-hosted models using
HuggingFace Tokenizer
OpenAI API model (e.g. gpt-3.5-turbo)
using the **gen_kwargs passed on init
:param truncate: bool
Truncate input if too long (if False and input is too long, throw error)
"""
super().__init__()
try:
import openai # noqa: E401
except ModuleNotFoundError:
raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
) )
self.model = model
self.base_url = base_url
self.truncate = truncate
# Read from environment variable OPENAI_API_KEY
# Set to EMPTY for local
if self.base_url:
self.client = openai.OpenAI(base_url=self.base_url)
else:
self.client = openai.OpenAI() # openai.AsyncOpenAI()
@property
def max_length(self) -> int:
# Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
return 2048
@property
def max_gen_toks(self) -> int:
return 256
@property @cached_property
def batch_size(self): def api_key(self):
# Isn't used because we override _loglikelihood_tokens """Override this property to return the API key for the API request."""
raise NotImplementedError() key = os.environ.get("OPENAI_API_KEY", None)
if key is None:
@property raise ValueError(
def device(self): "API key not found. Please set the OPENAI_API_KEY environment variable."
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
res = defaultdict(list)
re_ords = {}
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
grouper = lm_eval.models.utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer(
[req.args for req in reqs], lambda x: (-len(x[0]), x[0])
) )
return key
pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0))) def _loglikelihood_tokens(self, requests, **kwargs):
for key, re_ord in re_ords.items(): assert (
# n needs to be 1 because messages in self.model != "gpt-3.5-turbo"
# chat completion are not batch but ), "Loglikelihood is not supported for gpt-3.5-turbo"
# is regarded as a single conversation. return super()._loglikelihood_tokens(requests, **kwargs)
chunks = lm_eval.models.utils.chunks(re_ord.get_reordered(), n=1)
for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk)
inps = [{"role": "user", "content": context} for context in contexts]
gen_kwargs = all_gen_kwargs[0]
until = None
if isinstance(kwargs := copy.deepcopy(gen_kwargs), dict):
if "do_sample" in kwargs.keys():
kwargs.pop("do_sample")
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(
f"Expected repr(kwargs['until']) to be of type Union[str, list] but got {until}"
)
kwargs["stop"] = until
kwargs["max_tokens"] = kwargs.pop("max_gen_toks", self.max_gen_toks)
else:
raise ValueError(
f"Expected repr(kwargs) to be of type repr(dict) but got {kwargs}"
)
response = oa_completion( @register_model("openai-chatcompletions")
client=self.client, class OpenAIChatCompletion(LocalChatCompletion):
chat=True, def __init__(
messages=inps, self,
model=self.model, base_url="https://api.openai.com/v1/chat/completions",
tokenizer_backend=None,
tokenized_requests=False,
**kwargs,
):
super().__init__(
base_url=base_url,
tokenizer_backend=tokenizer_backend,
tokenized_requests=tokenized_requests,
**kwargs, **kwargs,
) )
for resp, (context, args_) in zip(response.choices, chunk): @cached_property
s = resp.message.content def api_key(self):
"""Override this property to return the API key for the API request."""
if until is not None: key = os.environ.get("OPENAI_API_KEY", None)
for term in until: if key is None:
if len(term) > 0: raise ValueError(
s = s.split(term)[0] "API key not found. Please set the OPENAI_API_KEY environment variable."
res[key].append(s)
self.cache_hook.add_partial(
"generate_until", (context, {"until": until}), s
) )
pbar.update(1) return key
# reorder this group of results back to original unsorted form
res[key] = re_ord.get_original(res[key])
pbar.close()
return grouper.get_original(res)
def loglikelihood(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.")
def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.")
...@@ -57,7 +57,7 @@ Homepage = "https://github.com/EleutherAI/lm-evaluation-harness" ...@@ -57,7 +57,7 @@ Homepage = "https://github.com/EleutherAI/lm-evaluation-harness"
Repository = "https://github.com/EleutherAI/lm-evaluation-harness" Repository = "https://github.com/EleutherAI/lm-evaluation-harness"
[project.optional-dependencies] [project.optional-dependencies]
anthropic = ["anthropic"] api = ["requests", "aiohttp", "tenacity", "tqdm", "tiktoken"]
dev = ["pytest", "pytest-cov", "pytest-xdist", "pre-commit", "mypy"] dev = ["pytest", "pytest-cov", "pytest-xdist", "pre-commit", "mypy"]
deepsparse = ["deepsparse-nightly[llm]>=1.8.0.20240404"] deepsparse = ["deepsparse-nightly[llm]>=1.8.0.20240404"]
gptq = ["auto-gptq[triton]>=0.6.0"] gptq = ["auto-gptq[triton]>=0.6.0"]
...@@ -67,7 +67,6 @@ neuronx = ["optimum[neuronx]"] ...@@ -67,7 +67,6 @@ neuronx = ["optimum[neuronx]"]
mamba = ["mamba_ssm", "causal-conv1d==1.0.2"] mamba = ["mamba_ssm", "causal-conv1d==1.0.2"]
math = ["sympy>=1.12", "antlr4-python3-runtime==4.11"] math = ["sympy>=1.12", "antlr4-python3-runtime==4.11"]
multilingual = ["nagisa>=0.2.7", "jieba>=0.42.1", "pycountry"] multilingual = ["nagisa>=0.2.7", "jieba>=0.42.1", "pycountry"]
openai = ["openai==1.3.9", "tiktoken"]
optimum = ["optimum[openvino]"] optimum = ["optimum[openvino]"]
promptsource = ["promptsource>=0.2.3"] promptsource = ["promptsource>=0.2.3"]
sentencepiece = ["sentencepiece>=0.1.98"] sentencepiece = ["sentencepiece>=0.1.98"]
......
from unittest.mock import MagicMock, patch
import pytest
from lm_eval.models.openai_completions import LocalCompletionsAPI
@pytest.fixture
def api():
return LocalCompletionsAPI(
base_url="http://test-url.com", tokenizer_backend=None, model="gpt-3.5-turbo"
)
@pytest.fixture
def api_tokenized():
return LocalCompletionsAPI(
base_url="http://test-url.com",
model="EleutherAI/pythia-1b",
tokenizer_backend="huggingface",
)
def test_create_payload_generate(api):
messages = ["Generate a story"]
gen_kwargs = {
"max_tokens": 100,
"temperature": 0.7,
"until": ["The End"],
"do_sample": True,
}
payload = api._create_payload(messages, generate=True, gen_kwargs=gen_kwargs)
assert payload == {
"prompt": ["Generate a story"],
"model": "gpt-3.5-turbo",
"max_tokens": 100,
"temperature": 0.7,
"stop": ["The End"],
}
def test_create_payload_loglikelihood(api):
messages = ["The capital of France is"]
payload = api._create_payload(messages, generate=False, gen_kwargs=None)
assert payload == {
"model": "gpt-3.5-turbo",
"prompt": ["The capital of France is"],
"max_tokens": 1,
"logprobs": 1,
"echo": True,
}
@pytest.mark.parametrize(
"input_messages, generate, gen_kwargs, expected_payload",
[
(
["Hello, how are"],
True,
{"max_gen_toks": 100, "temperature": 0.7},
{
"prompt": "Hello, how are",
"model": "gpt-3.5-turbo",
"max_tokens": 100,
"temperature": 0.7,
"stop": ["<|endoftext|>"],
},
),
(
["Hello, how are", "you"],
True,
{},
{
"prompt": "Hello, how are",
"model": "gpt-3.5-turbo",
"max_tokens": 256,
"temperature": 0,
"stop": ["<|endoftext|>"],
},
),
],
)
def test_model_generate_call_usage(
api, input_messages, generate, gen_kwargs, expected_payload
):
with patch("requests.post") as mock_post:
mock_response = MagicMock()
mock_response.json.return_value = {"result": "success"}
mock_post.return_value = mock_response
# Act
result = api.model_call(
input_messages, generate=generate, gen_kwargs=gen_kwargs
)
# Assert
mock_post.assert_called_once()
_, kwargs = mock_post.call_args
assert "json" in kwargs
assert kwargs["json"] == expected_payload
assert result == {"result": "success"}
@pytest.mark.parametrize(
"input_messages, generate, gen_kwargs, expected_payload",
[
(
[[1, 2, 3, 4, 5]],
False,
None,
{
"model": "EleutherAI/pythia-1b",
"prompt": [[1, 2, 3, 4, 5]],
"max_tokens": 1,
"logprobs": 1,
"echo": True,
},
),
],
)
def test_model_tokenized_call_usage(
api_tokenized, input_messages, generate, gen_kwargs, expected_payload
):
with patch("requests.post") as mock_post:
mock_response = MagicMock()
mock_response.json.return_value = {"result": "success"}
mock_post.return_value = mock_response
# Act
result = api_tokenized.model_call(
input_messages, generate=generate, gen_kwargs=gen_kwargs
)
# Assert
mock_post.assert_called_once()
_, kwargs = mock_post.call_args
assert "json" in kwargs
assert kwargs["json"] == expected_payload
assert result == {"result": "success"}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment