Unverified Commit 0230356c authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

make utility function to handle `until` (#2518)

* make utility function to handle `until`

* fix text
parent 9169899b
......@@ -8,7 +8,7 @@ from lm_eval import utils
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.models.openai_completions import LocalCompletionsAPI
from lm_eval.models.utils import retry_on_specific_exceptions
from lm_eval.models.utils import handle_stop_sequences, retry_on_specific_exceptions
eval_logger = utils.eval_logger
......@@ -311,7 +311,12 @@ class AnthropicChat(LocalCompletionsAPI):
}
def _create_payload(
self, messages: List[Dict], generate=True, gen_kwargs: dict = None, **kwargs
self,
messages: List[Dict],
generate=True,
gen_kwargs: dict = None,
eos="\n\nHuman:",
**kwargs,
) -> dict:
system = (
messages[0].get("content") if messages[0].get("role") == "system" else None
......@@ -321,7 +326,7 @@ class AnthropicChat(LocalCompletionsAPI):
gen_kwargs.pop("do_sample", False)
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
temperature = gen_kwargs.pop("temperature", 0)
stop = gen_kwargs.pop("until", ["\n\nHuman:"])
stop = handle_stop_sequences(gen_kwargs.pop("until", ["\n\nHuman:"]), eos=eos)
if not isinstance(stop, list):
stop = [stop]
out = {
......
......@@ -80,6 +80,7 @@ class TemplateAPI(TemplateLM):
revision: Optional[str] = "main",
use_fast_tokenizer: bool = True,
verify_certificate: bool = True,
eos_string: str = None,
**kwargs,
) -> None:
super().__init__()
......@@ -124,6 +125,7 @@ class TemplateAPI(TemplateLM):
self.tokenized_requests = tokenized_requests
self.max_retries = int(max_retries)
self.verify_certificate = verify_certificate
self._eos_string = eos_string
eval_logger.info(f"Using tokenizer {self.tokenizer_backend}")
if self.tokenizer_backend is None:
......@@ -176,6 +178,7 @@ class TemplateAPI(TemplateLM):
generate: bool = True,
gen_kwargs: Optional[dict] = None,
seed: int = 1234,
eos: str = None,
**kwargs,
) -> dict:
"""This method is responsible for creating the json payload that will be sent to the API."""
......@@ -268,6 +271,21 @@ class TemplateAPI(TemplateLM):
elif self.tokenizer_backend == "tiktoken":
return self.tokenizer.eot_token
@cached_property
def eos_string(self) -> Optional[str]:
if self._eos_string:
return self._eos_string
elif self.tokenizer is not None:
if self.tokenizer_backend == "huggingface":
return self.tokenizer.eos_token
elif self.tokenizer_backend == "tiktoken":
return self.tokenizer.decode([self.tokenizer.eot_token])
else:
eval_logger.warning(
"Cannot determine EOS string to pass to stop sequence. Manually set by passing `eos_string` to model_args."
)
return None
@cached_property
def prefix_token_id(self) -> Optional[int]:
if self.tokenizer is None:
......@@ -343,6 +361,7 @@ class TemplateAPI(TemplateLM):
generate=generate,
gen_kwargs=gen_kwargs,
seed=self._seed,
eos=self.eos_string,
**kwargs,
),
headers=self.header,
......
......@@ -14,6 +14,7 @@ from lm_eval.models.huggingface import HFLM
from lm_eval.models.utils import (
Collator,
flatten_image_list,
handle_stop_sequences,
pad_and_concat,
replace_placeholders,
stop_sequences_criteria,
......@@ -629,7 +630,7 @@ class HFMultimodalLM(HFLM):
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
### Up to here: was identical to non-multimodal HFLM generate_until ###
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
for chunk in chunks:
contexts, all_gen_kwargs, aux_arguments = zip(*chunk)
......@@ -646,27 +647,14 @@ class HFMultimodalLM(HFLM):
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments.
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
# add EOS token to stop sequences
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
# add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
if not until:
until = [eos]
else:
until.append(eos)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
......
......@@ -33,6 +33,7 @@ from lm_eval.models.utils import (
clear_torch_cache,
configure_pad_token,
get_dtype,
handle_stop_sequences,
pad_and_concat,
stop_sequences_criteria,
)
......@@ -1255,33 +1256,21 @@ class HFLM(TemplateLM):
group_fn=lambda x: x[1],
)
chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk)
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments.
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
# add EOS token to stop sequences
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
# add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
if not until:
until = [eos]
else:
until.append(eos)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
......
......@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from lm_eval.api.registry import register_model
from lm_eval.models.api_models import TemplateAPI
from lm_eval.models.utils import handle_stop_sequences
from lm_eval.utils import eval_logger
......@@ -25,6 +26,7 @@ class LocalCompletionsAPI(TemplateAPI):
generate=False,
gen_kwargs: Optional[dict] = None,
seed: int = 1234,
eos=None,
**kwargs,
) -> dict:
if generate:
......@@ -34,7 +36,7 @@ class LocalCompletionsAPI(TemplateAPI):
else:
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
temperature = gen_kwargs.pop("temperature", 0)
stop = gen_kwargs.pop("until", ["<|endoftext|>"])
stop = handle_stop_sequences(gen_kwargs.pop("until", None), eos)
return {
"prompt": messages,
"model": self.model,
......@@ -124,6 +126,7 @@ class LocalChatCompletion(LocalCompletionsAPI):
generate=False,
gen_kwargs: dict = None,
seed=1234,
eos=None,
**kwargs,
) -> dict:
assert (
......@@ -135,7 +138,7 @@ class LocalChatCompletion(LocalCompletionsAPI):
else:
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
temperature = gen_kwargs.pop("temperature", 0)
stop = gen_kwargs.pop("until", ["<|endoftext|>"])
stop = handle_stop_sequences(gen_kwargs.pop("until", None), eos)
if not isinstance(stop, (list, tuple)):
stop = [stop]
return {
......@@ -252,6 +255,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
generate=False,
gen_kwargs: dict = None,
seed=1234,
eos="<|endoftext|>",
**kwargs,
) -> dict:
assert (
......@@ -263,7 +267,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
else:
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
temperature = gen_kwargs.pop("temperature", 0)
stop = gen_kwargs.pop("until", ["<|endoftext|>"])
stop = handle_stop_sequences(gen_kwargs.pop("until", ["<|endoftext|>"]), eos)
if not isinstance(stop, (list, tuple)):
stop = [stop]
output = {
......
......@@ -709,3 +709,21 @@ def flatten_image_list(images: List[List]):
:return: a list of PIL images, via concatenating all the sub-lists in order.
"""
return [image for image_list in images for image in image_list]
def handle_stop_sequences(
until: Union[str, List[str], None], eos: Optional[str]
) -> List[str]:
"""Ensures that the `until` parameter is a list of stop sequences and includes the EOS token."""
if isinstance(until, str):
until = [until]
elif until is None:
until = []
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
if eos is not None and eos not in until:
until.append(eos)
return until
......@@ -10,7 +10,12 @@ from tqdm import tqdm
from lm_eval.api.instance import Instance
from lm_eval.api.model import TemplateLM
from lm_eval.api.registry import register_model
from lm_eval.models.utils import Collator, configure_pad_token, undistribute
from lm_eval.models.utils import (
Collator,
configure_pad_token,
handle_stop_sequences,
undistribute,
)
from lm_eval.utils import (
eval_logger,
get_rolling_token_windows,
......@@ -346,6 +351,7 @@ class VLLM(TemplateLM):
desc="Running generate_until requests",
)
# for each different set of kwargs, we execute all requests, by batch.
eos = self.tokenizer.decode(self.eot_token_id)
for chunk in chunks:
context_and_encoding, all_gen_kwargs = zip(*chunk)
context, context_encoding = zip(*context_and_encoding)
......@@ -353,27 +359,14 @@ class VLLM(TemplateLM):
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments.
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
# add EOS token to stop sequences
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {gen_kwargs}"
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
# add EOS token to stop sequences
eos = self.tokenizer.decode(self.eot_token_id)
if not until:
until = [eos]
else:
until.append(eos)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
......
......@@ -7,7 +7,12 @@ from tqdm import tqdm
from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_model
from lm_eval.models.utils import Collator, replace_placeholders, undistribute
from lm_eval.models.utils import (
Collator,
handle_stop_sequences,
replace_placeholders,
undistribute,
)
from lm_eval.models.vllm_causallms import VLLM
from lm_eval.utils import eval_logger
......@@ -225,7 +230,7 @@ class VLLM_VLM(VLLM):
group_fn=lambda x: x[1],
)
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
eos = self.tokenizer.decode(self.eot_token_id)
for chunk in chunks:
contexts, all_gen_kwargs, aux_arguments = zip(*chunk)
......@@ -241,27 +246,14 @@ class VLLM_VLM(VLLM):
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments.
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
# add EOS token to stop sequences
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
# add EOS token to stop sequences
eos = self.tokenizer.decode(self.eot_token_id)
if not until:
until = [eos]
else:
until.append(eos)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
......
......@@ -63,13 +63,13 @@ def test_create_payload_loglikelihood(api):
(
["Hello, how are"],
True,
{"max_gen_toks": 100, "temperature": 0.7},
{"max_gen_toks": 100, "temperature": 0.7, "until": ["hi"]},
{
"prompt": "Hello, how are",
"model": "gpt-3.5-turbo",
"max_tokens": 100,
"temperature": 0.7,
"stop": ["<|endoftext|>"],
"stop": ["hi"],
"seed": 1234,
},
),
......@@ -82,7 +82,7 @@ def test_create_payload_loglikelihood(api):
"model": "gpt-3.5-turbo",
"max_tokens": 256,
"temperature": 0,
"stop": ["<|endoftext|>"],
"stop": [],
"seed": 1234,
},
),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment