Commit 2106fbeb authored by Baber's avatar Baber
Browse files

Merge branch 'main' into mathvista

# Conflicts:
#	lm_eval/models/openai_completions.py
parents 4354fe46 703fbffd
...@@ -4,7 +4,7 @@ from typing import List ...@@ -4,7 +4,7 @@ from typing import List
from lm_eval.api.filter import FilterEnsemble from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.registry import get_filter from lm_eval.api.registry import get_filter
from . import extraction, selection, transformation from . import custom, extraction, selection, transformation
def build_filter_ensemble( def build_filter_ensemble(
......
from lm_eval.api.filter import Filter
from lm_eval.api.registry import register_filter
@register_filter("custom")
class CustomFilter(Filter):
"""
Custom filter that applies a custom, user-defined function to the model responses.
"""
def __init__(self, **kwargs) -> None:
self.filter_fn = kwargs.pop("filter_fn")
super().__init__(**kwargs)
def apply(self, resps, docs):
return self.filter_fn(resps, docs)
...@@ -8,12 +8,17 @@ from lm_eval.api.registry import register_filter ...@@ -8,12 +8,17 @@ from lm_eval.api.registry import register_filter
@register_filter("regex") @register_filter("regex")
class RegexFilter(Filter): class RegexFilter(Filter):
""" """ """A filter that extracts values from text using regex pattern matching.
This filter applies a regex pattern to each model response and extracts matched values.
If no match is found, returns a fallback value. Useful for extracting structured data
(like numbers) from unstructured model outputs.
"""
def __init__( def __init__(
self, self,
regex_pattern: str = r"#### (\-?[0-9\.\,]+)", regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select=0, group_select: int = 0,
fallback: str = "[invalid]", fallback: str = "[invalid]",
) -> None: ) -> None:
""" """
...@@ -25,7 +30,7 @@ class RegexFilter(Filter): ...@@ -25,7 +30,7 @@ class RegexFilter(Filter):
self.group_select = group_select self.group_select = group_select
self.fallback = fallback self.fallback = fallback
def apply(self, resps, docs): def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
# here, we assume we have a list, in which each element is # here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair. # a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets) # so we process each of these (same input/target response sets)
...@@ -37,28 +42,27 @@ class RegexFilter(Filter): ...@@ -37,28 +42,27 @@ class RegexFilter(Filter):
if match: if match:
match = match[self.group_select] match = match[self.group_select]
if isinstance(match, tuple): if isinstance(match, tuple):
match = [m for m in match if m][0] match = [m for m in match if m]
if match:
match = match[0]
else:
match = self.fallback
match = match.strip() match = match.strip()
else: else:
match = self.fallback match = self.fallback
filtered.append(match) filtered.append(match)
return filtered return filtered
# print(resps)
filtered_resps = list(map(lambda x: filter_set(x), resps)) filtered_resps = list(map(lambda x: filter_set(x), resps))
# print(filtered_resps)
return filtered_resps return filtered_resps
@register_filter("remove_whitespace") @register_filter("remove_whitespace")
class WhitespaceFilter(Filter): class WhitespaceFilter(Filter):
""" """ """Filters out leading whitespace from responses."""
def __init__(self) -> None:
pass
def apply(self, resps, docs): def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
def filter_set(inst): def filter_set(inst):
filtered_resp = [] filtered_resp = []
for resp in inst: for resp in inst:
...@@ -103,7 +107,7 @@ class MultiChoiceRegexFilter(RegexFilter): ...@@ -103,7 +107,7 @@ class MultiChoiceRegexFilter(RegexFilter):
self.ignore_punctuation = ignore_punctuation self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore self.regexes_to_ignore = regexes_to_ignore
def apply(self, resps, docs): def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
# here, we assume we have a list, in which each element is # here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair. # a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets) # so we process each of these (same input/target response sets)
...@@ -162,7 +166,7 @@ class MultiChoiceRegexFilter(RegexFilter): ...@@ -162,7 +166,7 @@ class MultiChoiceRegexFilter(RegexFilter):
fallback_regex = re.compile("|".join(fallback_regexes)) fallback_regex = re.compile("|".join(fallback_regexes))
without_paren_fallback_regex = "|".join(without_paren_fallback_regexes) without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
without_paren_fallback_regex = re.compile( without_paren_fallback_regex = re.compile(
f":[\s]*({without_paren_fallback_regex})" rf":[\s]*({without_paren_fallback_regex})"
) )
filtered = [] filtered = []
......
...@@ -15,10 +15,9 @@ logger = logging.getLogger(__name__) ...@@ -15,10 +15,9 @@ logger = logging.getLogger(__name__)
def get_wandb_printer() -> Literal["Printer"]: def get_wandb_printer() -> Literal["Printer"]:
"""Returns a wandb printer instance for pretty stdout.""" """Returns a wandb printer instance for pretty stdout."""
from wandb.sdk.lib.printer import get_printer from wandb.sdk.lib.printer import new_printer
from wandb.sdk.wandb_settings import Settings
printer = get_printer(Settings()._jupyter) printer = new_printer()
return printer return printer
...@@ -49,6 +48,9 @@ class WandbLogger: ...@@ -49,6 +48,9 @@ class WandbLogger:
self.wandb_args: Dict[str, Any] = kwargs self.wandb_args: Dict[str, Any] = kwargs
# pop the step key from the args to save for all logging calls
self.step = self.wandb_args.pop("step", None)
# initialize a W&B run # initialize a W&B run
if wandb.run is None: if wandb.run is None:
self.run = wandb.init(**self.wandb_args) self.run = wandb.init(**self.wandb_args)
...@@ -153,11 +155,11 @@ class WandbLogger: ...@@ -153,11 +155,11 @@ class WandbLogger:
# log the complete eval result to W&B Table # log the complete eval result to W&B Table
table = make_table(["Tasks"] + columns, "results") table = make_table(["Tasks"] + columns, "results")
self.run.log({"evaluation/eval_results": table}) self.run.log({"evaluation/eval_results": table}, step=self.step)
if "groups" in self.results.keys(): if "groups" in self.results.keys():
table = make_table(["Groups"] + columns, "groups") table = make_table(["Groups"] + columns, "groups")
self.run.log({"evaluation/group_eval_results": table}) self.run.log({"evaluation/group_eval_results": table}, step=self.step)
def _log_results_as_artifact(self) -> None: def _log_results_as_artifact(self) -> None:
"""Log results as JSON artifact to W&B.""" """Log results as JSON artifact to W&B."""
...@@ -175,13 +177,13 @@ class WandbLogger: ...@@ -175,13 +177,13 @@ class WandbLogger:
"""Log evaluation results to W&B.""" """Log evaluation results to W&B."""
# Log configs to wandb # Log configs to wandb
configs = self._get_config() configs = self._get_config()
self.run.config.update(configs) self.run.config.update(configs, allow_val_change=self.step is not None)
wandb_summary, self.wandb_results = self._sanitize_results_dict() wandb_summary, self.wandb_results = self._sanitize_results_dict()
# update wandb.run.summary with items that were removed # update wandb.run.summary with items that were removed
self.run.summary.update(wandb_summary) self.run.summary.update(wandb_summary)
# Log the evaluation metrics to wandb # Log the evaluation metrics to wandb
self.run.log(self.wandb_results) self.run.log(self.wandb_results, step=self.step)
# Log the evaluation metrics as W&B Table # Log the evaluation metrics as W&B Table
self._log_results_as_table() self._log_results_as_table()
# Log the results dict as json to W&B Artifacts # Log the results dict as json to W&B Artifacts
...@@ -330,7 +332,7 @@ class WandbLogger: ...@@ -330,7 +332,7 @@ class WandbLogger:
# log the samples as a W&B Table # log the samples as a W&B Table
df = self._generate_dataset(eval_preds, self.task_configs.get(task_name)) df = self._generate_dataset(eval_preds, self.task_configs.get(task_name))
self.run.log({f"{task_name}_eval_results": df}) self.run.log({f"{task_name}_eval_results": df}, step=self.step)
# log the samples as a json file as W&B Artifact # log the samples as a json file as W&B Artifact
self._log_samples_as_artifact(eval_preds, task_name) self._log_samples_as_artifact(eval_preds, task_name)
...@@ -349,4 +351,4 @@ class WandbLogger: ...@@ -349,4 +351,4 @@ class WandbLogger:
# log the samples as a json file as W&B Artifact # log the samples as a json file as W&B Artifact
self._log_samples_as_artifact(eval_preds, task_name) self._log_samples_as_artifact(eval_preds, task_name)
self.run.log({f"{group}_eval_results": grouped_df}) self.run.log({f"{group}_eval_results": grouped_df}, step=self.step)
...@@ -5,11 +5,13 @@ from . import ( ...@@ -5,11 +5,13 @@ from . import (
gguf, gguf,
hf_vlms, hf_vlms,
huggingface, huggingface,
ibm_watsonx_ai,
mamba_lm, mamba_lm,
nemo_lm, nemo_lm,
neuralmagic, neuralmagic,
neuron_optimum, neuron_optimum,
openai_completions, openai_completions,
optimum_ipex,
optimum_lm, optimum_lm,
textsynth, textsynth,
vllm_causallms, vllm_causallms,
......
...@@ -8,7 +8,7 @@ from lm_eval import utils ...@@ -8,7 +8,7 @@ from lm_eval import utils
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.openai_completions import LocalCompletionsAPI from lm_eval.models.openai_completions import LocalCompletionsAPI
from lm_eval.models.utils import retry_on_specific_exceptions from lm_eval.models.utils import handle_stop_sequences, retry_on_specific_exceptions
eval_logger = utils.eval_logger eval_logger = utils.eval_logger
...@@ -45,8 +45,8 @@ def anthropic_completion( ...@@ -45,8 +45,8 @@ def anthropic_completion(
try: try:
import anthropic import anthropic
except ModuleNotFoundError: except ModuleNotFoundError as exception:
raise Exception( raise type(exception)(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \ "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`", please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
) )
...@@ -108,8 +108,8 @@ def anthropic_chat( ...@@ -108,8 +108,8 @@ def anthropic_chat(
try: try:
import anthropic import anthropic
except ModuleNotFoundError: except ModuleNotFoundError as exception:
raise Exception( raise type(exception)(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \ "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`", please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
) )
...@@ -168,8 +168,8 @@ class AnthropicLM(LM): ...@@ -168,8 +168,8 @@ class AnthropicLM(LM):
try: try:
import anthropic import anthropic
except ModuleNotFoundError: except ModuleNotFoundError as exception:
raise Exception( raise type(exception)(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \ "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`", please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
) )
...@@ -217,8 +217,8 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install ...@@ -217,8 +217,8 @@ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]: def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
try: try:
import anthropic import anthropic
except ModuleNotFoundError: except ModuleNotFoundError as exception:
raise Exception( raise type(exception)(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \ "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`", please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
) )
...@@ -311,7 +311,12 @@ class AnthropicChat(LocalCompletionsAPI): ...@@ -311,7 +311,12 @@ class AnthropicChat(LocalCompletionsAPI):
} }
def _create_payload( def _create_payload(
self, messages: List[Dict], generate=True, gen_kwargs: dict = None, **kwargs self,
messages: List[Dict],
generate=True,
gen_kwargs: dict = None,
eos="\n\nHuman:",
**kwargs,
) -> dict: ) -> dict:
system = ( system = (
messages[0].get("content") if messages[0].get("role") == "system" else None messages[0].get("content") if messages[0].get("role") == "system" else None
...@@ -321,7 +326,7 @@ class AnthropicChat(LocalCompletionsAPI): ...@@ -321,7 +326,7 @@ class AnthropicChat(LocalCompletionsAPI):
gen_kwargs.pop("do_sample", False) gen_kwargs.pop("do_sample", False)
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks) max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
temperature = gen_kwargs.pop("temperature", 0) temperature = gen_kwargs.pop("temperature", 0)
stop = gen_kwargs.pop("until", ["\n\nHuman:"]) stop = handle_stop_sequences(gen_kwargs.pop("until", ["\n\nHuman:"]), eos=eos)
if not isinstance(stop, list): if not isinstance(stop, list):
stop = [stop] stop = [stop]
out = { out = {
......
...@@ -21,7 +21,7 @@ from typing import ( ...@@ -21,7 +21,7 @@ from typing import (
try: try:
import requests import requests
from aiohttp import ClientSession, TCPConnector from aiohttp import ClientSession, ClientTimeout, TCPConnector
from tenacity import RetryError, retry, stop_after_attempt, wait_exponential from tenacity import RetryError, retry, stop_after_attempt, wait_exponential
from tqdm import tqdm from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio from tqdm.asyncio import tqdm_asyncio
...@@ -58,11 +58,11 @@ class TemplateAPI(TemplateLM): ...@@ -58,11 +58,11 @@ class TemplateAPI(TemplateLM):
pretrained: str = None, # `model` takes precedence over `pretrained` when passed. pretrained: str = None, # `model` takes precedence over `pretrained` when passed.
base_url: str = None, base_url: str = None,
tokenizer: Optional[str] = None, tokenizer: Optional[str] = None,
# Logliklehood tasks require a tokenizer to calculate context lengths, # Loglikelihood tasks require a tokenizer to calculate context lengths,
# however the requests can be sent as a string if the API doesn't support token inputs. # however the requests can be sent as a string if the API doesn't support token inputs.
# use tokenized_requests=False # use tokenized_requests=False
tokenizer_backend: Optional[ tokenizer_backend: Optional[
Literal["tiktoken", "huggingface", None] Literal["tiktoken", "huggingface", "None", "none"]
] = "huggingface", ] = "huggingface",
truncate: bool = False, truncate: bool = False,
# number of concurrent requests. More useful if not batching # number of concurrent requests. More useful if not batching
...@@ -79,6 +79,10 @@ class TemplateAPI(TemplateLM): ...@@ -79,6 +79,10 @@ class TemplateAPI(TemplateLM):
trust_remote_code: bool = False, trust_remote_code: bool = False,
revision: Optional[str] = "main", revision: Optional[str] = "main",
use_fast_tokenizer: bool = True, use_fast_tokenizer: bool = True,
verify_certificate: bool = True,
eos_string: str = None,
# timeout in seconds
timeout: int = 300,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -115,11 +119,16 @@ class TemplateAPI(TemplateLM): ...@@ -115,11 +119,16 @@ class TemplateAPI(TemplateLM):
"Concurrent requests are disabled. To enable concurrent requests, set `num_concurrent` > 1." "Concurrent requests are disabled. To enable concurrent requests, set `num_concurrent` > 1."
) )
self._concurrent = int(num_concurrent) self._concurrent = int(num_concurrent)
self.tokenizer_backend = tokenizer_backend self.tokenizer_backend = (
None if tokenizer_backend in ("None", "none") else tokenizer_backend
)
self.add_bos_token = add_bos_token self.add_bos_token = add_bos_token
self.custom_prefix_token_id = custom_prefix_token_id self.custom_prefix_token_id = custom_prefix_token_id
self.tokenized_requests = tokenized_requests self.tokenized_requests = tokenized_requests
self.max_retries = int(max_retries) self.max_retries = int(max_retries)
self.verify_certificate = verify_certificate
self._eos_string = eos_string
self.timeout = int(timeout)
eval_logger.info(f"Using tokenizer {self.tokenizer_backend}") eval_logger.info(f"Using tokenizer {self.tokenizer_backend}")
if self.tokenizer_backend is None: if self.tokenizer_backend is None:
...@@ -144,7 +153,7 @@ class TemplateAPI(TemplateLM): ...@@ -144,7 +153,7 @@ class TemplateAPI(TemplateLM):
self.tokenizer = tiktoken.encoding_for_model(self.model) self.tokenizer = tiktoken.encoding_for_model(self.model)
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
raise Exception( raise ModuleNotFoundError(
"Attempted to use 'openai' LM type, but the package `tiktoken` is not installed. " "Attempted to use 'openai' LM type, but the package `tiktoken` is not installed. "
"Please install it via `pip install lm-eval[api]` or `pip install -e .[api]`." "Please install it via `pip install lm-eval[api]` or `pip install -e .[api]`."
) from e ) from e
...@@ -172,6 +181,7 @@ class TemplateAPI(TemplateLM): ...@@ -172,6 +181,7 @@ class TemplateAPI(TemplateLM):
generate: bool = True, generate: bool = True,
gen_kwargs: Optional[dict] = None, gen_kwargs: Optional[dict] = None,
seed: int = 1234, seed: int = 1234,
eos: str = None,
**kwargs, **kwargs,
) -> dict: ) -> dict:
"""This method is responsible for creating the json payload that will be sent to the API.""" """This method is responsible for creating the json payload that will be sent to the API."""
...@@ -194,7 +204,7 @@ class TemplateAPI(TemplateLM): ...@@ -194,7 +204,7 @@ class TemplateAPI(TemplateLM):
if not self.tokenized_requests: if not self.tokenized_requests:
# if messages are tokenized: # if messages are tokenized:
if isinstance(messages[0][0], int): if isinstance(messages[0][0], int):
# assuming decoding is lossless. However, this is only for logliklehood requests # assuming decoding is lossless. However, this is only for loglikelihood requests
# as we need to compute the context length. For generations, we don't need to tokenize. # as we need to compute the context length. For generations, we don't need to tokenize.
messages = self.decode_batch(messages) messages = self.decode_batch(messages)
if self._batch_size <= 1: if self._batch_size <= 1:
...@@ -243,12 +253,15 @@ class TemplateAPI(TemplateLM): ...@@ -243,12 +253,15 @@ class TemplateAPI(TemplateLM):
return "" return ""
def apply_chat_template( def apply_chat_template(
self, chat_history: List[Dict[str, str]] self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
) -> Union[str, JsonChatStr]: ) -> Union[str, JsonChatStr]:
"""Applies a chat template to a list of chat history between user and model.""" """Applies a chat template to a list of chat history between user and model."""
if self.tokenizer_backend == "huggingface" and self.tokenized_requests: if self.tokenizer_backend == "huggingface" and self.tokenized_requests:
return self.tokenizer.apply_chat_template( return self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
) )
else: else:
# bit of a hack. We'll load back before sending to the API # bit of a hack. We'll load back before sending to the API
...@@ -264,6 +277,21 @@ class TemplateAPI(TemplateLM): ...@@ -264,6 +277,21 @@ class TemplateAPI(TemplateLM):
elif self.tokenizer_backend == "tiktoken": elif self.tokenizer_backend == "tiktoken":
return self.tokenizer.eot_token return self.tokenizer.eot_token
@cached_property
def eos_string(self) -> Optional[str]:
if self._eos_string:
return self._eos_string
elif self.tokenizer is not None:
if self.tokenizer_backend == "huggingface":
return self.tokenizer.eos_token
elif self.tokenizer_backend == "tiktoken":
return self.tokenizer.decode([self.tokenizer.eot_token])
else:
eval_logger.warning(
"Cannot determine EOS string to pass to stop sequence. Manually set by passing `eos_string` to model_args."
)
return None
@cached_property @cached_property
def prefix_token_id(self) -> Optional[int]: def prefix_token_id(self) -> Optional[int]:
if self.tokenizer is None: if self.tokenizer is None:
...@@ -339,9 +367,11 @@ class TemplateAPI(TemplateLM): ...@@ -339,9 +367,11 @@ class TemplateAPI(TemplateLM):
generate=generate, generate=generate,
gen_kwargs=gen_kwargs, gen_kwargs=gen_kwargs,
seed=self._seed, seed=self._seed,
eos=self.eos_string,
**kwargs, **kwargs,
), ),
headers=self.header, headers=self.header,
verify=self.verify_certificate,
) )
if not response.ok: if not response.ok:
eval_logger.warning( eval_logger.warning(
...@@ -412,7 +442,7 @@ class TemplateAPI(TemplateLM): ...@@ -412,7 +442,7 @@ class TemplateAPI(TemplateLM):
) )
return None return None
def batch_logliklehood_requests( def batch_loglikelihood_requests(
self, chunks: Iterable[List[LogLikelihoodInputs]] self, chunks: Iterable[List[LogLikelihoodInputs]]
) -> Tuple[List[List[int]], List[int], List[Tuple[str, str]]]: ) -> Tuple[List[List[int]], List[int], List[Tuple[str, str]]]:
inputs = [] inputs = []
...@@ -421,9 +451,13 @@ class TemplateAPI(TemplateLM): ...@@ -421,9 +451,13 @@ class TemplateAPI(TemplateLM):
for chunk in chunks: for chunk in chunks:
for cache_key, context_enc, continuation_enc in chunk: for cache_key, context_enc, continuation_enc in chunk:
# max_length - 1 as we always have 1 token for generation # max_length - 1 as we always have 1 token for generation
inp = (context_enc + continuation_enc)[-(self.max_length) :] inp = (context_enc + continuation_enc)[-self.max_length :]
if len(inp) < len(context_enc + continuation_enc):
eval_logger.warning(
f"Context length ({len(context_enc)}) + continuation length ({len(continuation_enc)}) > max_length ({self.max_length}). Left truncating context."
)
ctxlen = len(context_enc) - max( ctxlen = len(context_enc) - max(
0, len(context_enc) + len(continuation_enc) - (self.max_length) 0, len(context_enc) + len(continuation_enc) - self.max_length
) )
inputs.append(inp) inputs.append(inp)
...@@ -442,7 +476,9 @@ class TemplateAPI(TemplateLM): ...@@ -442,7 +476,9 @@ class TemplateAPI(TemplateLM):
) -> Union[List[List[str]], List[List[Tuple[float, bool]]]]: ) -> Union[List[List[str]], List[List[Tuple[float, bool]]]]:
ctxlens = ctxlens if ctxlens else [None] * len(requests) ctxlens = ctxlens if ctxlens else [None] * len(requests)
conn = TCPConnector(limit=self._concurrent) conn = TCPConnector(limit=self._concurrent)
async with ClientSession(connector=conn) as session: async with ClientSession(
connector=conn, timeout=ClientTimeout(total=self.timeout)
) as session:
retry_: Callable[..., Awaitable[Any]] = retry( retry_: Callable[..., Awaitable[Any]] = retry(
stop=stop_after_attempt(self.max_retries), stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=0.5, min=1, max=10), wait=wait_exponential(multiplier=0.5, min=1, max=10),
...@@ -497,7 +533,7 @@ class TemplateAPI(TemplateLM): ...@@ -497,7 +533,7 @@ class TemplateAPI(TemplateLM):
if self._concurrent <= 1: if self._concurrent <= 1:
pbar = tqdm(desc="Requesting API", total=len(requests)) pbar = tqdm(desc="Requesting API", total=len(requests))
for chunk in chunked: for chunk in chunked:
inputs, ctxlens, cache_keys = self.batch_logliklehood_requests([chunk]) inputs, ctxlens, cache_keys = self.batch_loglikelihood_requests([chunk])
outputs = retry( outputs = retry(
stop=stop_after_attempt(self.max_retries), stop=stop_after_attempt(self.max_retries),
...@@ -521,7 +557,7 @@ class TemplateAPI(TemplateLM): ...@@ -521,7 +557,7 @@ class TemplateAPI(TemplateLM):
) )
pbar.update(1) pbar.update(1)
else: else:
inputs, ctxlens, cache_keys = self.batch_logliklehood_requests(chunked) inputs, ctxlens, cache_keys = self.batch_loglikelihood_requests(chunked)
res = itertools.chain.from_iterable( res = itertools.chain.from_iterable(
asyncio.run( asyncio.run(
self.get_batched_requests( self.get_batched_requests(
...@@ -565,6 +601,24 @@ class TemplateAPI(TemplateLM): ...@@ -565,6 +601,24 @@ class TemplateAPI(TemplateLM):
pbar = tqdm(desc="Requesting API", total=len(requests)) pbar = tqdm(desc="Requesting API", total=len(requests))
for chunk in chunked: for chunk in chunked:
contexts, all_gen_kwargs, encodings_list = zip(*chunk) contexts, all_gen_kwargs, encodings_list = zip(*chunk)
if self.tokenized_requests:
max_gen_toks = all_gen_kwargs[0].get(
"max_gen_toks", self._max_gen_toks
)
max_context_len = self.max_length - max_gen_toks
encodings_list = [x[-max_context_len:] for x in encodings_list]
if any(
len(x) + max_gen_toks > self.max_length for x in encodings_list
):
eval_logger.warning(
f"Some contexts exceeded (max length: ({self.max_length}) - max_gen_toks: ({max_gen_toks}). They were left truncated."
)
else:
eval_logger.info(
"Tokenized requests are disabled. Context + generation length is not checked."
)
req = encodings_list if self.tokenized_requests else contexts req = encodings_list if self.tokenized_requests else contexts
outputs = retry( outputs = retry(
stop=stop_after_attempt(self.max_retries), stop=stop_after_attempt(self.max_retries),
...@@ -596,6 +650,24 @@ class TemplateAPI(TemplateLM): ...@@ -596,6 +650,24 @@ class TemplateAPI(TemplateLM):
else: else:
for chunk in chunked: for chunk in chunked:
contexts, all_gen_kwargs, encodings_list = zip(*chunk) contexts, all_gen_kwargs, encodings_list = zip(*chunk)
if self.tokenized_requests:
max_gen_toks = all_gen_kwargs[0].get(
"max_gen_toks", self._max_gen_toks
)
max_context_len = self.max_length - max_gen_toks
encodings_list = [x[-max_context_len:] for x in encodings_list]
if any(
len(x) + max_gen_toks > self.max_length for x in encodings_list
):
eval_logger.warning(
f"Some contexts exceeded (max length: ({self.max_length}) - max_gen_toks ({max_gen_toks}). They were left truncated."
)
else:
eval_logger.info(
"Tokenized requests are disabled. Context + generation length is not checked."
)
req = encodings_list if self.tokenized_requests else contexts req = encodings_list if self.tokenized_requests else contexts
results = itertools.chain.from_iterable( results = itertools.chain.from_iterable(
asyncio.run( asyncio.run(
......
...@@ -68,7 +68,9 @@ class GGUFLM(LM): ...@@ -68,7 +68,9 @@ class GGUFLM(LM):
logger.error(f"RequestException: {e}") logger.error(f"RequestException: {e}")
time.sleep(delay) # wait before retrying time.sleep(delay) # wait before retrying
else: else:
raise Exception(f"Failed to get a valid response after {retries} retries.") raise RuntimeError(
f"Failed to get a valid response after {retries} retries."
)
def loglikelihood(self, requests, disable_tqdm: bool = False): def loglikelihood(self, requests, disable_tqdm: bool = False):
if not requests: if not requests:
......
...@@ -14,6 +14,7 @@ from lm_eval.models.huggingface import HFLM ...@@ -14,6 +14,7 @@ from lm_eval.models.huggingface import HFLM
from lm_eval.models.utils import ( from lm_eval.models.utils import (
Collator, Collator,
flatten_image_list, flatten_image_list,
handle_stop_sequences,
pad_and_concat, pad_and_concat,
replace_placeholders, replace_placeholders,
stop_sequences_criteria, stop_sequences_criteria,
...@@ -215,7 +216,9 @@ class HFMultimodalLM(HFLM): ...@@ -215,7 +216,9 @@ class HFMultimodalLM(HFLM):
return context_enc, continuation_enc, image_enc return context_enc, continuation_enc, image_enc
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str: def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
) -> str:
self.chat_applied = True self.chat_applied = True
if not self.interleave: if not self.interleave:
for content in chat_history: for content in chat_history:
...@@ -265,7 +268,9 @@ class HFMultimodalLM(HFLM): ...@@ -265,7 +268,9 @@ class HFMultimodalLM(HFLM):
) )
return self.processor.apply_chat_template( return self.processor.apply_chat_template(
chat_history, add_generation_prompt=True chat_history,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
) )
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
...@@ -661,7 +666,7 @@ class HFMultimodalLM(HFLM): ...@@ -661,7 +666,7 @@ class HFMultimodalLM(HFLM):
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
### Up to here: was identical to non-multimodal HFLM generate_until ### ### Up to here: was identical to non-multimodal HFLM generate_until ###
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
for chunk in chunks: for chunk in chunks:
contexts, all_gen_kwargs, aux_arguments = zip(*chunk) contexts, all_gen_kwargs, aux_arguments = zip(*chunk)
...@@ -678,27 +683,14 @@ class HFMultimodalLM(HFLM): ...@@ -678,27 +683,14 @@ class HFMultimodalLM(HFLM):
# this is safe to assume because the `grouper` object ensures it. # this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0] gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments. # unpack our keyword arguments.
until = None
if isinstance(gen_kwargs, dict): if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys(): # add EOS token to stop sequences
until = kwargs.pop("until") until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else: else:
raise ValueError( raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
) )
# add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
if not until:
until = [eos]
else:
until.append(eos)
if "max_gen_toks" in kwargs.keys(): if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks") max_gen_toks = kwargs.pop("max_gen_toks")
else: else:
......
This diff is collapsed.
import copy
import json
import os
from functools import lru_cache
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast
from tqdm import tqdm
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.models.api_models import JsonChatStr
from lm_eval.utils import eval_logger, simple_parse_args_string
class LogLikelihoodResult(NamedTuple):
log_likelihood: float
is_greedy: bool
def _verify_credentials(creds: Any) -> None:
"""
Verifies that all required keys are present in the credentials dictionary.
Args:
creds (Any): A dictionary containing the credentials.
Raises:
ValueError: If any of the necessary credentials are missing, with guidance on which environment variables need to be set.
"""
required_keys = ["apikey", "url", "project_id"]
env_var_mapping = {
"apikey": "WATSONX_API_KEY",
"url": "WATSONX_URL",
"project_id": "WATSONX_PROJECT_ID",
}
missing_keys = [key for key in required_keys if key not in creds or not creds[key]]
if missing_keys:
missing_env_vars = [env_var_mapping[key] for key in missing_keys]
raise ValueError(
f"Missing required credentials: {', '.join(missing_keys)}. Please set the following environment variables: {', '.join(missing_env_vars)}"
)
@lru_cache(maxsize=None)
def get_watsonx_credentials() -> Dict[str, str]:
"""
Retrieves Watsonx API credentials from environmental variables.
Returns:
Dict[str, str]: A dictionary containing the credentials necessary for authentication, including
keys such as `apikey`, `url`, and `project_id`.
Raises:
AssertionError: If the credentials format is invalid or any of the necessary credentials are missing.
"""
credentials = {
"apikey": os.getenv("WATSONX_API_KEY", None),
"url": os.getenv("WATSONX_URL", None),
"project_id": os.getenv("WATSONX_PROJECT_ID", None),
}
_verify_credentials(credentials)
return credentials
@register_model("watsonx_llm")
class WatsonxLLM(LM):
"""
Implementation of LM model interface for evaluating Watsonx model with the lm_eval framework.
See https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/model_guide.md for reference.
"""
@classmethod
def create_from_arg_string(
cls: Type["WatsonxLLM"],
arg_string: str,
additional_config: Optional[Dict] = None,
) -> "WatsonxLLM":
"""
Allow the user to specify model parameters (TextGenerationParameters) in CLI arguments.
"""
try:
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
except ImportError:
raise ImportError(
"Could not import ibm_watsonx_ai: Please install lm_eval[ibm_watsonx_ai] package."
)
args = simple_parse_args_string(arg_string)
args.update(additional_config)
model_id = args.pop("model_id", None)
if model_id is None:
raise ValueError("'model_id' is required, please pass it in 'model_args'")
if not args.get("do_sample", None):
args["temperature"] = None
args["top_p"] = None
args["top_k"] = None
args["seed"] = None
generate_params = {
GenParams.DECODING_METHOD: (
"greedy" if not args.get("do_sample", None) else "sample"
),
GenParams.LENGTH_PENALTY: args.get("length_penalty", None),
GenParams.TEMPERATURE: args.get("temperature", None),
GenParams.TOP_P: args.get("top_p", None),
GenParams.TOP_K: args.get("top_k", None),
GenParams.RANDOM_SEED: args.get("seed", None),
GenParams.REPETITION_PENALTY: args.get("repetition_penalty", None),
GenParams.MIN_NEW_TOKENS: args.get("min_new_tokens", None),
GenParams.MAX_NEW_TOKENS: args.get("max_new_tokens", 256),
GenParams.STOP_SEQUENCES: args.get("stop_sequences", None),
GenParams.TIME_LIMIT: args.get("time_limit", None),
GenParams.TRUNCATE_INPUT_TOKENS: args.get("truncate_input_tokens", None),
GenParams.RETURN_OPTIONS: {
"generated_tokens": True,
"input_tokens": True,
"token_logprobs": True,
"token_ranks": True,
},
}
generate_params = {k: v for k, v in generate_params.items() if v is not None}
return cls(
watsonx_credentials=get_watsonx_credentials(),
model_id=model_id,
generate_params=generate_params,
)
def __init__(
self,
watsonx_credentials: Dict,
model_id,
generate_params: Optional[Dict[Any, Any]] = None,
) -> None:
try:
from ibm_watsonx_ai import APIClient
from ibm_watsonx_ai.foundation_models import ModelInference
except ImportError:
raise ImportError(
"Could not import ibm_watsonx_ai: Please install lm_eval[ibm_watsonx_ai] package."
)
super().__init__()
client = APIClient(watsonx_credentials)
project_id = watsonx_credentials.get("project_id", None)
deployment_id = watsonx_credentials.get("deployment_id", None)
client.set.default_project(project_id)
self.generate_params = generate_params
self.model = ModelInference(
model_id=model_id,
deployment_id=deployment_id,
api_client=client,
project_id=project_id,
)
self._model_id = model_id
@staticmethod
def _has_stop_token(response_tokens: List[str], context_tokens: List[str]) -> bool:
"""
Determines whether a stop token has been generated in the `response_tokens` compared to the `context_tokens`.
If the tokens do not match as expected, the function raises a RuntimeError, indicating a possible
misalignment between the tokens generated by the tokenizer and the model.
Args:
response_tokens (List[str]): The List of tokens generated as a response by the model.
context_tokens (List[str]): The List of tokens representing the input context.
Returns:
bool: True if the `response_tokens` likely contain a stop token that terminates the sequence,
otherwise raises an exception.
Raises:
RuntimeError: If there is an unexpected mismatch between the `response_tokens` and the `context_tokens`.
"""
context_length = len(context_tokens)
if response_tokens[: context_length - 1] == context_tokens[:-1]:
return (
response_tokens[-1] != context_tokens[-1]
) # only last token differs, probably stop sequence (</s>)
raise RuntimeError(
f"There is an unexpected difference between tokenizer and model tokens:\n"
f"context_tokens={context_tokens}\n"
f"response_tokens={response_tokens[:context_length]}"
)
def _check_model_logprobs_support(self):
"""
Verifies if the model supports returning log probabilities for input tokens.
This function sends a prompt to the model and checks whether the model's response
includes log probabilities for the input tokens. If log probabilities are not present,
it raises a `RuntimeError`, indicating that the model is not supported.
Raises:
RuntimeError: If the model does not return log probabilities for input tokens.
"""
tokens = self.model.generate_text(
prompt=["The best ice cream flavor is:"],
params=self.generate_params,
raw_response=True,
)[0]["results"][0]
if all(token.get("logprob", None) is None for token in tokens["input_tokens"]):
raise RuntimeError(
f"Model {self._model_id} is not supported: does not return logprobs for input tokens"
)
def _get_log_likelihood(
self,
input_tokens: List[Dict[str, float]],
context_tokens: List[Dict[str, float]],
) -> LogLikelihoodResult:
"""
Calculates the log likelihood of the generated tokens compared to the context tokens.
Args:
input_tokens (List[Dict[str, float]]): A List of token dictionaries, each containing
token information like `text` and `logprob`.
context_tokens (List[Dict[str, float]]): A List of token dictionaries representing
the input context.
Returns:
LogLikelihoodResult: An object containing the calculated log likelihood and a boolean
flag indicating if the tokens were generated greedily.
"""
response_tokens = [token["text"] for token in input_tokens]
context_length = len(context_tokens)
if self._has_stop_token(response_tokens, context_tokens):
context_length -= 1
return LogLikelihoodResult(
log_likelihood=sum(
token.get("logprob", 0) for token in input_tokens[context_length:]
),
is_greedy=all(
token["rank"] == 1 for token in input_tokens[context_length:]
),
)
def generate_until(self, requests: List[Instance]) -> List[str]:
"""
Generates text responses for a List of requests, with progress tracking and caching.
Args:
requests (List[Instance]): A List of instances, each containing a text input to be processed.
Returns:
List[str]: A List of generated responses.
"""
requests = [request.args for request in requests]
results = []
for request in tqdm(
requests,
desc="Running generate_until function ...",
):
context, continuation = request
try:
if isinstance(context, JsonChatStr):
context = json.loads(context.prompt)
response = self.model.chat(context, self.generate_params)
response = response["choices"][0]["message"]["content"]
else:
response = self.model.generate_text(context, self.generate_params)
except Exception as exp:
eval_logger.error("Error while generating text.")
raise exp
results.append(response)
self.cache_hook.add_partial(
"generate_until", (context, continuation), response
)
return results
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
"""
Args:
requests: Each request contains Instance.args : Tuple[str, str] containing:
1. an input string to the LM and
2. a target string on which the loglikelihood of the LM producing this target,
conditioned on the input, will be returned.
Returns:
Tuple (loglikelihood, is_greedy) for each request according to the input order:
loglikelihood: probability of generating the target string conditioned on the input
is_greedy: True if and only if the target string would be generated by greedy sampling from the LM
"""
try:
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
except ImportError:
raise ImportError(
"Could not import ibm_watsonx_ai: Please install lm_eval[ibm_watsonx_ai] package."
)
self._check_model_logprobs_support()
generate_params = copy.copy(self.generate_params)
generate_params[GenParams.MAX_NEW_TOKENS] = 1
requests = [request.args for request in requests]
results: List[LogLikelihoodResult] = []
# Note: We're not using batching due to (current) indeterminism of loglikelihood values when sending batch of requests
for request in tqdm(
requests,
desc="Running loglikelihood function ...",
):
context, continuation = request
try:
tokenized_context = self.model.tokenize(
prompt=context, return_tokens=True
)["result"]["tokens"]
except Exception as exp:
eval_logger.error("Error while model tokenize.")
raise exp
input_prompt = context + continuation
try:
response = self.model.generate_text(
prompt=input_prompt, params=generate_params, raw_response=True
)
except Exception as exp:
eval_logger.error("Error while model generate text.")
raise exp
log_likelihood_response = self._get_log_likelihood(
response["results"][0]["input_tokens"], tokenized_context
)
results.append(log_likelihood_response)
self.cache_hook.add_partial(
"loglikelihood",
(context, continuation),
(
log_likelihood_response.log_likelihood,
log_likelihood_response.is_greedy,
),
)
return cast(List[Tuple[float, bool]], results)
def loglikelihood_rolling(self, requests) -> List[Tuple[float, bool]]:
"""
Used to evaluate perplexity on a data distribution.
Args:
requests: Each request contains Instance.args : Tuple[str] containing an input string to the model whose
entire loglikelihood, conditioned on purely the EOT token, will be calculated.
Returns:
Tuple (loglikelihood,) for each request according to the input order:
loglikelihood: solely the probability of producing each piece of text given no starting input.
"""
try:
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
except ImportError:
raise ImportError(
"Could not import ibm_watsonx_ai: Please install lm_eval[ibm_watsonx_ai] package."
)
self._check_model_logprobs_support()
generate_params = copy.deepcopy(self.generate_params)
generate_params[GenParams.MAX_NEW_TOKENS] = 1
requests = [request.args for request in requests]
results: List[LogLikelihoodResult] = []
# Note: We're not using batching due to (current) indeterminism of loglikelihood values when sending batch of requests
for request in tqdm(
requests,
desc="Running loglikelihood_rolling function ...",
):
context, continuation = request
try:
response = self.model.generate_text(
prompt=context, params=generate_params, raw_response=True
)
except Exception as exp:
eval_logger.error("Error while model generate text.")
raise exp
log_likelihood_response = self._get_log_likelihood(
response["results"][0]["input_tokens"], []
)
results.append(log_likelihood_response)
self.cache_hook.add_partial(
"loglikelihood_rolling",
(context, continuation),
log_likelihood_response.log_likelihood,
)
return cast(List[Tuple[float, bool]], results)
@property
def tokenizer_name(self) -> str:
return ""
def apply_chat_template(
self, chat_history: List[Dict[str, str]]
) -> List[Dict[str, str]]:
# A hack similar from api_model to allow encoding for cache
return JsonChatStr(json.dumps(chat_history))
...@@ -12,6 +12,8 @@ class MambaLMWrapper(HFLM): ...@@ -12,6 +12,8 @@ class MambaLMWrapper(HFLM):
def __init__( def __init__(
self, self,
pretrained="state-spaces/mamba-130m", pretrained="state-spaces/mamba-130m",
# To use the HF compatible variant
is_hf: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
""" """
...@@ -52,7 +54,7 @@ class MambaLMWrapper(HFLM): ...@@ -52,7 +54,7 @@ class MambaLMWrapper(HFLM):
if "backend" in kwargs: if "backend" in kwargs:
# mamba currently only supports causal models # mamba currently only supports causal models
assert kwargs["backend"] == "causal" assert kwargs["backend"] == "causal"
self.is_hf = is_hf or (True if pretrained.endswith("hf") else False)
super().__init__( super().__init__(
pretrained=pretrained, pretrained=pretrained,
# set appropriate defaults for tokenizer, max length, etc # set appropriate defaults for tokenizer, max length, etc
...@@ -67,15 +69,18 @@ class MambaLMWrapper(HFLM): ...@@ -67,15 +69,18 @@ class MambaLMWrapper(HFLM):
pretrained: str, pretrained: str,
**kwargs, **kwargs,
) -> None: ) -> None:
try: if self.is_hf:
from mamba_ssm.utils.hf import load_config_hf # noqa: F811 super()._get_config(pretrained, **kwargs)
except ModuleNotFoundError: else:
raise Exception( try:
"attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \ from mamba_ssm.utils.hf import load_config_hf # noqa: F811
please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`", except ModuleNotFoundError as exception:
) raise type(exception)(
"attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \
self._config = load_config_hf(pretrained) please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`",
)
self._config = load_config_hf(pretrained)
def _create_model( def _create_model(
self, self,
...@@ -86,24 +91,32 @@ please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba ...@@ -86,24 +91,32 @@ please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba
# Mamba does not support arbitrary HF from_pretrained() args # Mamba does not support arbitrary HF from_pretrained() args
**kwargs, **kwargs,
) -> None: ) -> None:
try: if self.is_hf:
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel # noqa: F811 super()._create_model(pretrained, dtype=dtype, **kwargs)
except ModuleNotFoundError: else:
raise Exception( try:
"attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \ from mamba_ssm.models.mixer_seq_simple import (
please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`", MambaLMHeadModel, # noqa: F811
)
except ModuleNotFoundError as exception:
raise type(exception)(
"attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \
please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`",
)
self._model = MambaLMHeadModel.from_pretrained(
pretrained,
device=self._device,
dtype=torch.float16
if dtype == "auto"
else lm_eval.models.utils.get_dtype(dtype),
) )
self._model = MambaLMHeadModel.from_pretrained(
pretrained,
device=self._device,
dtype=torch.float16
if dtype == "auto"
else lm_eval.models.utils.get_dtype(dtype),
)
def _model_generate(self, context, max_length, stop, **generation_kwargs): def _model_generate(self, context, max_length, stop, **generation_kwargs):
for key in ("do_sample", "attention_mask"): remove_arg = (
["attention_mask"] if self.is_hf else ["do_sample", "attention_mask"]
)
for key in remove_arg:
if key in generation_kwargs: if key in generation_kwargs:
generation_kwargs.pop(key) generation_kwargs.pop(key)
...@@ -116,11 +129,37 @@ please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba ...@@ -116,11 +129,37 @@ please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba
# self.tokenizer, stop, 1, context.shape[0] # self.tokenizer, stop, 1, context.shape[0]
# ) # )
return self.model.generate( if not self.is_hf:
input_ids=context, return self.model.generate(
max_length=max_length, input_ids=context,
# stopping_criteria=stopping_criteria, max_length=max_length,
# pad_token_id=self.tokenizer.pad_token_id, # stopping_criteria=stopping_criteria,
# use_cache=True, # pad_token_id=self.tokenizer.pad_token_id,
**generation_kwargs, # use_cache=True,
) **generation_kwargs,
)
else:
stopping_criteria = lm_eval.models.utils.stop_sequences_criteria(
self.tokenizer,
stop,
context.shape[1],
context.shape[0],
)
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
do_sample = generation_kwargs.get("do_sample", None)
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
generation_kwargs["do_sample"] = do_sample = False
if do_sample is False and generation_kwargs.get("temperature") == 0.0:
generation_kwargs.pop("temperature")
return self.model.generate(
input_ids=context,
max_length=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=True,
**generation_kwargs,
)
...@@ -39,8 +39,8 @@ def _patch_pretrained_cfg( ...@@ -39,8 +39,8 @@ def _patch_pretrained_cfg(
): ):
try: try:
import omegaconf import omegaconf
except ModuleNotFoundError: except ModuleNotFoundError as exception:
raise Exception( raise type(exception)(
"Attempted to use 'nemo_lm' model type, but package `nemo` is not installed" "Attempted to use 'nemo_lm' model type, but package `nemo` is not installed"
"Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, " "Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, "
"or installing nemo following https://github.com/NVIDIA/NeMo.", "or installing nemo following https://github.com/NVIDIA/NeMo.",
...@@ -79,8 +79,8 @@ def load_model( ...@@ -79,8 +79,8 @@ def load_model(
MegatronGPTModel, MegatronGPTModel,
) )
from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector
except ModuleNotFoundError: except ModuleNotFoundError as exception:
raise Exception( raise type(exception)(
"Attempted to use 'nemo_lm' model type, but package `nemo` is not installed" "Attempted to use 'nemo_lm' model type, but package `nemo` is not installed"
"Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, " "Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, "
"or installing nemo following https://github.com/NVIDIA/NeMo.", "or installing nemo following https://github.com/NVIDIA/NeMo.",
...@@ -140,8 +140,8 @@ def load_model( ...@@ -140,8 +140,8 @@ def load_model(
def setup_distributed_environment(trainer): def setup_distributed_environment(trainer):
try: try:
from nemo.utils.app_state import AppState from nemo.utils.app_state import AppState
except ModuleNotFoundError: except ModuleNotFoundError as exception:
raise Exception( raise type(exception)(
"Attempted to use 'nemo_lm' model type, but package `nemo` is not installed" "Attempted to use 'nemo_lm' model type, but package `nemo` is not installed"
"Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, " "Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, "
"or installing nemo following https://github.com/NVIDIA/NeMo.", "or installing nemo following https://github.com/NVIDIA/NeMo.",
...@@ -187,15 +187,15 @@ class NeMoLM(LM): ...@@ -187,15 +187,15 @@ class NeMoLM(LM):
**kwargs, **kwargs,
): ):
try: try:
from lightning.pytorch.trainer.trainer import Trainer
from nemo.collections.nlp.modules.common.text_generation_utils import ( from nemo.collections.nlp.modules.common.text_generation_utils import (
generate, generate,
) )
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy
from pytorch_lightning.trainer.trainer import Trainer
self.generate = generate self.generate = generate
except ModuleNotFoundError: except ModuleNotFoundError as exception:
raise Exception( raise type(exception)(
"Attempted to use 'nemo_lm' model type, but package `nemo` is not installed" "Attempted to use 'nemo_lm' model type, but package `nemo` is not installed"
"Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, " "Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, "
"or installing nemo following https://github.com/NVIDIA/NeMo.", "or installing nemo following https://github.com/NVIDIA/NeMo.",
......
...@@ -38,8 +38,8 @@ class SparseMLLM(HFLM): ...@@ -38,8 +38,8 @@ class SparseMLLM(HFLM):
) -> None: ) -> None:
try: try:
from sparseml.transformers import SparseAutoModelForCausalLM from sparseml.transformers import SparseAutoModelForCausalLM
except ModuleNotFoundError: except ModuleNotFoundError as exception:
raise Exception( raise type(exception)(
"Package `sparseml` is not installed. " "Package `sparseml` is not installed. "
"Please install it via `pip install sparseml[transformers]`" "Please install it via `pip install sparseml[transformers]`"
) )
...@@ -88,8 +88,8 @@ class SparseMLLM(HFLM): ...@@ -88,8 +88,8 @@ class SparseMLLM(HFLM):
def _get_config(self, pretrained: str, **kwargs) -> None: def _get_config(self, pretrained: str, **kwargs) -> None:
try: try:
from sparseml.transformers import SparseAutoConfig from sparseml.transformers import SparseAutoConfig
except ModuleNotFoundError: except ModuleNotFoundError as exception:
raise Exception( raise type(exception)(
"Package `sparseml` is not installed. " "Package `sparseml` is not installed. "
"Please install it via `pip install sparseml[transformers]`" "Please install it via `pip install sparseml[transformers]`"
) )
...@@ -112,8 +112,8 @@ class SparseMLLM(HFLM): ...@@ -112,8 +112,8 @@ class SparseMLLM(HFLM):
) -> None: ) -> None:
try: try:
from sparseml.transformers import SparseAutoTokenizer from sparseml.transformers import SparseAutoTokenizer
except ModuleNotFoundError: except ModuleNotFoundError as exception:
raise Exception( raise type(exception)(
"Package `sparseml` is not installed. " "Package `sparseml` is not installed. "
"Please install it via `pip install sparseml[transformers]`" "Please install it via `pip install sparseml[transformers]`"
) )
...@@ -171,8 +171,8 @@ class DeepSparseLM(LM): ...@@ -171,8 +171,8 @@ class DeepSparseLM(LM):
try: try:
import deepsparse import deepsparse
except ModuleNotFoundError: except ModuleNotFoundError as exception:
raise Exception( raise type(exception)(
"Package `deepsparse` is not installed. " "Package `deepsparse` is not installed. "
"Please install it via `pip install deepsparse[transformers]`" "Please install it via `pip install deepsparse[transformers]`"
) )
......
...@@ -144,7 +144,7 @@ class NEURON_HF(TemplateLM): ...@@ -144,7 +144,7 @@ class NEURON_HF(TemplateLM):
add_bos_token: Optional[bool] = False, add_bos_token: Optional[bool] = False,
) -> None: ) -> None:
if not NEURON_AVAILABLE: if not NEURON_AVAILABLE:
raise Exception( raise ImportError(
"Tried to load neuron model, but neuron is not installed ", "Tried to load neuron model, but neuron is not installed ",
"please install neuron via pip install transformers-neuron ", "please install neuron via pip install transformers-neuron ",
"also make sure you are running on an AWS inf2 instance", "also make sure you are running on an AWS inf2 instance",
......
...@@ -5,6 +5,7 @@ import itertools ...@@ -5,6 +5,7 @@ import itertools
import json import json
import os import os
from functools import cached_property from functools import cached_property
from operator import itemgetter
from io import BytesIO from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
...@@ -14,6 +15,8 @@ from tqdm import tqdm ...@@ -14,6 +15,8 @@ from tqdm import tqdm
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.api_models import TemplateAPI
from lm_eval.models.utils import handle_stop_sequences
from lm_eval.models.api_models import JsonChatStr, TemplateAPI from lm_eval.models.api_models import JsonChatStr, TemplateAPI
from lm_eval.models.utils import Collator from lm_eval.models.utils import Collator
from lm_eval.utils import eval_logger from lm_eval.utils import eval_logger
...@@ -37,6 +40,7 @@ class LocalCompletionsAPI(TemplateAPI): ...@@ -37,6 +40,7 @@ class LocalCompletionsAPI(TemplateAPI):
generate=False, generate=False,
gen_kwargs: Optional[dict] = None, gen_kwargs: Optional[dict] = None,
seed: int = 1234, seed: int = 1234,
eos=None,
**kwargs, **kwargs,
) -> dict: ) -> dict:
if generate: if generate:
...@@ -46,7 +50,7 @@ class LocalCompletionsAPI(TemplateAPI): ...@@ -46,7 +50,7 @@ class LocalCompletionsAPI(TemplateAPI):
else: else:
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks) max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
temperature = gen_kwargs.pop("temperature", 0) temperature = gen_kwargs.pop("temperature", 0)
stop = gen_kwargs.pop("until", ["<|endoftext|>"]) stop = handle_stop_sequences(gen_kwargs.pop("until", None), eos)
return { return {
"prompt": messages, "prompt": messages,
"model": self.model, "model": self.model,
...@@ -78,7 +82,9 @@ class LocalCompletionsAPI(TemplateAPI): ...@@ -78,7 +82,9 @@ class LocalCompletionsAPI(TemplateAPI):
if not isinstance(outputs, list): if not isinstance(outputs, list):
outputs = [outputs] outputs = [outputs]
for out in outputs: for out in outputs:
for choice, ctxlen in zip(out["choices"], ctxlens): for choice, ctxlen in zip(
sorted(out["choices"], key=itemgetter("index")), ctxlens
):
assert ctxlen > 0, "Context length must be greater than 0" assert ctxlen > 0, "Context length must be greater than 0"
logprobs = sum(choice["logprobs"]["token_logprobs"][ctxlen:-1]) logprobs = sum(choice["logprobs"]["token_logprobs"][ctxlen:-1])
tokens_logprobs = choice["logprobs"]["token_logprobs"][ctxlen:-1] tokens_logprobs = choice["logprobs"]["token_logprobs"][ctxlen:-1]
...@@ -97,8 +103,10 @@ class LocalCompletionsAPI(TemplateAPI): ...@@ -97,8 +103,10 @@ class LocalCompletionsAPI(TemplateAPI):
if not isinstance(outputs, list): if not isinstance(outputs, list):
outputs = [outputs] outputs = [outputs]
for out in outputs: for out in outputs:
tmp = [None] * len(out["choices"])
for choices in out["choices"]: for choices in out["choices"]:
res.append(choices["text"]) tmp[choices["index"]] = choices["text"]
res = res + tmp
return res return res
@property @property
...@@ -136,15 +144,19 @@ class LocalChatCompletion(LocalCompletionsAPI): ...@@ -136,15 +144,19 @@ class LocalChatCompletion(LocalCompletionsAPI):
generate=False, generate=False,
gen_kwargs: dict = None, gen_kwargs: dict = None,
seed=1234, seed=1234,
eos=None,
**kwargs, **kwargs,
) -> dict: ) -> dict:
assert (
type(messages) is not str
), "chat-completions require the --apply_chat_template flag."
gen_kwargs.pop("do_sample", False) gen_kwargs.pop("do_sample", False)
if "max_tokens" in gen_kwargs: if "max_tokens" in gen_kwargs:
max_tokens = gen_kwargs.pop("max_tokens") max_tokens = gen_kwargs.pop("max_tokens")
else: else:
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks) max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
temperature = gen_kwargs.pop("temperature", 0) temperature = gen_kwargs.pop("temperature", 0)
stop = gen_kwargs.pop("until", ["<|endoftext|>"]) stop = handle_stop_sequences(gen_kwargs.pop("until", None), eos)
if not isinstance(stop, (list, tuple)): if not isinstance(stop, (list, tuple)):
stop = [stop] stop = [stop]
return { return {
...@@ -163,8 +175,10 @@ class LocalChatCompletion(LocalCompletionsAPI): ...@@ -163,8 +175,10 @@ class LocalChatCompletion(LocalCompletionsAPI):
if not isinstance(outputs, list): if not isinstance(outputs, list):
outputs = [outputs] outputs = [outputs]
for out in outputs: for out in outputs:
tmp = [None] * len(out["choices"])
for choices in out["choices"]: for choices in out["choices"]:
res.append(choices["message"]["content"]) tmp[choices["index"]] = choices["message"]["content"]
res = res + tmp
return res return res
def tok_encode( def tok_encode(
...@@ -229,6 +243,10 @@ class OpenAIChatCompletion(LocalChatCompletion): ...@@ -229,6 +243,10 @@ class OpenAIChatCompletion(LocalChatCompletion):
tokenized_requests=False, tokenized_requests=False,
**kwargs, **kwargs,
): ):
if "o1" in kwargs.get("model", ""):
eval_logger.warning(
"o1 models do not support `stop` and only support temperature=1"
)
super().__init__( super().__init__(
base_url=base_url, base_url=base_url,
tokenizer_backend=tokenizer_backend, tokenizer_backend=tokenizer_backend,
...@@ -251,6 +269,41 @@ class OpenAIChatCompletion(LocalChatCompletion): ...@@ -251,6 +269,41 @@ class OpenAIChatCompletion(LocalChatCompletion):
"Loglikelihood (and therefore `multiple_choice`-type tasks) is not supported for chat completions as OpenAI does not provide prompt logprobs. See https://github.com/EleutherAI/lm-evaluation-harness/issues/942#issuecomment-1777836312 or https://github.com/EleutherAI/lm-evaluation-harness/issues/1196 for more background on this limitation." "Loglikelihood (and therefore `multiple_choice`-type tasks) is not supported for chat completions as OpenAI does not provide prompt logprobs. See https://github.com/EleutherAI/lm-evaluation-harness/issues/942#issuecomment-1777836312 or https://github.com/EleutherAI/lm-evaluation-harness/issues/1196 for more background on this limitation."
) )
def _create_payload(
self,
messages: List[Dict],
generate=False,
gen_kwargs: dict = None,
seed=1234,
eos="<|endoftext|>",
**kwargs,
) -> dict:
assert (
type(messages) is not str
), "chat-completions require the --apply_chat_template flag."
gen_kwargs.pop("do_sample", False)
if "max_tokens" in gen_kwargs:
max_tokens = gen_kwargs.pop("max_tokens")
else:
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
temperature = gen_kwargs.pop("temperature", 0)
stop = handle_stop_sequences(gen_kwargs.pop("until", ["<|endoftext|>"]), eos)
if not isinstance(stop, (list, tuple)):
stop = [stop]
output = {
"messages": messages,
"model": self.model,
"max_completion_tokens": max_tokens,
"temperature": temperature,
"stop": stop[:4],
"seed": seed,
**gen_kwargs,
}
if "o1" in self.model:
output.pop("stop")
output["temperature"] = 1
return output
@register_model("pixtral-api") @register_model("pixtral-api")
class PixtralAPI(LocalChatCompletion): class PixtralAPI(LocalChatCompletion):
......
from importlib.util import find_spec
from lm_eval import utils
from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM
from lm_eval.models.utils import get_dtype
eval_logger = utils.eval_logger
@register_model("ipex")
class IPEXLM(HFLM):
"""
using the HuggingFace transformers + optimum-intel ipex backend, can run on intel cpu and intel gpu
"""
def __init__(
self,
**kwargs,
) -> None:
if "backend" in kwargs:
# currently only supports causal models
assert (
kwargs["backend"] == "causal"
), "Currently, only IPEXModelForCausalLM is supported."
super().__init__(
backend=kwargs.pop("backend", "causal"),
**kwargs,
)
def _create_model(
self,
pretrained: str,
revision="main",
dtype="auto",
trust_remote_code=False,
# arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`.
# (accelerate naive PP (device_map) options)
parallelize=False,
gpus=None,
max_memory_per_gpu=None,
max_cpu_memory=None,
offload_folder="./offload",
# PEFT, delta weights and quantization options
peft=None,
delta=None,
autogptq=False,
gptqmodel=False,
**kwargs,
) -> None:
if not find_spec("optimum"):
raise ModuleNotFoundError(
"package `optimum` is not installed. Please install it via `pip install optimum[ipex]`"
)
else:
from optimum.intel import IPEXModelForCausalLM
model_kwargs = kwargs if kwargs else {}
model_kwargs.update(
self._get_accelerate_args(
parallelize=parallelize,
device_map=kwargs.get("device_map", None),
max_memory_per_gpu=max_memory_per_gpu,
max_cpu_memory=max_cpu_memory,
offload_folder=offload_folder,
gpus=gpus,
)
)
self._model = IPEXModelForCausalLM.from_pretrained(
pretrained,
revision=revision,
torch_dtype=get_dtype(dtype),
trust_remote_code=trust_remote_code,
**model_kwargs,
)
...@@ -50,7 +50,7 @@ class OptimumLM(HFLM): ...@@ -50,7 +50,7 @@ class OptimumLM(HFLM):
**kwargs, **kwargs,
) -> None: ) -> None:
if not find_spec("optimum"): if not find_spec("optimum"):
raise Exception( raise ModuleNotFoundError(
"package `optimum` is not installed. Please install it via `pip install optimum[openvino]`" "package `optimum` is not installed. Please install it via `pip install optimum[openvino]`"
) )
else: else:
...@@ -71,6 +71,11 @@ class OptimumLM(HFLM): ...@@ -71,6 +71,11 @@ class OptimumLM(HFLM):
else: else:
model_kwargs["ov_config"] = {} model_kwargs["ov_config"] = {}
model_kwargs["ov_config"].setdefault("CACHE_DIR", "") model_kwargs["ov_config"].setdefault("CACHE_DIR", "")
if "pipeline_parallel" in model_kwargs:
if model_kwargs["pipeline_parallel"]:
model_kwargs["ov_config"]["MODEL_DISTRIBUTION_POLICY"] = (
"PIPELINE_PARALLEL"
)
model_file = Path(pretrained) / "openvino_model.xml" model_file = Path(pretrained) / "openvino_model.xml"
if model_file.exists(): if model_file.exists():
export = False export = False
......
...@@ -709,3 +709,21 @@ def flatten_image_list(images: List[List]): ...@@ -709,3 +709,21 @@ def flatten_image_list(images: List[List]):
:return: a list of PIL images, via concatenating all the sub-lists in order. :return: a list of PIL images, via concatenating all the sub-lists in order.
""" """
return [image for image_list in images for image in image_list] return [image for image_list in images for image in image_list]
def handle_stop_sequences(
until: Union[str, List[str], None], eos: Optional[str]
) -> List[str]:
"""Ensures that the `until` parameter is a list of stop sequences and includes the EOS token."""
if isinstance(until, str):
until = [until]
elif until is None:
until = []
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
if eos is not None and eos not in until:
until.append(eos)
return until
...@@ -10,7 +10,12 @@ from tqdm import tqdm ...@@ -10,7 +10,12 @@ from tqdm import tqdm
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.model import TemplateLM from lm_eval.api.model import TemplateLM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.utils import Collator, configure_pad_token, undistribute from lm_eval.models.utils import (
Collator,
configure_pad_token,
handle_stop_sequences,
undistribute,
)
from lm_eval.utils import ( from lm_eval.utils import (
eval_logger, eval_logger,
get_rolling_token_windows, get_rolling_token_windows,
...@@ -65,7 +70,7 @@ class VLLM(TemplateLM): ...@@ -65,7 +70,7 @@ class VLLM(TemplateLM):
super().__init__() super().__init__()
if not find_spec("vllm"): if not find_spec("vllm"):
raise Exception( raise ModuleNotFoundError(
"attempted to use 'vllm' LM type, but package `vllm` is not installed. " "attempted to use 'vllm' LM type, but package `vllm` is not installed. "
"Please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" "Please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
) )
...@@ -97,7 +102,7 @@ class VLLM(TemplateLM): ...@@ -97,7 +102,7 @@ class VLLM(TemplateLM):
self.batch_size = ( self.batch_size = (
"auto" "auto"
if isinstance(batch_size, str) and "auto" in batch_size if isinstance(batch_size, str) and "auto" in batch_size
else batch_size else int(batch_size)
) )
if self.data_parallel_size <= 1: if self.data_parallel_size <= 1:
self.model = LLM(**self.model_args) self.model = LLM(**self.model_args)
...@@ -118,7 +123,7 @@ class VLLM(TemplateLM): ...@@ -118,7 +123,7 @@ class VLLM(TemplateLM):
tokenizer if tokenizer else pretrained, tokenizer if tokenizer else pretrained,
tokenizer_mode=tokenizer_mode, tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision, revision=tokenizer_revision,
) )
self.tokenizer = configure_pad_token(self.tokenizer) self.tokenizer = configure_pad_token(self.tokenizer)
self.add_bos_token = add_bos_token self.add_bos_token = add_bos_token
...@@ -179,14 +184,21 @@ class VLLM(TemplateLM): ...@@ -179,14 +184,21 @@ class VLLM(TemplateLM):
def max_gen_toks(self): def max_gen_toks(self):
return self._max_gen_toks return self._max_gen_toks
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str: def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
) -> str:
""" """
Method to apply a chat template to a list of chat history between user and model. Method to apply a chat template to a list of chat history between user and model.
""" """
return self.tokenizer.apply_chat_template( chat_templated = self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
) )
return chat_templated
@property @property
def tokenizer_name(self) -> str: def tokenizer_name(self) -> str:
return self.tokenizer.name_or_path.replace("/", "__") return self.tokenizer.name_or_path.replace("/", "__")
...@@ -239,17 +251,25 @@ class VLLM(TemplateLM): ...@@ -239,17 +251,25 @@ class VLLM(TemplateLM):
# but then tensor_parallel breaks # but then tensor_parallel breaks
@ray.remote @ray.remote
def run_inference_one_model( def run_inference_one_model(
model_args: dict, sampling_params, requests: List[List[int]] model_args: dict,
sampling_params,
requests: List[List[int]],
lora_request: LoRARequest,
): ):
llm = LLM(**model_args) llm = LLM(**model_args)
return llm.generate( return llm.generate(
prompt_token_ids=requests, sampling_params=sampling_params prompt_token_ids=requests,
sampling_params=sampling_params,
lora_request=lora_request,
) )
# dispatch requests to all self.data_parallel_size workers, in interleaved fashion # dispatch requests to all self.data_parallel_size workers, in interleaved fashion
# interleaved important to balance context lengths across workers # interleaved important to balance context lengths across workers
requests = [list(x) for x in distribute(self.data_parallel_size, requests)] requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
inputs = ((self.model_args, sampling_params, req) for req in requests) inputs = (
(self.model_args, sampling_params, req, self.lora_request)
for req in requests
)
object_refs = [run_inference_one_model.remote(*x) for x in inputs] object_refs = [run_inference_one_model.remote(*x) for x in inputs]
results = ray.get(object_refs) results = ray.get(object_refs)
# Invoke ray.shutdown() to prevent hang-ups if subsequent calls required. # Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
...@@ -257,28 +277,32 @@ class VLLM(TemplateLM): ...@@ -257,28 +277,32 @@ class VLLM(TemplateLM):
# flatten results # flatten results
return undistribute(results) return undistribute(results)
if self.lora_request is not None: outputs = self.model.generate(
outputs = self.model.generate( prompt_token_ids=requests,
prompt_token_ids=requests, sampling_params=sampling_params,
sampling_params=sampling_params, use_tqdm=True if self.batch_size == "auto" else False,
use_tqdm=True if self.batch_size == "auto" else False, lora_request=self.lora_request,
lora_request=self.lora_request, )
)
else:
outputs = self.model.generate(
prompt_token_ids=requests,
sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
)
return outputs return outputs
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: List[Instance], disable_tqdm: bool = False
) -> List[float]: ) -> List[float]:
loglikelihoods = [] adaptive_batch_size = None
if self.batch_size == "auto":
for (string,) in tqdm([req.args for req in requests], disable=disable_tqdm): adaptive_batch_size = len(requests)
rolling_token_windows = list(
# First, collect all windows from all requests
all_windows = [] # List of (request_idx, window) tuples
request_window_counts = [] # Track number of windows per request
for req_idx, (string,) in enumerate(
tqdm(
[req.args for req in requests],
disable=(disable_tqdm or (self.rank != 0)),
)
):
rolling_token_windows: List[Tuple[List[int], List[int]]] = list(
map( map(
make_disjoint_window, make_disjoint_window,
get_rolling_token_windows( get_rolling_token_windows(
...@@ -291,20 +315,42 @@ class VLLM(TemplateLM): ...@@ -291,20 +315,42 @@ class VLLM(TemplateLM):
) )
) )
rolling_token_windows = [(None,) + x for x in rolling_token_windows] # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
windows = [(None,) + x for x in rolling_token_windows]
string_nll = self._loglikelihood_tokens( # Store windows with their request index
rolling_token_windows, all_windows.extend((req_idx, window) for window in windows)
) request_window_counts.append(len(windows))
# discard is_greedy all_nlls = []
string_nll = [x[0] for x in string_nll] batch_size = adaptive_batch_size or int(self.batch_size)
for i in range(0, len(all_windows), batch_size):
batch = all_windows[i : i + batch_size]
# Extract just the windows for processing, keeping track of request indices
batch_indices, batch_windows = zip(*batch)
string_nll = sum(string_nll) batch_nlls = self._loglikelihood_tokens(
loglikelihoods.append(string_nll) requests=batch_windows,
disable_tqdm=False,
)
# Store results with their request indices
all_nlls.extend(zip(batch_indices, batch_nlls))
# cache this loglikelihood_rolling request # Reconstruct per-request loglikelihoods
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll) loglikelihoods = []
current_idx = 0
for window_count in request_window_counts:
# Get all nlls for this request
request_nlls = all_nlls[current_idx : current_idx + window_count]
# Sum up the nlls for this request (discarding is_greedy)
request_total = sum(nll[0] for _, nll in request_nlls)
loglikelihoods.append(request_total)
current_idx += window_count
string = requests[len(loglikelihoods) - 1].args[0]
self.cache_hook.add_partial(
"loglikelihood_rolling", (string,), request_total
)
return loglikelihoods return loglikelihoods
...@@ -345,6 +391,7 @@ class VLLM(TemplateLM): ...@@ -345,6 +391,7 @@ class VLLM(TemplateLM):
desc="Running generate_until requests", desc="Running generate_until requests",
) )
# for each different set of kwargs, we execute all requests, by batch. # for each different set of kwargs, we execute all requests, by batch.
eos = self.tokenizer.decode(self.eot_token_id)
for chunk in chunks: for chunk in chunks:
context_and_encoding, all_gen_kwargs = zip(*chunk) context_and_encoding, all_gen_kwargs = zip(*chunk)
context, context_encoding = zip(*context_and_encoding) context, context_encoding = zip(*context_and_encoding)
...@@ -352,27 +399,14 @@ class VLLM(TemplateLM): ...@@ -352,27 +399,14 @@ class VLLM(TemplateLM):
# this is safe to assume because the `grouper` object ensures it. # this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0] gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments. # unpack our keyword arguments.
until = None
if isinstance(gen_kwargs, dict): if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys(): # add EOS token to stop sequences
until = kwargs.pop("until") until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else: else:
raise ValueError( raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {gen_kwargs}" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
) )
# add EOS token to stop sequences
eos = self.tokenizer.decode(self.eot_token_id)
if not until:
until = [eos]
else:
until.append(eos)
if "max_gen_toks" in kwargs.keys(): if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks") max_gen_toks = kwargs.pop("max_gen_toks")
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment