Commit 601be343 authored by Baber's avatar Baber
Browse files

Merge branch 'main' into feature/eval_from_config

parents d0884a96 68c3a811
......@@ -79,36 +79,36 @@ jobs:
path: |
test_logs/*
testmodels:
name: External LM Tests
runs-on: ubuntu-latest
timeout-minutes: 30
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Set up Python 3.9
uses: actions/setup-python@v5
with:
python-version: 3.9
cache: pip
cache-dependency-path: pyproject.toml
# Cache HuggingFace cache directory for External LM tests
- name: Cache HuggingFace cache (External LM tests)
uses: actions/cache@v3
id: cache-hf-lm
with:
path: ~/.cache/huggingface
key: ${{ runner.os }}-hf-cache-external-lm
restore-keys: |
${{ runner.os }}-hf-cache-external-lm
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e '.[dev,optimum,deepsparse,sparseml,api]' --extra-index-url https://download.pytorch.org/whl/cpu
pip install -U transformers peft accelerate
- name: Test with pytest
run: python -m pytest tests/models --showlocals -s -vv
continue-on-error: true # Continue workflow even if tests fail
# testmodels:
# name: External LM Tests
# runs-on: ubuntu-latest
# timeout-minutes: 30
# steps:
# - name: Checkout Code
# uses: actions/checkout@v4
# - name: Set up Python 3.9
# uses: actions/setup-python@v5
# with:
# python-version: 3.9
# cache: pip
# cache-dependency-path: pyproject.toml
#
# # Cache HuggingFace cache directory for External LM tests
# - name: Cache HuggingFace cache (External LM tests)
# uses: actions/cache@v3
# id: cache-hf-lm
# with:
# path: ~/.cache/huggingface
# key: ${{ runner.os }}-hf-cache-external-lm
# restore-keys: |
# ${{ runner.os }}-hf-cache-external-lm
#
# - name: Install dependencies
# run: |
# python -m pip install --upgrade pip
# pip install -e '.[dev,optimum,deepsparse,sparseml,api]' --extra-index-url https://download.pytorch.org/whl/cpu
# pip install -U transformers peft accelerate
#
# - name: Test with pytest
# run: python -m pytest tests/models --showlocals -s -vv
# continue-on-error: true # Continue workflow even if tests fail
env
*.pyc
output/
data/
lm_cache
.idea
build
dist
*.egg-info
venv
# macOS system files
.DS_Store
# Virtual environments
.venv/
venv/
ENV/
env/
*.env
# Python bytecode and build artifacts
__pycache__/
*.py[cod]
*.so
*.egg-info/
build/
dist/
# IDE & editor settings
.vscode/
temp
__pycache__
.ipynb_checkpoints
temp
test_logs/
# IPython
.idea/
# Jupyter
.ipynb_checkpoints/
profile_default/
ipython_config.py
# don't track (the default location of) the cached requests
# Output and data
output/
data/
temp/
test_logs/
# Caching
lm_eval/caching/.cache
# don't track files created by wandb
wandb
examples/wandb
lm_cache/
# Logging
*.log
logs/
# wandb experiment tracking
wandb/
examples/wandb/
# PyInstaller
*.spec
......@@ -29,7 +29,7 @@ repos:
- id: mixed-line-ending
args: [--fix=lf]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.0
rev: v0.11.10
hooks:
# Run the linter.
- id: ruff
......@@ -50,7 +50,7 @@ repos:
rev: v0.9.29
hooks:
- id: pymarkdown
exclude: ^lm_eval/tasks/
exclude: ^(lm_eval/tasks/.*|docs/footguns\.md)$
args: [fix, -r]
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.5.1
......
recursive-include tests
......@@ -614,7 +614,7 @@ Extras dependencies can be installed via `pip install -e ".[NAME]"`
```text
@misc{eval-harness,
author = {Gao, Leo and Tow, Jonathan and Abbasi, Baber and Biderman, Stella and Black, Sid and DiPofi, Anthony and Foster, Charles and Golding, Laurence and Hsu, Jeffrey and Le Noac'h, Alain and Li, Haonan and McDonell, Kyle and Muennighoff, Niklas and Ociepa, Chris and Phang, Jason and Reynolds, Laria and Schoelkopf, Hailey and Skowron, Aviya and Sutawika, Lintang and Tang, Eric and Thite, Anish and Wang, Ben and Wang, Kevin and Zou, Andy},
title = {A framework for few-shot language model evaluation},
title = {The Language Model Evaluation Harness},
month = 07,
year = 2024,
publisher = {Zenodo},
......
# Common Pitfalls and Troubleshooting Guide
This document highlights common pitfalls and troubleshooting tips when using this library. We'll continue to add more tips as we discover them.
## YAML Configuration Issues
### Newline Characters in YAML (`\n`)
**Problem:** When specifying newline characters in YAML, they may be interpreted incorrectly depending on how you format them.
```yaml
# ❌ WRONG: Single quotes don't process escape sequences
generation_kwargs:
until: ['\n'] # Gets parsed as the literal characters '\' and 'n' i.e "\\n"
```
```yaml
# ✅ RIGHT: Use double quotes for escape sequences
generation_kwargs:
until: ["\n"] # Gets parsed as an actual newline character
```
**Solutions:**
- Use double quotes for strings containing escape sequences
- For multiline content, use YAML's block scalars (`|` or `>`)
- When generating YAML programmatically, be careful with how template engines handle escape sequences
### Quoting in YAML
**When to use different types of quotes:**
- **No quotes**: Simple values (numbers, booleans, alphanumeric strings without special characters)
```yaml
simple_value: plain text
number: 42
```
- **Single quotes (')**:
- Preserves literal values
- Use when you need special characters to be treated literally
- Escape single quotes by doubling them: `'It''s working'`
```yaml
literal_string: 'The newline character \n is not processed here'
path: 'C:\Users\name' # Backslashes preserved
```
- **Double quotes (")**:
- Processes escape sequences like `\n`, `\t`, etc.
- Use for strings that need special characters interpreted
- Escape double quotes with backslash: `"He said \"Hello\""`
```yaml
processed_string: "First line\nSecond line" # Creates actual newline
unicode: "Copyright symbol: \u00A9" # Unicode character
```
This diff is collapsed.
......@@ -4,4 +4,4 @@ import os
from .evaluator import evaluate, simple_evaluate
__version__ = "0.4.8"
__version__ = "0.4.9"
......@@ -164,7 +164,7 @@ def setup_parser() -> argparse.ArgumentParser:
type=str,
action=TrackExplicitAction,
metavar="DIR|DIR/file.json",
help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
help="Path where result metrics will be saved. Can be either a directory or a .json file. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
)
parser.add_argument(
"--limit",
......
......@@ -1481,7 +1481,10 @@ class ConfigurableTask(Task):
# here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice.
aux_arguments = [("", f"{choice}") for choice in choices]
# TODO: should these be strided? will have to modify the processing in process_results if so
aux_arguments = [
("", f"{target_delimiter}{choice}") for choice in choices
]
arguments.extend(aux_arguments)
......@@ -1580,11 +1583,12 @@ class ConfigurableTask(Task):
):
# then we are doing mutual info.
# this stores the "dryrun" / unconditional answer loglikelihoods
lls_unconditional = lls[1::2]
# as we extend the args list with unconditional ("", continuation) pairs
lls_unconditional = lls[len(choices) :]
if len(lls_unconditional) != len(choices):
raise ValueError
# and this stores our "regular" conditional loglikelihoods
lls = lls[::2]
lls = lls[: len(choices)]
pred = np.argmax(lls)
pred_norm = np.argmax(lls / completion_len)
......
......@@ -54,6 +54,51 @@ class RegexFilter(Filter):
return filtered
filtered_resps = list(map(lambda x: filter_set(x), resps))
return filtered_resps
@register_filter("regex_pos")
class POSFilter(Filter):
""" """
def __init__(
self,
regex_pattern: str = r"\['(.*?)'\]",
group_select=0,
fallback=None,
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
if fallback is None:
fallback = ["invalid"]
self.regex_pattern = regex_pattern
self.regex = re.compile(regex_pattern)
self.group_select = group_select
self.fallback = fallback
def apply(self, resps, docs):
def extract_tagged_tokens(text):
# Extract tagged tokens list from text input using regex
tokens = re.findall(r"\('([^']*)', '([^']*)'\)", text)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result):
pos_tags = []
if isinstance(result, str):
result = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result)
return pos_tags if pos_tags else self.fallback
def filter_set(inst):
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
filtered_resps = map(lambda x: filter_set(x), resps)
return filtered_resps
......
import re
from lm_eval.api.filter import Filter
from lm_eval.api.registry import register_filter
......@@ -54,3 +56,67 @@ class MapFilter(Filter):
return [self.mapping_dict.get(resp, self.default_value) for resp in inst]
return [filter_set(resp) for resp in resps]
@register_filter("format_span")
class SPANFilter(Filter):
def __init__(self) -> None:
pass
def apply(self, resps, docs):
def format_ner_text(text):
label_dict = {
"person": "PER",
"location": "LOC",
"organization": "ORG",
"counties": "LOC",
"places": "LOC",
"people": "PER",
"persons": "PER",
"company": "ORG",
"country": "LOC",
"continent": "LOC",
"time": "DATE",
"date": "DATE",
"per": "PER",
"loc": "LOC",
"org": "ORG",
}
text = text.lower()
for key, value in label_dict.items():
text = text.replace(key, value)
text = "$".join(i for i in text.split("$$"))
return text.rstrip("$$")
def format_named_entities(text):
"""
Extract named entities from text and format them as 'label: value $$ label: value'.
Handles grouped entities (e.g., LOC: kenya, uganda) and excludes 'none' values.
"""
# Regular expression to match label: entities pattern
pattern = r"\b(PER|LOC|ORG|DATE):\s*([^$]+)"
# Normalize newline characters
text = text.replace("\n", "$").strip()
matches = re.findall(pattern, text)
formatted_entities = []
for label, values in matches:
# Split multiple entities separated by commas and strip whitespace
entities = [value.strip() for value in values.split(",")]
# Exclude 'none' entities
for entity in entities:
if entity.lower() != "none":
formatted_entities.append(f"{label.lower()}: {entity}")
# Join entities with the desired separator
return " $ ".join(formatted_entities)
def filter_set(inst):
return [
format_named_entities(format_ner_text(resp.lower())) for resp in inst
]
return [filter_set(resp) for resp in resps]
......@@ -229,11 +229,21 @@ class EvaluationTracker:
)
path = Path(self.output_path if self.output_path else Path.cwd())
path = path.joinpath(self.general_config_tracker.model_name_sanitized)
path.mkdir(parents=True, exist_ok=True)
self.date_id = datetime.now().isoformat().replace(":", "-")
file_results_aggregated = path.joinpath(f"results_{self.date_id}.json")
if path.suffix == ".json":
path.parent.mkdir(parents=True, exist_ok=True)
file_results_aggregated = path.with_name(
f"{path.stem}_{self.date_id}.json"
)
else:
path = path.joinpath(
self.general_config_tracker.model_name_sanitized
)
path.mkdir(parents=True, exist_ok=True)
file_results_aggregated = path.joinpath(
f"results_{self.date_id}.json"
)
file_results_aggregated.open("w", encoding="utf-8").write(dumped)
if self.api and self.push_results_to_hub:
......@@ -250,12 +260,10 @@ class EvaluationTracker:
)
self.api.upload_file(
repo_id=repo_id,
path_or_fileobj=str(
path.joinpath(f"results_{self.date_id}.json")
),
path_or_fileobj=str(file_results_aggregated),
path_in_repo=os.path.join(
self.general_config_tracker.model_name,
f"results_{self.date_id}.json",
file_results_aggregated.name,
),
repo_type="dataset",
commit_message=f"Adding aggregated results for {self.general_config_tracker.model_name}",
......@@ -290,7 +298,12 @@ class EvaluationTracker:
eval_logger.info(f"Saving per-sample results for: {task_name}")
path = Path(self.output_path if self.output_path else Path.cwd())
path = path.joinpath(self.general_config_tracker.model_name_sanitized)
if path.suffix == ".json":
path = path.parent
else:
path = path.joinpath(
self.general_config_tracker.model_name_sanitized
)
path.mkdir(parents=True, exist_ok=True)
file_results_samples = path.joinpath(
......
......@@ -16,6 +16,7 @@ from . import (
optimum_ipex,
optimum_lm,
sglang_causallms,
sglang_generate_API,
textsynth,
vllm_causallms,
vllm_vlms,
......
......@@ -6,6 +6,7 @@ import json
import logging
from functools import cached_property
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
......@@ -30,7 +31,9 @@ except ModuleNotFoundError:
pass
import base64
from importlib.util import find_spec
from io import BytesIO
from lm_eval import utils
from lm_eval.api.instance import Instance
......@@ -38,6 +41,10 @@ from lm_eval.api.model import TemplateLM
from lm_eval.models.utils import Collator, chunks, configure_pad_token
if TYPE_CHECKING:
from PIL import Image
eval_logger = logging.getLogger(__name__)
LogLikelihoodInputs = Tuple[Tuple[str, str], List[int], List[int]]
......@@ -51,7 +58,52 @@ class JsonChatStr(NamedTuple):
return self.prompt.encode(encoding)
def create_image_prompt(
imgs: list["Image.Image"], chat: dict, fmt: str = "PNG"
) -> dict:
"""
Parameters
----------
img : list[PIL.Image.Image]
The list of images to encode to base64
chat : dict
fmt : str, optional
Any format Pillow understands (e.g. "PNG", "JPEG").
Defaults to "PNG".
Returns
-------
dict
"""
images = []
for img in imgs:
buf = BytesIO()
img.save(buf, format=fmt)
img_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
img_dict = {
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{img_b64}", "detail": "auto"},
}
images.append(img_dict)
# chat is in format of list[dict["role": "user"/"system", "content": str, "type": "text"],...]
# with images, we need "content" to be a list of dicts with "type" and "text"/"image_url"
# currently we do not support few-shots so only one user message
# text content also has <image> placeholders, which apparently is not necessary for API class (confirm)
if isinstance(chat[-1]["content"], list):
chat[-1]["content"] = images + chat[-1]["content"]
else:
text_content = {"type": "text", "text": chat[-1]["content"]}
chat[-1]["content"] = images + [text_content]
chat[-1].pop("type")
return chat
class TemplateAPI(TemplateLM):
MULTIMODAL = True
def __init__(
self,
model: str = None,
......@@ -83,6 +135,7 @@ class TemplateAPI(TemplateLM):
eos_string: str = None,
# timeout in seconds
timeout: int = 300,
max_images: int = 1,
**kwargs,
) -> None:
super().__init__()
......@@ -129,6 +182,7 @@ class TemplateAPI(TemplateLM):
self.verify_certificate = verify_certificate
self._eos_string = eos_string
self.timeout = int(timeout)
self.max_images = int(max_images)
eval_logger.info(f"Using tokenizer {self.tokenizer_backend}")
if self.tokenizer_backend is None:
......@@ -265,7 +319,12 @@ class TemplateAPI(TemplateLM):
)
else:
# bit of a hack. We'll load back before sending to the API
return JsonChatStr(json.dumps(chat_history, ensure_ascii=False))
return JsonChatStr(
json.dumps(
[{**item, "type": "text"} for item in chat_history],
ensure_ascii=False,
)
)
@cached_property
def eot_token_id(self) -> Optional[int]:
......@@ -578,7 +637,28 @@ class TemplateAPI(TemplateLM):
return -len(_requests[0])
# Let the API deal with tokenization
requests, all_gen_kwargs = zip(*(req.args for req in requests))
if len(requests[0].args) > 2:
assert self.tokenizer is None, (
"tokenizer is not supported for multimodal requests yet!"
)
eval_logger.info(
f"Using max_images {self.max_images}. Set in the model args."
)
requests, all_gen_kwargs, auxiliary_args = zip(
*(req.args for req in requests)
)
requests = tuple(
JsonChatStr(
json.dumps(
create_image_prompt(
y["visual"][: self.max_images], json.loads(x.prompt)
)
)
)
for x, y in zip(requests, auxiliary_args)
)
else:
requests, all_gen_kwargs = zip(*(req.args for req in requests))
if self.tokenized_requests:
encodings_list = self.tok_encode(
requests, add_special_tokens=self.add_bos_token
......@@ -597,6 +677,10 @@ class TemplateAPI(TemplateLM):
chunked = re_ord.get_batched(
n=self._batch_size if self._concurrent <= 1 else 0, batch_fn=None
)
if not self.tokenized_requests:
eval_logger.info(
"Tokenized requests are disabled. Context + generation length is not checked."
)
if self._concurrent <= 1:
pbar = tqdm(desc="Requesting API", total=len(requests))
for chunk in chunked:
......@@ -615,10 +699,7 @@ class TemplateAPI(TemplateLM):
eval_logger.warning(
f"Some contexts exceeded (max length: ({self.max_length}) - max_gen_toks: ({max_gen_toks}). They were left truncated."
)
else:
eval_logger.info(
"Tokenized requests are disabled. Context + generation length is not checked."
)
req = encodings_list if self.tokenized_requests else contexts
outputs = retry(
stop=stop_after_attempt(self.max_retries),
......@@ -664,10 +745,7 @@ class TemplateAPI(TemplateLM):
eval_logger.warning(
f"Some contexts exceeded (max length: ({self.max_length}) - max_gen_toks ({max_gen_toks}). They were left truncated."
)
else:
eval_logger.info(
"Tokenized requests are disabled. Context + generation length is not checked."
)
req = encodings_list if self.tokenized_requests else contexts
results = itertools.chain.from_iterable(
asyncio.run(
......
......@@ -17,6 +17,7 @@ from lm_eval.models.utils import (
handle_stop_sequences,
pad_and_concat,
replace_placeholders,
resize_image,
stop_sequences_criteria,
)
......@@ -45,10 +46,23 @@ class HFMultimodalLM(HFLM):
# TODO: handle whitespace in image placeholder (replacement)
max_images: Optional[int] = 999,
convert_img_format=False,
# For image resizing
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
image_width: Optional[int] = None,
image_height: Optional[int] = None,
image_max_side: Optional[int] = None,
**kwargs,
):
self.image_width = image_width
self.image_height = image_height
self.image_max_side = image_max_side
if self.image_max_side and (self.image_width or self.image_height):
raise ValueError(
"Ambiguous config for image resize: you can not specify both "
"image_max_side and (image_width or image_height)"
)
# init pixels before calling tokenizer creation to avoid errors
self.pixels = ({"min_pixels": min_pixels} if min_pixels else {}) | (
{"max_pixels": max_pixels} if max_pixels else {}
......@@ -385,6 +399,9 @@ class HFMultimodalLM(HFLM):
return batched_imgs
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
if requests and len(requests[0].args) < 3:
# Fall back to non-multimodal generation.
return super().loglikelihood_rolling(requests=requests)
raise NotImplementedError(
"model type `hf-multimodal` does not support loglikelihood_rolling. Use 'hf' model type for text-only loglikelihood_rolling tasks ",
"this is because we do not support measuring the loglikelihood a model assigns to an image.",
......@@ -393,6 +410,9 @@ class HFMultimodalLM(HFLM):
def loglikelihood(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[Tuple[float, bool]]:
if requests and len(requests[0].args) < 3:
# Fall back to non-multimodal generation.
return super().loglikelihood(requests=requests, disable_tqdm=disable_tqdm)
raise NotImplementedError(
"'loglikelihood' requests for model type `hf-multimodal` are not yet tested. This feature will be enabled when a loglikelihood-based multiple-choice VQA dataset is added!"
)
......@@ -419,9 +439,11 @@ class HFMultimodalLM(HFLM):
)
)
return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm)
return self._multimodal_loglikelihood_tokens(
new_reqs, disable_tqdm=disable_tqdm
)
def _loglikelihood_tokens(
def _multimodal_loglikelihood_tokens(
self,
requests: List[
Tuple[Tuple[None, str, str], List[int], List[int], List[int]]
......@@ -610,7 +632,10 @@ class HFMultimodalLM(HFLM):
def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
# TODO: back out to HFLM.generate_until() for all requests without aux_arguments (text-only reqs)
if requests and len(requests[0].args) < 3:
# Fall back to non-multimodal generation.
return super().generate_until(requests=requests, disable_tqdm=disable_tqdm)
res = []
def _collate(x):
......@@ -646,7 +671,15 @@ class HFMultimodalLM(HFLM):
for chunk in chunks:
contexts, all_gen_kwargs, aux_arguments = zip(*chunk)
visuals = [arg["visual"] for arg in aux_arguments]
visuals = [
[
resize_image(
img, self.image_width, self.image_height, self.image_max_side
)
for img in arg["visual"]
]
for arg in aux_arguments
]
if not isinstance(contexts, list):
contexts = list(
......
......@@ -890,7 +890,10 @@ class HFLM(TemplateLM):
input_ids=inps, attention_mask=attn_mask, labels=labels
).logits
else:
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
assert self.AUTO_MODEL_CLASS in (
transformers.AutoModelForCausalLM,
transformers.AutoModelForVision2Seq,
)
return self.model(inps).logits
def _model_generate(self, context, max_length, stop, **generation_kwargs):
......@@ -1136,7 +1139,7 @@ class HFLM(TemplateLM):
if self.backend == "causal":
total_length = len(context_enc) + len(continuation_enc)
if total_length > self.max_length + 1:
eval_logger.warn(
eval_logger.warning(
f"Combined length of context ({len(context_enc)}) and continuation ({len(continuation_enc)}) "
f"exceeds model's maximum length ({self.max_length}). "
f"Truncating {total_length - self.max_length + 1} tokens from the left."
......@@ -1247,7 +1250,12 @@ class HFLM(TemplateLM):
cont_toks = torch.tensor(
cont_toks, dtype=torch.long, device=self.device
).unsqueeze(0) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
# Use trailing slice [-cont_toks.shape[1]:] to handle variable length cont_len (but same ctx+cont[:-1]).
# i.e. continuations can be sliced at diff points. Collator ensures we have sufficient greedy_tokens
# by choosing key with longest cont if group_by="contexts".
max_equal = (
greedy_tokens[:, -cont_toks.shape[1] :] == cont_toks
).all()
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
......
from typing import Dict, List, Optional, Tuple, Union
from lm_eval.api.registry import register_model
from lm_eval.models.openai_completions import LocalCompletionsAPI
from lm_eval.models.utils import handle_stop_sequences
@register_model("sglang-generate")
class SGLANGGENERATEAPI(LocalCompletionsAPI):
def __init__(
self,
base_url=None,
tokenizer_backend="huggingface",
**kwargs,
):
super().__init__(
base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
)
def _create_payload(
self,
messages: Union[List[List[int]], List[dict], List[str], str],
generate=False,
gen_kwargs: Optional[dict] = None,
seed: int = 1234,
eos=None,
**kwargs,
) -> dict:
is_string = (
True
if (isinstance(messages, str) or isinstance(messages[0], str))
else False
)
if generate:
gen_kwargs.pop("do_sample", False)
if "max_tokens" in gen_kwargs:
max_tokens = gen_kwargs.pop("max_tokens")
else:
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
temperature = gen_kwargs.pop("temperature", 0)
stop = handle_stop_sequences(gen_kwargs.pop("until", None), eos)
request = {
"sampling_params": {
"max_new_tokens": max_tokens,
"temperature": temperature,
"stop": stop,
**gen_kwargs,
},
}
request.update({"text": messages}) if is_string else request.update(
{"input_ids": messages}
)
return request
else:
assert not is_string, "Logprobs are only supported for tokenized inputs"
request = {
"input_ids": messages,
"sampling_params": {"max_new_tokens": 1, "temperature": 0},
"logprob_start_len": 0,
"top_logprobs_num": 1,
"return_logprob": True,
}
return request
@staticmethod
def parse_logprobs(
outputs: Union[Dict, List[Dict]],
tokens: List[List[int]] = None,
ctxlens: List[int] = None,
**kwargs,
) -> List[Tuple[float, bool]]:
res = []
if not isinstance(outputs, list):
outputs = [outputs]
for choice, ctxlen in zip(outputs, ctxlens):
choice = choice["meta_info"]
assert ctxlen > 0, "Context length must be greater than 0"
logprobs = sum(x[0] for x in choice["input_token_logprobs"][ctxlen:])
is_greedy = all(
x[1] != y[0][1]
for x, y in zip(
choice["input_token_logprobs"][ctxlen:],
choice["input_top_logprobs"][ctxlen:],
)
)
res.append((logprobs, is_greedy))
return res
@staticmethod
def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]:
res = []
if not isinstance(outputs, list):
outputs = [outputs]
for out in outputs:
res.append(out["text"])
return res
@property
def api_key(self):
return ""
......@@ -28,6 +28,7 @@ eval_logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from PIL import Image
from transformers import PreTrainedTokenizerBase
from transformers.configuration_utils import PretrainedConfig
......@@ -427,9 +428,13 @@ class Collator:
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
elif self._group_by == "contexts":
# Get one sample from each key
# Get one sample from each key.
# Select longest continuation per group to ensure sufficient context logits
values = self._reorder(
[value[0] for value in self._arr_with_indices.values()]
[
max(value, key=lambda x: len(x[1][-1]))
for value in self._arr_with_indices.values()
]
)
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
......@@ -729,3 +734,121 @@ def handle_stop_sequences(
if eos is not None and eos not in until:
until.append(eos)
return until
def resize_image(
image: "Image.Image",
width: Optional[int] = None,
height: Optional[int] = None,
max_dimension: Optional[int] = None,
keep_aspect_ratio: bool = True,
resample_filter: Union[int, str] = "Image.BICUBIC",
min_width: int = 1,
min_height: int = 1,
) -> "Image.Image":
"""
Resizes a PIL Image object with flexible options.
Args:
image: The PIL Image object to resize.
width: Target width in pixels.
height: Target height in pixels.
max_dimension: Maximum size for the longer dimension of the image.
keep_aspect_ratio: If True (default) and both width and height are provided,
the image is resized to fit within these dimensions while
maintaining its aspect ratio. If False, the image is stretched
to the exact width and height.
resample_filter: The resampling filter to use for resizing.
Defaults to Image.BICUBIC.
min_width: Minimum width for the resized image. Defaults to 1.
min_height: Minimum height for the resized image. Defaults to 1.
Returns:
The resized PIL Image object. If no resize parameters are provided
or if the image already meets the criteria, the original image is returned.
Order of precedence for resizing:
1. If width AND height are provided:
- If keep_aspect_ratio is True: Fits image within bounds, preserving aspect ratio.
- If keep_aspect_ratio is False: Resizes to exact dimensions (may distort).
2. Else if only width is provided: Calculates height proportionally.
3. Else if only height is provided: Calculates width proportionally.
4. Else if max_dimension is provided: Resizes the longest side to max_dimension
and scales the other side proportionally.
5. If none of the above are provided, returns the original image.
"""
original_width, original_height = image.size
# If no arguments are provided, return the original image
if width is None and height is None and max_dimension is None:
return image
new_width = original_width
new_height = original_height
if width is not None and height is not None:
# No resize needed if image is already smaller than target dimensions
if original_width <= width and original_height <= height:
return image
if keep_aspect_ratio:
# Calculate the ratio to fit within the target dimensions
ratio = min(width / original_width, height / original_height)
new_width = int(original_width * ratio)
new_height = int(original_height * ratio)
else:
# Stretch to exact dimensions
new_width = width
new_height = height
elif width is not None:
# No resize needed if width is already smaller
if original_width <= width:
return image
# Calculate height proportionally
new_width = width
new_height = int((original_height / original_width) * new_width)
elif height is not None:
# No resize needed if height is already smaller
if original_height <= height:
return image
# Calculate width proportionally
new_height = height
new_width = int((original_width / original_height) * new_height)
elif max_dimension is not None:
# No resize needed if both dimensions are smaller than max_dimension
if max(original_height, original_width) <= max_dimension:
return image
if original_width > original_height:
# Width is the longer side
new_width = max_dimension
new_height = int((original_height / original_width) * new_width)
else:
# Height is the longer side or sides are equal
new_height = max_dimension
new_width = int((original_width / original_height) * new_height)
# Ensure dimensions are at least minimum values
new_width = max(min_width, new_width)
new_height = max(min_height, new_height)
# Perform the resize operation with the calculated dimensions
return image.resize((new_width, new_height), resample_filter)
def truncate_tokens(
tokens: List[int],
max_length: int,
tokenizer: "PreTrainedTokenizerBase",
strategy: str = "left",
):
if strategy == "left":
return tokens[-max_length:]
elif strategy == "right":
return tokens[:max_length]
elif strategy == "middle":
# Truncate the middle of the sequence
left_length = max_length // 2
right_length = max_length - left_length
return tokens[:left_length] + tokens[-right_length:]
return None
import copy
import gc
import inspect
import logging
import os
from importlib.metadata import version
from importlib.util import find_spec
from multiprocessing import Process, Queue
from queue import Empty
from time import sleep
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
from more_itertools import distribute
......@@ -28,6 +34,7 @@ try:
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import get_open_port
if parse_version(version("vllm")) >= parse_version("0.8.3"):
from vllm.entrypoints.chat_utils import resolve_hf_chat_template
......@@ -40,6 +47,63 @@ if TYPE_CHECKING:
eval_logger = logging.getLogger(__name__)
def _vllm_mp_worker(
model_args: dict,
sampling_params: "SamplingParams",
requests: list[list[int]],
lora_request: "LoRARequest",
result_queue: "Queue",
dp_size: int,
local_dp_rank: int,
dp_master_port: int,
dp_master_ip: str = "127.0.0.1",
) -> None:
"""
Worker process for vLLM multiprocessing.
Initializes a vLLM engine, processes requests, and puts results or errors
onto the result_queue.
"""
if not requests:
result_queue.put((local_dp_rank, []))
return None
os.environ["VLLM_DP_RANK"] = os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size)
os.environ["VLLM_DP_MASTER_IP"] = str(dp_master_ip)
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
llm = None
try:
llm = LLM(**model_args)
res = llm.generate(
prompt_token_ids=requests,
sampling_params=sampling_params,
lora_request=lora_request,
)
# Give engines time to pause their processing loops before exiting."
sleep(1)
result_queue.put((local_dp_rank, res))
except Exception as e:
error_message = f"Worker {local_dp_rank} failed during generation: {type(e).__name__}: {str(e)}"
eval_logger.error(error_message, exc_info=True)
result_queue.put((local_dp_rank, {"error": error_message}))
finally:
if llm is not None:
try:
del llm
gc.collect()
except Exception as e_cleanup:
eval_logger.warning(
f"Worker {local_dp_rank} encountered an error during LLM cleanup: {type(e_cleanup).__name__}: {str(e_cleanup)}",
exc_info=True,
)
return None
@register_model("vllm")
class VLLM(TemplateLM):
_DEFAULT_MAX_LENGTH = 2048
......@@ -68,6 +132,7 @@ class VLLM(TemplateLM):
device: str = "cuda",
data_parallel_size: int = 1,
lora_local_path: str = None,
enable_thinking: bool = False,
**kwargs,
):
super().__init__()
......@@ -81,7 +146,7 @@ class VLLM(TemplateLM):
assert max_length is None or max_model_len is None, (
"Either max_length or max_model_len may be provided, but not both"
)
self.V1 = os.environ.get("VLLM_USE_V1", "1") != "0"
self._max_length = max_model_len if max_model_len is not None else max_length
self.tensor_parallel_size = int(tensor_parallel_size)
self.data_parallel_size = int(data_parallel_size)
......@@ -96,9 +161,11 @@ class VLLM(TemplateLM):
"trust_remote_code": trust_remote_code,
"tensor_parallel_size": int(tensor_parallel_size),
"max_model_len": int(self._max_length) if self._max_length else None,
"max_num_seqs": kwargs.get("max_num_seqs", max_batch_size),
"swap_space": int(swap_space),
"quantization": quantization,
"seed": int(seed),
"device": str(device),
}
self.model_args.update(kwargs)
self.batch_size = (
......@@ -112,7 +179,11 @@ class VLLM(TemplateLM):
eval_logger.warning(
"You might experience occasional issues with model weight downloading when data_parallel is in use. To ensure stable performance, run with data_parallel_size=1 until the weights are downloaded and cached."
)
self.model_args["distributed_executor_backend"] = "ray"
self.model_args["distributed_executor_backend"] = (
"ray"
if not self.V1
else self.model_args.get("distributed_executor_backend", None)
)
self.batch_size = "auto"
eval_logger.info("Manual batching is not compatible with data parallelism.")
......@@ -129,6 +200,7 @@ class VLLM(TemplateLM):
add_bos_token=add_bos_token,
)
self.tokenizer = configure_pad_token(self.tokenizer, model_config=self._config)
self.enable_thinking = enable_thinking
self.add_bos_token = add_bos_token
if "gemma" in pretrained.lower():
self.add_bos_token = True
......@@ -137,11 +209,36 @@ class VLLM(TemplateLM):
)
if parse_version(version("vllm")) >= parse_version("0.8.3"):
kwargs_resolve_hf_chat_template = {
"tokenizer": self.tokenizer,
"chat_template": None,
"tools": None,
}
if parse_version(version("vllm")) >= parse_version("0.9.0"):
if self.data_parallel_size <= 1:
kwargs_resolve_hf_chat_template["model_config"] = (
self.model.llm_engine.model_config
)
else:
from vllm.engine.arg_utils import EngineArgs
engine_args = EngineArgs(**self.model_args)
model_config = engine_args.create_model_config()
kwargs_resolve_hf_chat_template["model_config"] = model_config
# https://github.com/vllm-project/vllm/pull/18259
if (
"trsut_remote_code"
in inspect.signature(resolve_hf_chat_template).parameters
):
kwargs_resolve_hf_chat_template["trsut_remote_code"] = trust_remote_code
else:
kwargs_resolve_hf_chat_template["trust_remote_code"] = trust_remote_code
self.hf_chat_template = resolve_hf_chat_template(
tokenizer=self.tokenizer,
chat_template=None,
tools=None,
trust_remote_code=trust_remote_code,
**kwargs_resolve_hf_chat_template
)
else:
self.hf_chat_template = None
......@@ -209,6 +306,7 @@ class VLLM(TemplateLM):
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
chat_template=self.hf_chat_template,
enable_thinking=self.enable_thinking,
)
return chat_templated
......@@ -257,7 +355,7 @@ class VLLM(TemplateLM):
sampling_params = SamplingParams(
temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False
)
if self.data_parallel_size > 1:
if self.data_parallel_size > 1 and not self.V1:
# vLLM hangs if resources are set in ray.remote
# also seems to only work with decorator and not with ray.remote() fn
# see https://github.com/vllm-project/vllm/issues/973
......@@ -288,14 +386,83 @@ class VLLM(TemplateLM):
ray.shutdown()
# flatten results
return undistribute(results)
elif self.data_parallel_size > 1:
# based on https://github.com/vllm-project/vllm/blob/a04720bc36401d831cb048c3917b9e58173d9c1d/examples/offline_inference/data_parallel.py
dp_size = self.data_parallel_size
dp_master_ip = os.environ.get("VLLM_DP_MASTER_IP", "127.0.0.1")
dp_master_port = os.environ.get("VLLM_DP_MASTER_PORT") or get_open_port()
requests = (list(x) for x in distribute(self.data_parallel_size, requests))
procs, resq = [], Queue()
# We use Process as it is non-daemonic
try:
for rank, req in enumerate(requests):
proc = Process(
target=_vllm_mp_worker,
args=(
self.model_args.copy(),
sampling_params,
req,
self.lora_request,
resq,
dp_size,
rank,
dp_master_port,
dp_master_ip,
),
)
proc.start()
procs.append(proc)
# Collect results
rank_res = {}
while len(rank_res) < len(procs):
try:
rank, result = resq.get(timeout=30)
if isinstance(result, dict) and "error" in result:
raise RuntimeError(result["error"])
rank_res[rank] = result
except Empty:
dead_procs = [
idx
for idx, p in enumerate(procs)
if not p.is_alive() and idx not in rank_res
]
if dead_procs:
raise RuntimeError(
f"Worker processes {dead_procs} died unexpectedly"
)
continue
results = [rank_res[i] for i in range(len(procs))]
return undistribute(results)
# cleanup
finally:
try:
resq.close()
resq.join_thread()
except Exception:
eval_logger.debug(
"Failed to close vllm DP results queue", exc_info=True
)
for proc in procs:
proc.join(timeout=10)
if proc.is_alive():
proc.terminate()
proc.join(timeout=5)
if proc.is_alive():
proc.kill()
outputs = self.model.generate(
prompt_token_ids=requests,
sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
lora_request=self.lora_request,
)
return outputs
else:
outputs = self.model.generate(
prompt_token_ids=requests,
sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
lora_request=self.lora_request,
)
return outputs
def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False
......@@ -427,6 +594,12 @@ class VLLM(TemplateLM):
# set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
all_lengths = [len(x) for x in context_encoding]
for length in all_lengths:
if length > max_ctx_len:
eval_logger.warning(
f"Context length {length} exceeds max length (context + max gen tokens): {max_ctx_len}. Truncating context."
)
context_encoding = [x[-max_ctx_len:] for x in context_encoding]
# perform batched generation
......@@ -441,6 +614,10 @@ class VLLM(TemplateLM):
# cache generations
for output, context in zip(cont, context):
generated_text = output.outputs[0].text
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until:
if len(term) > 0:
generated_text = generated_text.split(term)[0]
res.append(generated_text)
self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), generated_text
......@@ -477,6 +654,12 @@ class VLLM(TemplateLM):
inputs = []
ctxlens = []
for cache_key, context_enc, continuation_enc in chunk:
if (
full_length := len(context_enc + continuation_enc)
) > self.max_length:
eval_logger.warning(
f"Context length {full_length} exceeds max length ({self.max_length}). Truncating context."
)
inp = (context_enc + continuation_enc)[-(self.max_length) :]
ctxlen = len(context_enc) - max(
0, len(context_enc) + len(continuation_enc) - (self.max_length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment