Unverified Commit 0230356c authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

make utility function to handle `until` (#2518)

* make utility function to handle `until`

* fix text
parent 9169899b
...@@ -8,7 +8,7 @@ from lm_eval import utils ...@@ -8,7 +8,7 @@ from lm_eval import utils
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.openai_completions import LocalCompletionsAPI from lm_eval.models.openai_completions import LocalCompletionsAPI
from lm_eval.models.utils import retry_on_specific_exceptions from lm_eval.models.utils import handle_stop_sequences, retry_on_specific_exceptions
eval_logger = utils.eval_logger eval_logger = utils.eval_logger
...@@ -311,7 +311,12 @@ class AnthropicChat(LocalCompletionsAPI): ...@@ -311,7 +311,12 @@ class AnthropicChat(LocalCompletionsAPI):
} }
def _create_payload( def _create_payload(
self, messages: List[Dict], generate=True, gen_kwargs: dict = None, **kwargs self,
messages: List[Dict],
generate=True,
gen_kwargs: dict = None,
eos="\n\nHuman:",
**kwargs,
) -> dict: ) -> dict:
system = ( system = (
messages[0].get("content") if messages[0].get("role") == "system" else None messages[0].get("content") if messages[0].get("role") == "system" else None
...@@ -321,7 +326,7 @@ class AnthropicChat(LocalCompletionsAPI): ...@@ -321,7 +326,7 @@ class AnthropicChat(LocalCompletionsAPI):
gen_kwargs.pop("do_sample", False) gen_kwargs.pop("do_sample", False)
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks) max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
temperature = gen_kwargs.pop("temperature", 0) temperature = gen_kwargs.pop("temperature", 0)
stop = gen_kwargs.pop("until", ["\n\nHuman:"]) stop = handle_stop_sequences(gen_kwargs.pop("until", ["\n\nHuman:"]), eos=eos)
if not isinstance(stop, list): if not isinstance(stop, list):
stop = [stop] stop = [stop]
out = { out = {
......
...@@ -80,6 +80,7 @@ class TemplateAPI(TemplateLM): ...@@ -80,6 +80,7 @@ class TemplateAPI(TemplateLM):
revision: Optional[str] = "main", revision: Optional[str] = "main",
use_fast_tokenizer: bool = True, use_fast_tokenizer: bool = True,
verify_certificate: bool = True, verify_certificate: bool = True,
eos_string: str = None,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -124,6 +125,7 @@ class TemplateAPI(TemplateLM): ...@@ -124,6 +125,7 @@ class TemplateAPI(TemplateLM):
self.tokenized_requests = tokenized_requests self.tokenized_requests = tokenized_requests
self.max_retries = int(max_retries) self.max_retries = int(max_retries)
self.verify_certificate = verify_certificate self.verify_certificate = verify_certificate
self._eos_string = eos_string
eval_logger.info(f"Using tokenizer {self.tokenizer_backend}") eval_logger.info(f"Using tokenizer {self.tokenizer_backend}")
if self.tokenizer_backend is None: if self.tokenizer_backend is None:
...@@ -176,6 +178,7 @@ class TemplateAPI(TemplateLM): ...@@ -176,6 +178,7 @@ class TemplateAPI(TemplateLM):
generate: bool = True, generate: bool = True,
gen_kwargs: Optional[dict] = None, gen_kwargs: Optional[dict] = None,
seed: int = 1234, seed: int = 1234,
eos: str = None,
**kwargs, **kwargs,
) -> dict: ) -> dict:
"""This method is responsible for creating the json payload that will be sent to the API.""" """This method is responsible for creating the json payload that will be sent to the API."""
...@@ -268,6 +271,21 @@ class TemplateAPI(TemplateLM): ...@@ -268,6 +271,21 @@ class TemplateAPI(TemplateLM):
elif self.tokenizer_backend == "tiktoken": elif self.tokenizer_backend == "tiktoken":
return self.tokenizer.eot_token return self.tokenizer.eot_token
@cached_property
def eos_string(self) -> Optional[str]:
if self._eos_string:
return self._eos_string
elif self.tokenizer is not None:
if self.tokenizer_backend == "huggingface":
return self.tokenizer.eos_token
elif self.tokenizer_backend == "tiktoken":
return self.tokenizer.decode([self.tokenizer.eot_token])
else:
eval_logger.warning(
"Cannot determine EOS string to pass to stop sequence. Manually set by passing `eos_string` to model_args."
)
return None
@cached_property @cached_property
def prefix_token_id(self) -> Optional[int]: def prefix_token_id(self) -> Optional[int]:
if self.tokenizer is None: if self.tokenizer is None:
...@@ -343,6 +361,7 @@ class TemplateAPI(TemplateLM): ...@@ -343,6 +361,7 @@ class TemplateAPI(TemplateLM):
generate=generate, generate=generate,
gen_kwargs=gen_kwargs, gen_kwargs=gen_kwargs,
seed=self._seed, seed=self._seed,
eos=self.eos_string,
**kwargs, **kwargs,
), ),
headers=self.header, headers=self.header,
......
...@@ -14,6 +14,7 @@ from lm_eval.models.huggingface import HFLM ...@@ -14,6 +14,7 @@ from lm_eval.models.huggingface import HFLM
from lm_eval.models.utils import ( from lm_eval.models.utils import (
Collator, Collator,
flatten_image_list, flatten_image_list,
handle_stop_sequences,
pad_and_concat, pad_and_concat,
replace_placeholders, replace_placeholders,
stop_sequences_criteria, stop_sequences_criteria,
...@@ -629,7 +630,7 @@ class HFMultimodalLM(HFLM): ...@@ -629,7 +630,7 @@ class HFMultimodalLM(HFLM):
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
### Up to here: was identical to non-multimodal HFLM generate_until ### ### Up to here: was identical to non-multimodal HFLM generate_until ###
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
for chunk in chunks: for chunk in chunks:
contexts, all_gen_kwargs, aux_arguments = zip(*chunk) contexts, all_gen_kwargs, aux_arguments = zip(*chunk)
...@@ -646,27 +647,14 @@ class HFMultimodalLM(HFLM): ...@@ -646,27 +647,14 @@ class HFMultimodalLM(HFLM):
# this is safe to assume because the `grouper` object ensures it. # this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0] gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments. # unpack our keyword arguments.
until = None
if isinstance(gen_kwargs, dict): if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys(): # add EOS token to stop sequences
until = kwargs.pop("until") until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else: else:
raise ValueError( raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
) )
# add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
if not until:
until = [eos]
else:
until.append(eos)
if "max_gen_toks" in kwargs.keys(): if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks") max_gen_toks = kwargs.pop("max_gen_toks")
else: else:
......
...@@ -33,6 +33,7 @@ from lm_eval.models.utils import ( ...@@ -33,6 +33,7 @@ from lm_eval.models.utils import (
clear_torch_cache, clear_torch_cache,
configure_pad_token, configure_pad_token,
get_dtype, get_dtype,
handle_stop_sequences,
pad_and_concat, pad_and_concat,
stop_sequences_criteria, stop_sequences_criteria,
) )
...@@ -1255,33 +1256,21 @@ class HFLM(TemplateLM): ...@@ -1255,33 +1256,21 @@ class HFLM(TemplateLM):
group_fn=lambda x: x[1], group_fn=lambda x: x[1],
) )
chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn) chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
for chunk in chunks: for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk) contexts, all_gen_kwargs = zip(*chunk)
# we assume all gen kwargs in the batch are the same # we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it. # this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0] gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments. # unpack our keyword arguments.
until = None
if isinstance(gen_kwargs, dict): if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys(): # add EOS token to stop sequences
until = kwargs.pop("until") until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else: else:
raise ValueError( raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
) )
# add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
if not until:
until = [eos]
else:
until.append(eos)
if "max_gen_toks" in kwargs.keys(): if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks") max_gen_toks = kwargs.pop("max_gen_toks")
else: else:
......
...@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.api_models import TemplateAPI from lm_eval.models.api_models import TemplateAPI
from lm_eval.models.utils import handle_stop_sequences
from lm_eval.utils import eval_logger from lm_eval.utils import eval_logger
...@@ -25,6 +26,7 @@ class LocalCompletionsAPI(TemplateAPI): ...@@ -25,6 +26,7 @@ class LocalCompletionsAPI(TemplateAPI):
generate=False, generate=False,
gen_kwargs: Optional[dict] = None, gen_kwargs: Optional[dict] = None,
seed: int = 1234, seed: int = 1234,
eos=None,
**kwargs, **kwargs,
) -> dict: ) -> dict:
if generate: if generate:
...@@ -34,7 +36,7 @@ class LocalCompletionsAPI(TemplateAPI): ...@@ -34,7 +36,7 @@ class LocalCompletionsAPI(TemplateAPI):
else: else:
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks) max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
temperature = gen_kwargs.pop("temperature", 0) temperature = gen_kwargs.pop("temperature", 0)
stop = gen_kwargs.pop("until", ["<|endoftext|>"]) stop = handle_stop_sequences(gen_kwargs.pop("until", None), eos)
return { return {
"prompt": messages, "prompt": messages,
"model": self.model, "model": self.model,
...@@ -124,6 +126,7 @@ class LocalChatCompletion(LocalCompletionsAPI): ...@@ -124,6 +126,7 @@ class LocalChatCompletion(LocalCompletionsAPI):
generate=False, generate=False,
gen_kwargs: dict = None, gen_kwargs: dict = None,
seed=1234, seed=1234,
eos=None,
**kwargs, **kwargs,
) -> dict: ) -> dict:
assert ( assert (
...@@ -135,7 +138,7 @@ class LocalChatCompletion(LocalCompletionsAPI): ...@@ -135,7 +138,7 @@ class LocalChatCompletion(LocalCompletionsAPI):
else: else:
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks) max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
temperature = gen_kwargs.pop("temperature", 0) temperature = gen_kwargs.pop("temperature", 0)
stop = gen_kwargs.pop("until", ["<|endoftext|>"]) stop = handle_stop_sequences(gen_kwargs.pop("until", None), eos)
if not isinstance(stop, (list, tuple)): if not isinstance(stop, (list, tuple)):
stop = [stop] stop = [stop]
return { return {
...@@ -252,6 +255,7 @@ class OpenAIChatCompletion(LocalChatCompletion): ...@@ -252,6 +255,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
generate=False, generate=False,
gen_kwargs: dict = None, gen_kwargs: dict = None,
seed=1234, seed=1234,
eos="<|endoftext|>",
**kwargs, **kwargs,
) -> dict: ) -> dict:
assert ( assert (
...@@ -263,7 +267,7 @@ class OpenAIChatCompletion(LocalChatCompletion): ...@@ -263,7 +267,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
else: else:
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks) max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
temperature = gen_kwargs.pop("temperature", 0) temperature = gen_kwargs.pop("temperature", 0)
stop = gen_kwargs.pop("until", ["<|endoftext|>"]) stop = handle_stop_sequences(gen_kwargs.pop("until", ["<|endoftext|>"]), eos)
if not isinstance(stop, (list, tuple)): if not isinstance(stop, (list, tuple)):
stop = [stop] stop = [stop]
output = { output = {
......
...@@ -709,3 +709,21 @@ def flatten_image_list(images: List[List]): ...@@ -709,3 +709,21 @@ def flatten_image_list(images: List[List]):
:return: a list of PIL images, via concatenating all the sub-lists in order. :return: a list of PIL images, via concatenating all the sub-lists in order.
""" """
return [image for image_list in images for image in image_list] return [image for image_list in images for image in image_list]
def handle_stop_sequences(
until: Union[str, List[str], None], eos: Optional[str]
) -> List[str]:
"""Ensures that the `until` parameter is a list of stop sequences and includes the EOS token."""
if isinstance(until, str):
until = [until]
elif until is None:
until = []
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
if eos is not None and eos not in until:
until.append(eos)
return until
...@@ -10,7 +10,12 @@ from tqdm import tqdm ...@@ -10,7 +10,12 @@ from tqdm import tqdm
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.model import TemplateLM from lm_eval.api.model import TemplateLM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.utils import Collator, configure_pad_token, undistribute from lm_eval.models.utils import (
Collator,
configure_pad_token,
handle_stop_sequences,
undistribute,
)
from lm_eval.utils import ( from lm_eval.utils import (
eval_logger, eval_logger,
get_rolling_token_windows, get_rolling_token_windows,
...@@ -346,6 +351,7 @@ class VLLM(TemplateLM): ...@@ -346,6 +351,7 @@ class VLLM(TemplateLM):
desc="Running generate_until requests", desc="Running generate_until requests",
) )
# for each different set of kwargs, we execute all requests, by batch. # for each different set of kwargs, we execute all requests, by batch.
eos = self.tokenizer.decode(self.eot_token_id)
for chunk in chunks: for chunk in chunks:
context_and_encoding, all_gen_kwargs = zip(*chunk) context_and_encoding, all_gen_kwargs = zip(*chunk)
context, context_encoding = zip(*context_and_encoding) context, context_encoding = zip(*context_and_encoding)
...@@ -353,27 +359,14 @@ class VLLM(TemplateLM): ...@@ -353,27 +359,14 @@ class VLLM(TemplateLM):
# this is safe to assume because the `grouper` object ensures it. # this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0] gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments. # unpack our keyword arguments.
until = None
if isinstance(gen_kwargs, dict): if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys(): # add EOS token to stop sequences
until = kwargs.pop("until") until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else: else:
raise ValueError( raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {gen_kwargs}" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
) )
# add EOS token to stop sequences
eos = self.tokenizer.decode(self.eot_token_id)
if not until:
until = [eos]
else:
until.append(eos)
if "max_gen_toks" in kwargs.keys(): if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks") max_gen_toks = kwargs.pop("max_gen_toks")
else: else:
......
...@@ -7,7 +7,12 @@ from tqdm import tqdm ...@@ -7,7 +7,12 @@ from tqdm import tqdm
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.utils import Collator, replace_placeholders, undistribute from lm_eval.models.utils import (
Collator,
handle_stop_sequences,
replace_placeholders,
undistribute,
)
from lm_eval.models.vllm_causallms import VLLM from lm_eval.models.vllm_causallms import VLLM
from lm_eval.utils import eval_logger from lm_eval.utils import eval_logger
...@@ -225,7 +230,7 @@ class VLLM_VLM(VLLM): ...@@ -225,7 +230,7 @@ class VLLM_VLM(VLLM):
group_fn=lambda x: x[1], group_fn=lambda x: x[1],
) )
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
eos = self.tokenizer.decode(self.eot_token_id)
for chunk in chunks: for chunk in chunks:
contexts, all_gen_kwargs, aux_arguments = zip(*chunk) contexts, all_gen_kwargs, aux_arguments = zip(*chunk)
...@@ -241,27 +246,14 @@ class VLLM_VLM(VLLM): ...@@ -241,27 +246,14 @@ class VLLM_VLM(VLLM):
# this is safe to assume because the `grouper` object ensures it. # this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0] gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments. # unpack our keyword arguments.
until = None
if isinstance(gen_kwargs, dict): if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys(): # add EOS token to stop sequences
until = kwargs.pop("until") until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else: else:
raise ValueError( raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
) )
# add EOS token to stop sequences
eos = self.tokenizer.decode(self.eot_token_id)
if not until:
until = [eos]
else:
until.append(eos)
if "max_gen_toks" in kwargs.keys(): if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks") max_gen_toks = kwargs.pop("max_gen_toks")
else: else:
......
...@@ -63,13 +63,13 @@ def test_create_payload_loglikelihood(api): ...@@ -63,13 +63,13 @@ def test_create_payload_loglikelihood(api):
( (
["Hello, how are"], ["Hello, how are"],
True, True,
{"max_gen_toks": 100, "temperature": 0.7}, {"max_gen_toks": 100, "temperature": 0.7, "until": ["hi"]},
{ {
"prompt": "Hello, how are", "prompt": "Hello, how are",
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
"max_tokens": 100, "max_tokens": 100,
"temperature": 0.7, "temperature": 0.7,
"stop": ["<|endoftext|>"], "stop": ["hi"],
"seed": 1234, "seed": 1234,
}, },
), ),
...@@ -82,7 +82,7 @@ def test_create_payload_loglikelihood(api): ...@@ -82,7 +82,7 @@ def test_create_payload_loglikelihood(api):
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
"max_tokens": 256, "max_tokens": 256,
"temperature": 0, "temperature": 0,
"stop": ["<|endoftext|>"], "stop": [],
"seed": 1234, "seed": 1234,
}, },
), ),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment