"examples/runtime/vscode:/vscode.git/clone" did not exist on "85d2365d337ca81eb353645bca15a199cc348847"
Commit b58e5556 authored by Baber's avatar Baber
Browse files

Merge branch 'main' into tasklist

# Conflicts:
#	pyproject.toml
parents 6e1866f5 4f8195f1
...@@ -35,7 +35,7 @@ repos: ...@@ -35,7 +35,7 @@ repos:
- id: ruff - id: ruff
args: args:
- --fix - --fix
# Run the formatter. # Run the formatter.
- id: ruff-format - id: ruff-format
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell
rev: v2.4.1 rev: v2.4.1
...@@ -43,8 +43,10 @@ repos: ...@@ -43,8 +43,10 @@ repos:
- id: codespell - id: codespell
exclude: > exclude: >
(?x)^( (?x)^(
.*\.json|ignore.txt|lm_eval/tasks/.*|.*yaml|.*\.ipynb .*\.json|ignore.txt|lm_eval/tasks/.*|.*yaml|.*\.ipynb
)$ )$
args: [--check-filenames, --check-hidden, --ignore-words=ignore.txt] args: [--check-filenames, --check-hidden, --ignore-words=ignore.txt]
- repo: https://github.com/jackdewinter/pymarkdown - repo: https://github.com/jackdewinter/pymarkdown
rev: v0.9.30 rev: v0.9.30
...@@ -52,9 +54,3 @@ repos: ...@@ -52,9 +54,3 @@ repos:
- id: pymarkdown - id: pymarkdown
exclude: ^(lm_eval/tasks/.*|docs/footguns\.md)$ exclude: ^(lm_eval/tasks/.*|docs/footguns\.md)$
args: [fix, -r] args: [fix, -r]
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.5.1
# hooks:
# - id: mypy
# additional_dependencies: [".[sentencepiece,multilingual,promptsource,gptq]", "types-PyYAML", "types-requests"]
# exclude: ^tests/.*$
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
--- ---
## Latest News 📣 ## Latest News 📣
- [2025/07] Added `think_end_token` arg to `hf` (token/str), `vllm` and `sglang` (str) for stripping CoT reasoning traces from models that support it.
- [2025/03] Added support for steering HF models! - [2025/03] Added support for steering HF models!
- [2025/02] Added [SGLang](https://docs.sglang.ai/) support! - [2025/02] Added [SGLang](https://docs.sglang.ai/) support!
- [2024/09] We are prototyping allowing users of LM Evaluation Harness to create and evaluate on text+image multimodal input, text output tasks, and have just added the `hf-multimodal` and `vllm-vlm` model types and `mmmu` task as a prototype feature. We welcome users to try out this in-progress feature and stress-test it for themselves, and suggest they check out [`lmms-eval`](https://github.com/EvolvingLMMs-Lab/lmms-eval), a wonderful project originally forking off of the lm-evaluation-harness, for a broader range of multimodal tasks, models, and features. - [2024/09] We are prototyping allowing users of LM Evaluation Harness to create and evaluate on text+image multimodal input, text output tasks, and have just added the `hf-multimodal` and `vllm-vlm` model types and `mmmu` task as a prototype feature. We welcome users to try out this in-progress feature and stress-test it for themselves, and suggest they check out [`lmms-eval`](https://github.com/EvolvingLMMs-Lab/lmms-eval), a wonderful project originally forking off of the lm-evaluation-harness, for a broader range of multimodal tasks, models, and features.
......
...@@ -21,7 +21,11 @@ When subclassing `TemplateAPI`, you need to implement the following methods: ...@@ -21,7 +21,11 @@ When subclassing `TemplateAPI`, you need to implement the following methods:
1. `_create_payload`: Creates the JSON payload for API requests. 1. `_create_payload`: Creates the JSON payload for API requests.
2. `parse_logprobs`: Parses log probabilities from API responses. 2. `parse_logprobs`: Parses log probabilities from API responses.
3. `parse_generations`: Parses generated text from API responses. 3. `parse_generations`: Parses generated text from API responses.
4. `headers`: Returns the headers for the API request.
Optional Properties:
4. `header`: Returns the headers for the API request.
5. `api_key`: Returns the API key for authentication (if required).
You may also need to override other methods or properties depending on your API's specific requirements. You may also need to override other methods or properties depending on your API's specific requirements.
...@@ -97,6 +101,10 @@ When initializing a `TemplateAPI` instance or a subclass, you can provide severa ...@@ -97,6 +101,10 @@ When initializing a `TemplateAPI` instance or a subclass, you can provide severa
- Whether to validate the certificate of the API endpoint (if HTTPS). - Whether to validate the certificate of the API endpoint (if HTTPS).
- Default is True. - Default is True.
- `header` (dict, optional):
- Custom headers for API requests.
- If not provided, uses `{"Authorization": f"Bearer {self.api_key}"}` by default.
Example usage: Example usage:
```python ```python
......
...@@ -435,10 +435,15 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -435,10 +435,15 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
# because it's already been determined based on the prior env var before launching our # because it's already been determined based on the prior env var before launching our
# script--`datasets` gets imported by lm_eval internally before these lines can update the env. # script--`datasets` gets imported by lm_eval internally before these lines can update the env.
import datasets import datasets
from packaging.version import parse as vparse
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True if vparse(datasets.__version__) < vparse("4.0.0"):
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
args.model_args = args.model_args + ",trust_remote_code=True" if isinstance(args.model_args, dict):
args.model_args["trust_remote_code"] = True
else:
args.model_args = args.model_args + ",trust_remote_code=True"
( (
eval_logger.info(f"Selected Tasks: {task_names}") eval_logger.info(f"Selected Tasks: {task_names}")
if eval_logger.getEffectiveLevel() >= logging.INFO if eval_logger.getEffectiveLevel() >= logging.INFO
......
...@@ -505,7 +505,6 @@ def bootstrap_stderr( ...@@ -505,7 +505,6 @@ def bootstrap_stderr(
if not os.getenv("DISABLE_MULTIPROC"): if not os.getenv("DISABLE_MULTIPROC"):
import multiprocessing as mp import multiprocessing as mp
pool = mp.Pool(mp.cpu_count())
# this gives a biased estimate of the stderr (i.e w/ the mean, it gives something # this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
# equivalent to stderr calculated without Bessel's correction in the stddev. # equivalent to stderr calculated without Bessel's correction in the stddev.
# Unfortunately, I haven't been able to figure out what the right correction is # Unfortunately, I haven't been able to figure out what the right correction is
...@@ -517,17 +516,16 @@ def bootstrap_stderr( ...@@ -517,17 +516,16 @@ def bootstrap_stderr(
from tqdm import tqdm from tqdm import tqdm
print("bootstrapping for stddev:", f.__name__) print("bootstrapping for stddev:", f.__name__)
for bootstrap in tqdm( with mp.Pool(mp.cpu_count()) as pool:
pool.imap( for bootstrap in tqdm(
_bootstrap_internal(f, chunk_size), pool.imap(
[(i, xs) for i in range(iters // chunk_size)], _bootstrap_internal(f, chunk_size),
), [(i, xs) for i in range(iters // chunk_size)],
total=iters // chunk_size, ),
): total=iters // chunk_size,
# sample w replacement ):
res.extend(bootstrap) # sample w replacement
res.extend(bootstrap)
pool.close()
else: else:
res = _bootstrap_internal_no_mp(f, xs, iters) res = _bootstrap_internal_no_mp(f, xs, iters)
......
...@@ -3,18 +3,15 @@ import ast ...@@ -3,18 +3,15 @@ import ast
import logging import logging
import random import random
import re import re
from collections.abc import Callable from collections.abc import Callable, Iterable, Iterator, Mapping
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from inspect import getsource from inspect import getsource
from typing import ( from typing import (
Any, Any,
Dict, Dict,
Iterable,
Iterator,
List, List,
Literal, Literal,
Mapping,
Optional, Optional,
Tuple, Tuple,
Union, Union,
...@@ -113,7 +110,7 @@ class TaskConfig(dict): ...@@ -113,7 +110,7 @@ class TaskConfig(dict):
if "until" not in self.generation_kwargs: if "until" not in self.generation_kwargs:
eval_logger.warning( eval_logger.warning(
f"{self.task}: No `until` specified in `generation_kwargs`! Defaulting to the fewshot_delimiter={repr(self.fewshot_delimiter)}" f"{self.task}: No `until` specified in `generation_kwargs`! Defaulting to the fewshot_delimiter={self.fewshot_delimiter!r}"
) )
self.generation_kwargs["until"] = [self.fewshot_delimiter] self.generation_kwargs["until"] = [self.fewshot_delimiter]
else: else:
...@@ -289,17 +286,14 @@ class Task(abc.ABC): ...@@ -289,17 +286,14 @@ class Task(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def has_training_docs(self): def has_training_docs(self):
"""Whether the task has a training set""" """Whether the task has a training set"""
pass
@abc.abstractmethod @abc.abstractmethod
def has_validation_docs(self): def has_validation_docs(self):
"""Whether the task has a validation set""" """Whether the task has a validation set"""
pass
@abc.abstractmethod @abc.abstractmethod
def has_test_docs(self): def has_test_docs(self):
"""Whether the task has a test set""" """Whether the task has a test set"""
pass
def training_docs(self) -> Iterable: def training_docs(self) -> Iterable:
""" """
...@@ -518,7 +512,6 @@ class Task(abc.ABC): ...@@ -518,7 +512,6 @@ class Task(abc.ABC):
The number of times each instance in a dataset is inferred on. Defaults to 1, The number of times each instance in a dataset is inferred on. Defaults to 1,
can be increased for techniques like majority voting. can be increased for techniques like majority voting.
""" """
pass
@abc.abstractmethod @abc.abstractmethod
def process_results(self, doc, results): def process_results(self, doc, results):
...@@ -531,7 +524,6 @@ class Task(abc.ABC): ...@@ -531,7 +524,6 @@ class Task(abc.ABC):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
pass
@abc.abstractmethod @abc.abstractmethod
def aggregation(self): def aggregation(self):
...@@ -540,7 +532,6 @@ class Task(abc.ABC): ...@@ -540,7 +532,6 @@ class Task(abc.ABC):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metric scores functions that aggregate a list of metric scores
""" """
pass
@abc.abstractmethod @abc.abstractmethod
def higher_is_better(self): def higher_is_better(self):
...@@ -549,7 +540,6 @@ class Task(abc.ABC): ...@@ -549,7 +540,6 @@ class Task(abc.ABC):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
pass
def get_config(self, key: str) -> Any: def get_config(self, key: str) -> Any:
return getattr(self._config, key, None) return getattr(self._config, key, None)
...@@ -675,8 +665,8 @@ class Task(abc.ABC): ...@@ -675,8 +665,8 @@ class Task(abc.ABC):
self.aggregation = lambda: { self.aggregation = lambda: {
metric_name: get_metric_aggregation(metric_name) metric_name: get_metric_aggregation(metric_name)
} }
setattr(self._config, "metric_list", [{"metric": metric_name}]) self._config.metric_list = [{"metric": metric_name}]
setattr(self._config, "process_results", None) self._config.process_results = None
def set_fewshot_seed(self, seed: Optional[int] = None) -> None: def set_fewshot_seed(self, seed: Optional[int] = None) -> None:
self.fewshot_rnd = random.Random(seed) self.fewshot_rnd = random.Random(seed)
...@@ -835,7 +825,7 @@ class ConfigurableTask(Task): ...@@ -835,7 +825,7 @@ class ConfigurableTask(Task):
agg_name = metric_config["aggregation"] agg_name = metric_config["aggregation"]
if isinstance(agg_name, str): if isinstance(agg_name, str):
self._aggregation_list[metric_name] = get_aggregation(agg_name) self._aggregation_list[metric_name] = get_aggregation(agg_name)
elif callable(agg_name): # noqa: E721 elif callable(agg_name):
self._aggregation_list[metric_name] = metric_config[ self._aggregation_list[metric_name] = metric_config[
"aggregation" "aggregation"
] ]
...@@ -980,6 +970,10 @@ class ConfigurableTask(Task): ...@@ -980,6 +970,10 @@ class ConfigurableTask(Task):
def download( def download(
self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs
) -> None: ) -> None:
from packaging.version import parse as vparse
if dataset_kwargs and vparse(datasets.__version__) >= vparse("4.0.0"):
dataset_kwargs.pop("trust_remote_code", None)
if isinstance(self.config.custom_dataset, Callable): if isinstance(self.config.custom_dataset, Callable):
eval_logger.warning( eval_logger.warning(
f"{self.config.task}: Custom kwargs can be passed to `--metadata` in console (as json string) or to the TaskManager." f"{self.config.task}: Custom kwargs can be passed to `--metadata` in console (as json string) or to the TaskManager."
...@@ -1498,7 +1492,7 @@ class ConfigurableTask(Task): ...@@ -1498,7 +1492,7 @@ class ConfigurableTask(Task):
): # TODO: ensure that non-multimodal tasks aren't getting visual args ): # TODO: ensure that non-multimodal tasks aren't getting visual args
multimodal_arg = { multimodal_arg = {
**multimodal_arg, **multimodal_arg,
**{"visual": self.doc_to_image(doc)}, "visual": self.doc_to_image(doc),
} }
if ( if (
...@@ -1506,7 +1500,7 @@ class ConfigurableTask(Task): ...@@ -1506,7 +1500,7 @@ class ConfigurableTask(Task):
): # TODO: ensure that non-multimodal tasks aren't getting audio args ): # TODO: ensure that non-multimodal tasks aren't getting audio args
multimodal_arg = { multimodal_arg = {
**multimodal_arg, **multimodal_arg,
**{"audio": self.doc_to_audio(doc)}, "audio": self.doc_to_audio(doc),
} }
if bool(multimodal_arg): if bool(multimodal_arg):
...@@ -1769,7 +1763,7 @@ class MultipleChoiceTask(Task): ...@@ -1769,7 +1763,7 @@ class MultipleChoiceTask(Task):
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=(ctx, " {}".format(choice)), arguments=(ctx, f" {choice}"),
idx=i, idx=i,
**kwargs, **kwargs,
) )
......
...@@ -35,6 +35,7 @@ from lm_eval.utils import ( ...@@ -35,6 +35,7 @@ from lm_eval.utils import (
positional_deprecated, positional_deprecated,
setup_logging, setup_logging,
simple_parse_args_string, simple_parse_args_string,
wrap_text,
) )
...@@ -169,8 +170,11 @@ def simple_evaluate( ...@@ -169,8 +170,11 @@ def simple_evaluate(
) )
) and not apply_chat_template: ) and not apply_chat_template:
eval_logger.warning( eval_logger.warning(
"Model appears to be an instruct or chat variant but chat template is not applied. " wrap_text(
"Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`)." f"""pretrained={model_args.get("pretrained") if isinstance(model_args, dict) else model_args} appears to be an
instruct or chat variant but chat template is not applied.
Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`).""",
)
) )
if delete_requests_cache: if delete_requests_cache:
...@@ -234,7 +238,9 @@ def simple_evaluate( ...@@ -234,7 +238,9 @@ def simple_evaluate(
else: else:
eval_logger.info( eval_logger.info(
f"Initializing {model} model, with arguments: {simple_parse_args_string(model_args)}" wrap_text(
f"Initializing {model} model, with arguments: {simple_parse_args_string(model_args)}"
)
) )
lm = lm_eval.api.registry.get_model(model).create_from_arg_string( lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
model_args, model_args,
......
...@@ -135,6 +135,7 @@ class TemplateAPI(TemplateLM): ...@@ -135,6 +135,7 @@ class TemplateAPI(TemplateLM):
eos_string: str = None, eos_string: str = None,
# timeout in seconds # timeout in seconds
timeout: int = 300, timeout: int = 300,
header: Optional[Dict[str, str]] = None,
max_images: int = 1, max_images: int = 1,
**kwargs, **kwargs,
) -> None: ) -> None:
...@@ -152,6 +153,7 @@ class TemplateAPI(TemplateLM): ...@@ -152,6 +153,7 @@ class TemplateAPI(TemplateLM):
self.model = model or pretrained self.model = model or pretrained
self.base_url = base_url self.base_url = base_url
self.tokenizer = tokenizer self.tokenizer = tokenizer
self._header = header
if not isinstance(batch_size, int) and "auto" in batch_size: if not isinstance(batch_size, int) and "auto" in batch_size:
eval_logger.warning( eval_logger.warning(
"Automatic batch size is not supported for API models. Defaulting to batch size 1." "Automatic batch size is not supported for API models. Defaulting to batch size 1."
...@@ -296,7 +298,7 @@ class TemplateAPI(TemplateLM): ...@@ -296,7 +298,7 @@ class TemplateAPI(TemplateLM):
@cached_property @cached_property
def header(self) -> dict: def header(self) -> dict:
"""Override this property to return the headers for the API request.""" """Override this property to return the headers for the API request."""
return {"Authorization": f"Bearer {self.api_key}"} return self._header or {"Authorization": f"Bearer {self.api_key}"}
@property @property
def tokenizer_name(self) -> str: def tokenizer_name(self) -> str:
...@@ -447,6 +449,7 @@ class TemplateAPI(TemplateLM): ...@@ -447,6 +449,7 @@ class TemplateAPI(TemplateLM):
async def amodel_call( async def amodel_call(
self, self,
session: ClientSession, session: ClientSession,
sem: asyncio.Semaphore,
messages: Union[List[List[int]], List[str], List[JsonChatStr]], messages: Union[List[List[int]], List[str], List[JsonChatStr]],
*, *,
generate: bool = True, generate: bool = True,
...@@ -465,6 +468,7 @@ class TemplateAPI(TemplateLM): ...@@ -465,6 +468,7 @@ class TemplateAPI(TemplateLM):
**kwargs, **kwargs,
) )
cache_method = "generate_until" if generate else "loglikelihood" cache_method = "generate_until" if generate else "loglikelihood"
acquired = await sem.acquire()
try: try:
async with session.post( async with session.post(
self.base_url, self.base_url,
...@@ -474,7 +478,8 @@ class TemplateAPI(TemplateLM): ...@@ -474,7 +478,8 @@ class TemplateAPI(TemplateLM):
if not response.ok: if not response.ok:
error_text = await response.text() error_text = await response.text()
eval_logger.warning( eval_logger.warning(
f"API request failed with error message: {error_text}. Retrying..." f"API request failed! Status code: {response.status}, "
f"Response text: {error_text}. Retrying..."
) )
# raising exception will retry the request # raising exception will retry the request
response.raise_for_status() response.raise_for_status()
...@@ -495,11 +500,12 @@ class TemplateAPI(TemplateLM): ...@@ -495,11 +500,12 @@ class TemplateAPI(TemplateLM):
self.cache_hook.add_partial(cache_method, cache, res) self.cache_hook.add_partial(cache_method, cache, res)
return answers return answers
# If the retries also fail # If the retries also fail
except RetryError: except BaseException as e:
eval_logger.error( eval_logger.error(f"Exception:{repr(e)}, {outputs}, retrying.")
"API request failed after multiple retries. Please check the API status." raise e
) finally:
return None if acquired:
sem.release()
def batch_loglikelihood_requests( def batch_loglikelihood_requests(
self, chunks: Iterable[List[LogLikelihoodInputs]] self, chunks: Iterable[List[LogLikelihoodInputs]]
...@@ -535,6 +541,7 @@ class TemplateAPI(TemplateLM): ...@@ -535,6 +541,7 @@ class TemplateAPI(TemplateLM):
) -> Union[List[List[str]], List[List[Tuple[float, bool]]]]: ) -> Union[List[List[str]], List[List[Tuple[float, bool]]]]:
ctxlens = ctxlens if ctxlens else [None] * len(requests) ctxlens = ctxlens if ctxlens else [None] * len(requests)
conn = TCPConnector(limit=self._concurrent, ssl=self.verify_certificate) conn = TCPConnector(limit=self._concurrent, ssl=self.verify_certificate)
sem = asyncio.Semaphore(self._concurrent)
async with ClientSession( async with ClientSession(
connector=conn, timeout=ClientTimeout(total=self.timeout) connector=conn, timeout=ClientTimeout(total=self.timeout)
) as session: ) as session:
...@@ -542,12 +549,16 @@ class TemplateAPI(TemplateLM): ...@@ -542,12 +549,16 @@ class TemplateAPI(TemplateLM):
stop=stop_after_attempt(self.max_retries), stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=0.5, min=1, max=10), wait=wait_exponential(multiplier=0.5, min=1, max=10),
reraise=True, reraise=True,
before_sleep=lambda retry_state: eval_logger.info(
f"Retry attempt {retry_state.attempt_number}"
),
)(self.amodel_call) )(self.amodel_call)
# Create tasks for each batch of request # Create tasks for each batch of request
tasks = [ tasks = [
asyncio.create_task( asyncio.create_task(
retry_( retry_(
session=session, session=session,
sem=sem,
messages=message, messages=message,
cache_keys=cache_key, cache_keys=cache_key,
generate=generate, generate=generate,
......
This diff is collapsed.
...@@ -16,8 +16,8 @@ eval_logger = logging.getLogger(__name__) ...@@ -16,8 +16,8 @@ eval_logger = logging.getLogger(__name__)
class LocalCompletionsAPI(TemplateAPI): class LocalCompletionsAPI(TemplateAPI):
def __init__( def __init__(
self, self,
base_url=None, base_url: str = None,
tokenizer_backend="huggingface", tokenizer_backend: str = "huggingface",
**kwargs, **kwargs,
): ):
super().__init__( super().__init__(
...@@ -108,9 +108,9 @@ class LocalCompletionsAPI(TemplateAPI): ...@@ -108,9 +108,9 @@ class LocalCompletionsAPI(TemplateAPI):
class LocalChatCompletion(LocalCompletionsAPI): class LocalChatCompletion(LocalCompletionsAPI):
def __init__( def __init__(
self, self,
base_url=None, base_url: str = None,
tokenizer_backend=None, tokenizer_backend: str = None,
tokenized_requests=False, tokenized_requests: bool = False,
**kwargs, **kwargs,
): ):
eval_logger.warning( eval_logger.warning(
...@@ -236,6 +236,7 @@ class OpenAIChatCompletion(LocalChatCompletion): ...@@ -236,6 +236,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
eval_logger.warning( eval_logger.warning(
"o1 models do not support `stop` and only support temperature=1" "o1 models do not support `stop` and only support temperature=1"
) )
super().__init__( super().__init__(
base_url=base_url, base_url=base_url,
tokenizer_backend=tokenizer_backend, tokenizer_backend=tokenizer_backend,
......
...@@ -11,6 +11,7 @@ from lm_eval.api.registry import register_model ...@@ -11,6 +11,7 @@ from lm_eval.api.registry import register_model
from lm_eval.models.utils import ( from lm_eval.models.utils import (
Collator, Collator,
handle_stop_sequences, handle_stop_sequences,
postprocess_generated_text,
) )
from lm_eval.utils import ( from lm_eval.utils import (
get_rolling_token_windows, get_rolling_token_windows,
...@@ -59,6 +60,8 @@ class SGLangLM(TemplateLM): ...@@ -59,6 +60,8 @@ class SGLangLM(TemplateLM):
dp_size: int = 1, dp_size: int = 1,
tp_size: int = 1, tp_size: int = 1,
prefix_token_id: Optional[int] = None, prefix_token_id: Optional[int] = None,
# End marker for thinking tags - splits to get response after this token (if provided).
think_end_token: Optional[str] = None,
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
...@@ -74,6 +77,7 @@ class SGLangLM(TemplateLM): ...@@ -74,6 +77,7 @@ class SGLangLM(TemplateLM):
"Either context_length or max_model_len may be provided, but not both" "Either context_length or max_model_len may be provided, but not both"
) )
# Initialize your sglang model here # Initialize your sglang model here
self.think_end_token = think_end_token
self._max_length = ( self._max_length = (
max_model_len if max_model_len is not None else context_length max_model_len if max_model_len is not None else context_length
) )
...@@ -263,6 +267,9 @@ class SGLangLM(TemplateLM): ...@@ -263,6 +267,9 @@ class SGLangLM(TemplateLM):
# cache generations # cache generations
for output, context in zip(cont, context): for output, context in zip(cont, context):
generated_text = output.get("text", "") generated_text = output.get("text", "")
generated_text = postprocess_generated_text(
generated_text, until, self.think_end_token
)
res.append(generated_text) res.append(generated_text)
self.cache_hook.add_partial( self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), generated_text "generate_until", (context, gen_kwargs), generated_text
......
...@@ -852,3 +852,32 @@ def truncate_tokens( ...@@ -852,3 +852,32 @@ def truncate_tokens(
right_length = max_length - left_length right_length = max_length - left_length
return tokens[:left_length] + tokens[-right_length:] return tokens[:left_length] + tokens[-right_length:]
return None return None
def postprocess_generated_text(
generation: str, stop: Union[list[str], str, None], think_end_token: Optional[str]
) -> str:
"""
Post-processes the generated text by stripping stop sequences and optional thinking markers.
Args:
generation (str): The generated text to be processed.
stop (Optional[list[str]]): Stop sequence(s) to remove. Text is truncated
at the first occurrence of any stop sequence.
think_end_token (Optional[str]): Token marking end of thinking section. If provided,
returns only the text after this token (discarding thinking content).
Returns:
str: The processed generation - text before stop sequences and after thinking sections.
"""
if stop:
stop = [stop] if isinstance(stop, str) else stop
for term in stop:
if len(term) > 0:
# ignore '' separator,
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
generation = generation.split(term)[0]
if think_end_token:
generation = generation.split(think_end_token)[-1].lstrip()
return generation
...@@ -22,6 +22,7 @@ from lm_eval.models.utils import ( ...@@ -22,6 +22,7 @@ from lm_eval.models.utils import (
Collator, Collator,
configure_pad_token, configure_pad_token,
handle_stop_sequences, handle_stop_sequences,
postprocess_generated_text,
undistribute, undistribute,
) )
from lm_eval.utils import ( from lm_eval.utils import (
...@@ -130,10 +131,14 @@ class VLLM(TemplateLM): ...@@ -130,10 +131,14 @@ class VLLM(TemplateLM):
max_model_len: int = None, max_model_len: int = None,
seed: int = 1234, seed: int = 1234,
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
device: str = "cuda",
data_parallel_size: int = 1, data_parallel_size: int = 1,
lora_local_path: str = None, lora_local_path: str = None,
enable_thinking: bool = False, # VLLM: enable thinking tags in the prompt.
enable_thinking: bool = True,
chat_template_args: Optional[dict] = None,
# End marker for thinking tags - splits to get response after this token (if provided).
think_end_token: Optional[str] = None,
max_lora_rank: int = 16,
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
...@@ -147,6 +152,8 @@ class VLLM(TemplateLM): ...@@ -147,6 +152,8 @@ class VLLM(TemplateLM):
assert max_length is None or max_model_len is None, ( assert max_length is None or max_model_len is None, (
"Either max_length or max_model_len may be provided, but not both" "Either max_length or max_model_len may be provided, but not both"
) )
kwargs.pop("device", None)
self.think_end_token = think_end_token
self.V1 = os.environ.get("VLLM_USE_V1", "1") != "0" self.V1 = os.environ.get("VLLM_USE_V1", "1") != "0"
self._max_length = max_model_len if max_model_len is not None else max_length self._max_length = max_model_len if max_model_len is not None else max_length
self.tensor_parallel_size = int(tensor_parallel_size) self.tensor_parallel_size = int(tensor_parallel_size)
...@@ -166,7 +173,8 @@ class VLLM(TemplateLM): ...@@ -166,7 +173,8 @@ class VLLM(TemplateLM):
"swap_space": int(swap_space), "swap_space": int(swap_space),
"quantization": quantization, "quantization": quantization,
"seed": int(seed), "seed": int(seed),
"device": str(device), "enable_lora": True if lora_local_path else False,
"max_lora_rank": int(max_lora_rank),
} }
self.model_args.update(kwargs) self.model_args.update(kwargs)
self.batch_size = ( self.batch_size = (
...@@ -201,7 +209,10 @@ class VLLM(TemplateLM): ...@@ -201,7 +209,10 @@ class VLLM(TemplateLM):
add_bos_token=add_bos_token, add_bos_token=add_bos_token,
) )
self.tokenizer = configure_pad_token(self.tokenizer, model_config=self._config) self.tokenizer = configure_pad_token(self.tokenizer, model_config=self._config)
self.enable_thinking = enable_thinking self.chat_template_args = chat_template_args or {}
self.enable_thinking = self.chat_template_args.pop(
"enable_thinking", enable_thinking
)
self.add_bos_token = add_bos_token self.add_bos_token = add_bos_token
if "gemma" in pretrained.lower(): if "gemma" in pretrained.lower():
self.add_bos_token = True self.add_bos_token = True
...@@ -309,6 +320,7 @@ class VLLM(TemplateLM): ...@@ -309,6 +320,7 @@ class VLLM(TemplateLM):
continue_final_message=not add_generation_prompt, continue_final_message=not add_generation_prompt,
chat_template=self.hf_chat_template, chat_template=self.hf_chat_template,
enable_thinking=self.enable_thinking, enable_thinking=self.enable_thinking,
**self.chat_template_args,
) )
except jinja2.exceptions.TemplateError: except jinja2.exceptions.TemplateError:
eval_logger.warning( eval_logger.warning(
...@@ -321,6 +333,7 @@ class VLLM(TemplateLM): ...@@ -321,6 +333,7 @@ class VLLM(TemplateLM):
continue_final_message=not add_generation_prompt, continue_final_message=not add_generation_prompt,
chat_template=self.hf_chat_template, chat_template=self.hf_chat_template,
enable_thinking=self.enable_thinking, enable_thinking=self.enable_thinking,
**self.chat_template_args,
) )
return chat_templated return chat_templated
...@@ -627,11 +640,11 @@ class VLLM(TemplateLM): ...@@ -627,11 +640,11 @@ class VLLM(TemplateLM):
# cache generations # cache generations
for output, context in zip(cont, context): for output, context in zip(cont, context):
generated_text = output.outputs[0].text generated_text: str = output.outputs[0].text
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until: generated_text = postprocess_generated_text(
if len(term) > 0: generated_text, until, self.think_end_token
generated_text = generated_text.split(term)[0] )
res.append(generated_text) res.append(generated_text)
self.cache_hook.add_partial( self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), generated_text "generate_until", (context, gen_kwargs), generated_text
......
This diff is collapsed.
...@@ -4,9 +4,9 @@ include: _boolq_cot_2shot_yaml ...@@ -4,9 +4,9 @@ include: _boolq_cot_2shot_yaml
fewshot_config: fewshot_config:
sampler: first_n sampler: first_n
samples: samples:
- context: 'This is a ferry domain, where the task is to transport cars from their start to their goal locations, using a ferry. Each location is accessible by ferry from each other location. The cars can be debarked or boarded, and the ferry can carry only one car at a time. There are 2 locations and 5 cars, numbered consecutively. Currently, the ferry is at l1, with the car c4 on board. The cars are at locations as follows: c0 and c3 are at l1; c1 and c2 are at l0.' - context: "This is a ferry domain, where the task is to transport cars from their start to their goal locations, using a ferry. Each location is accessible by ferry from each other location. The cars can be debarked or boarded, and the ferry can carry only one car at a time. There are 2 locations and 5 cars, numbered consecutively. Currently, the ferry is at l1, with the car c4 on board. The cars are at locations as follows: c0 and c3 are at l1; c1 and c2 are at l0."
question: 'Is it possible to transition to a state where the action "travel by sea from location l0 to location l1" can be applied?' question: 'Is it possible to transition to a state where the action "travel by sea from location l0 to location l1" can be applied?'
answer: "Let's think step by step. Step 1: Verify if there is a sequence of actions which transforms the current state into a state where the precondition of the action \"travel by sea from location l0 to location l1\" hold. Step 2: The following sequence of actions would transition to such a state: sail from location l1 to location l0, unload the car c4 from the ferry to location l0, board car c1 at location l0. **Final Answer**: Yes." answer: "Let's think step by step. Step 1: Verify if there is a sequence of actions which transforms the current state into a state where the precondition of the action \"travel by sea from location l0 to location l1\" hold. Step 2: The following sequence of actions would transition to such a state: sail from location l1 to location l0, unload the car c4 from the ferry to location l0, board car c1 at location l0. **Final Answer**: Yes."
- context: 'There are several cities, each containing several locations, some of which are airports. There are also trucks, which can drive within a single city, and airplanes, which can fly between airports. The goal is to get some packages from various locations to various new locations. There are 2 trucks and 1 airplane, as well as 4 packages. There are 6 locations across 2 cities. The locations are in cities as follows: l0-0, l0-1, and l0-2 are in c0; l1-1, l1-2, and l1-0 are in c1. Currently, a0 is at l1-0, t1 is at l1-1, t0 is at l0-0, p2 and p1 are in t1, p0 and p3 are in a0.' - context: "There are several cities, each containing several locations, some of which are airports. There are also trucks, which can drive within a single city, and airplanes, which can fly between airports. The goal is to get some packages from various locations to various new locations. There are 2 trucks and 1 airplane, as well as 4 packages. There are 6 locations across 2 cities. The locations are in cities as follows: l0-0, l0-1, and l0-2 are in c0; l1-1, l1-2, and l1-0 are in c1. Currently, a0 is at l1-0, t1 is at l1-1, t0 is at l0-0, p2 and p1 are in t1, p0 and p3 are in a0."
question: 'Is it possible to transition to a state where the action "offload the object p0 from the truck p0 at location p1" can be applied?' question: 'Is it possible to transition to a state where the action "offload the object p0 from the truck p0 at location p1" can be applied?'
answer: "Let's think step by step. Step 1: Verify if there is a sequence of actions which transforms the current state into a state where the precondition of the action \"offload the object p0 from the truck p0 at location p1\" hold. Step 2: Action preconditions are \"p0 is in p0 and p0 is at p1\". Step 3: These facts are not reachable together, as they include mutually exclusive facts \"p0 is in p0 and p0 is at p1\". **Final Answer**: No." answer: "Let's think step by step. Step 1: Verify if there is a sequence of actions which transforms the current state into a state where the precondition of the action \"offload the object p0 from the truck p0 at location p1\" hold. Step 2: Action preconditions are \"p0 is in p0 and p0 is at p1\". Step 3: These facts are not reachable together, as they include mutually exclusive facts \"p0 is in p0 and p0 is at p1\". **Final Answer**: No."
...@@ -67,7 +67,7 @@ def span_f1_agg(items): ...@@ -67,7 +67,7 @@ def span_f1_agg(items):
def remove_blank_spaces(text): def remove_blank_spaces(text):
text = re.sub(pattern=get_blank_spaces_pattern(), repl="", string=text) text = re.sub(pattern=get_blank_spaces_pattern(), repl="", string=text)
text = re.sub("\s+", " ", text) text = re.sub(r"\s+", " ", text)
return text return text
def remove_punctuation(text): def remove_punctuation(text):
......
...@@ -67,7 +67,7 @@ def span_f1_agg(items): ...@@ -67,7 +67,7 @@ def span_f1_agg(items):
def remove_blank_spaces(text): def remove_blank_spaces(text):
text = re.sub(pattern=get_blank_spaces_pattern(), repl="", string=text) text = re.sub(pattern=get_blank_spaces_pattern(), repl="", string=text)
text = re.sub("\s+", " ", text) text = re.sub(r"\s+", " ", text)
return text return text
def remove_punctuation(text): def remove_punctuation(text):
......
...@@ -67,7 +67,7 @@ def span_f1_agg(items): ...@@ -67,7 +67,7 @@ def span_f1_agg(items):
def remove_blank_spaces(text): def remove_blank_spaces(text):
text = re.sub(pattern=get_blank_spaces_pattern(), repl="", string=text) text = re.sub(pattern=get_blank_spaces_pattern(), repl="", string=text)
text = re.sub("\s+", " ", text) text = re.sub(r"\s+", " ", text)
return text return text
def remove_punctuation(text): def remove_punctuation(text):
......
...@@ -67,7 +67,7 @@ def span_f1_agg(items): ...@@ -67,7 +67,7 @@ def span_f1_agg(items):
def remove_blank_spaces(text): def remove_blank_spaces(text):
text = re.sub(pattern=get_blank_spaces_pattern(), repl="", string=text) text = re.sub(pattern=get_blank_spaces_pattern(), repl="", string=text)
text = re.sub("\s+", " ", text) text = re.sub(r"\s+", " ", text)
return text return text
def remove_punctuation(text): def remove_punctuation(text):
......
...@@ -67,7 +67,7 @@ def span_f1_agg(items): ...@@ -67,7 +67,7 @@ def span_f1_agg(items):
def remove_blank_spaces(text): def remove_blank_spaces(text):
text = re.sub(pattern=get_blank_spaces_pattern(), repl="", string=text) text = re.sub(pattern=get_blank_spaces_pattern(), repl="", string=text)
text = re.sub("\s+", " ", text) text = re.sub(r"\s+", " ", text)
return text return text
def remove_punctuation(text): def remove_punctuation(text):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment