Commit 601be343 authored by Baber's avatar Baber
Browse files

Merge branch 'main' into feature/eval_from_config

parents d0884a96 68c3a811
...@@ -79,36 +79,36 @@ jobs: ...@@ -79,36 +79,36 @@ jobs:
path: | path: |
test_logs/* test_logs/*
testmodels: # testmodels:
name: External LM Tests # name: External LM Tests
runs-on: ubuntu-latest # runs-on: ubuntu-latest
timeout-minutes: 30 # timeout-minutes: 30
steps: # steps:
- name: Checkout Code # - name: Checkout Code
uses: actions/checkout@v4 # uses: actions/checkout@v4
- name: Set up Python 3.9 # - name: Set up Python 3.9
uses: actions/setup-python@v5 # uses: actions/setup-python@v5
with: # with:
python-version: 3.9 # python-version: 3.9
cache: pip # cache: pip
cache-dependency-path: pyproject.toml # cache-dependency-path: pyproject.toml
#
# Cache HuggingFace cache directory for External LM tests # # Cache HuggingFace cache directory for External LM tests
- name: Cache HuggingFace cache (External LM tests) # - name: Cache HuggingFace cache (External LM tests)
uses: actions/cache@v3 # uses: actions/cache@v3
id: cache-hf-lm # id: cache-hf-lm
with: # with:
path: ~/.cache/huggingface # path: ~/.cache/huggingface
key: ${{ runner.os }}-hf-cache-external-lm # key: ${{ runner.os }}-hf-cache-external-lm
restore-keys: | # restore-keys: |
${{ runner.os }}-hf-cache-external-lm # ${{ runner.os }}-hf-cache-external-lm
#
- name: Install dependencies # - name: Install dependencies
run: | # run: |
python -m pip install --upgrade pip # python -m pip install --upgrade pip
pip install -e '.[dev,optimum,deepsparse,sparseml,api]' --extra-index-url https://download.pytorch.org/whl/cpu # pip install -e '.[dev,optimum,deepsparse,sparseml,api]' --extra-index-url https://download.pytorch.org/whl/cpu
pip install -U transformers peft accelerate # pip install -U transformers peft accelerate
#
- name: Test with pytest # - name: Test with pytest
run: python -m pytest tests/models --showlocals -s -vv # run: python -m pytest tests/models --showlocals -s -vv
continue-on-error: true # Continue workflow even if tests fail # continue-on-error: true # Continue workflow even if tests fail
env # macOS system files
*.pyc .DS_Store
output/
data/ # Virtual environments
lm_cache
.idea
build
dist
*.egg-info
venv
.venv/ .venv/
venv/
ENV/
env/
*.env
# Python bytecode and build artifacts
__pycache__/
*.py[cod]
*.so
*.egg-info/
build/
dist/
# IDE & editor settings
.vscode/ .vscode/
temp .idea/
__pycache__
.ipynb_checkpoints # Jupyter
temp .ipynb_checkpoints/
test_logs/
# IPython
profile_default/ profile_default/
ipython_config.py ipython_config.py
# don't track (the default location of) the cached requests
# Output and data
output/
data/
temp/
test_logs/
# Caching
lm_eval/caching/.cache lm_eval/caching/.cache
# don't track files created by wandb lm_cache/
wandb
examples/wandb # Logging
*.log
logs/
# wandb experiment tracking
wandb/
examples/wandb/
# PyInstaller
*.spec
...@@ -29,7 +29,7 @@ repos: ...@@ -29,7 +29,7 @@ repos:
- id: mixed-line-ending - id: mixed-line-ending
args: [--fix=lf] args: [--fix=lf]
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.0 rev: v0.11.10
hooks: hooks:
# Run the linter. # Run the linter.
- id: ruff - id: ruff
...@@ -50,7 +50,7 @@ repos: ...@@ -50,7 +50,7 @@ repos:
rev: v0.9.29 rev: v0.9.29
hooks: hooks:
- id: pymarkdown - id: pymarkdown
exclude: ^lm_eval/tasks/ exclude: ^(lm_eval/tasks/.*|docs/footguns\.md)$
args: [fix, -r] args: [fix, -r]
# - repo: https://github.com/pre-commit/mirrors-mypy # - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.5.1 # rev: v1.5.1
......
recursive-include tests
...@@ -614,7 +614,7 @@ Extras dependencies can be installed via `pip install -e ".[NAME]"` ...@@ -614,7 +614,7 @@ Extras dependencies can be installed via `pip install -e ".[NAME]"`
```text ```text
@misc{eval-harness, @misc{eval-harness,
author = {Gao, Leo and Tow, Jonathan and Abbasi, Baber and Biderman, Stella and Black, Sid and DiPofi, Anthony and Foster, Charles and Golding, Laurence and Hsu, Jeffrey and Le Noac'h, Alain and Li, Haonan and McDonell, Kyle and Muennighoff, Niklas and Ociepa, Chris and Phang, Jason and Reynolds, Laria and Schoelkopf, Hailey and Skowron, Aviya and Sutawika, Lintang and Tang, Eric and Thite, Anish and Wang, Ben and Wang, Kevin and Zou, Andy}, author = {Gao, Leo and Tow, Jonathan and Abbasi, Baber and Biderman, Stella and Black, Sid and DiPofi, Anthony and Foster, Charles and Golding, Laurence and Hsu, Jeffrey and Le Noac'h, Alain and Li, Haonan and McDonell, Kyle and Muennighoff, Niklas and Ociepa, Chris and Phang, Jason and Reynolds, Laria and Schoelkopf, Hailey and Skowron, Aviya and Sutawika, Lintang and Tang, Eric and Thite, Anish and Wang, Ben and Wang, Kevin and Zou, Andy},
title = {A framework for few-shot language model evaluation}, title = {The Language Model Evaluation Harness},
month = 07, month = 07,
year = 2024, year = 2024,
publisher = {Zenodo}, publisher = {Zenodo},
......
# Common Pitfalls and Troubleshooting Guide
This document highlights common pitfalls and troubleshooting tips when using this library. We'll continue to add more tips as we discover them.
## YAML Configuration Issues
### Newline Characters in YAML (`\n`)
**Problem:** When specifying newline characters in YAML, they may be interpreted incorrectly depending on how you format them.
```yaml
# ❌ WRONG: Single quotes don't process escape sequences
generation_kwargs:
until: ['\n'] # Gets parsed as the literal characters '\' and 'n' i.e "\\n"
```
```yaml
# ✅ RIGHT: Use double quotes for escape sequences
generation_kwargs:
until: ["\n"] # Gets parsed as an actual newline character
```
**Solutions:**
- Use double quotes for strings containing escape sequences
- For multiline content, use YAML's block scalars (`|` or `>`)
- When generating YAML programmatically, be careful with how template engines handle escape sequences
### Quoting in YAML
**When to use different types of quotes:**
- **No quotes**: Simple values (numbers, booleans, alphanumeric strings without special characters)
```yaml
simple_value: plain text
number: 42
```
- **Single quotes (')**:
- Preserves literal values
- Use when you need special characters to be treated literally
- Escape single quotes by doubling them: `'It''s working'`
```yaml
literal_string: 'The newline character \n is not processed here'
path: 'C:\Users\name' # Backslashes preserved
```
- **Double quotes (")**:
- Processes escape sequences like `\n`, `\t`, etc.
- Use for strings that need special characters interpreted
- Escape double quotes with backslash: `"He said \"Hello\""`
```yaml
processed_string: "First line\nSecond line" # Creates actual newline
unicode: "Copyright symbol: \u00A9" # Unicode character
```
This diff is collapsed.
...@@ -4,4 +4,4 @@ import os ...@@ -4,4 +4,4 @@ import os
from .evaluator import evaluate, simple_evaluate from .evaluator import evaluate, simple_evaluate
__version__ = "0.4.8" __version__ = "0.4.9"
...@@ -164,7 +164,7 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -164,7 +164,7 @@ def setup_parser() -> argparse.ArgumentParser:
type=str, type=str,
action=TrackExplicitAction, action=TrackExplicitAction,
metavar="DIR|DIR/file.json", metavar="DIR|DIR/file.json",
help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.", help="Path where result metrics will be saved. Can be either a directory or a .json file. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
) )
parser.add_argument( parser.add_argument(
"--limit", "--limit",
......
...@@ -1481,7 +1481,10 @@ class ConfigurableTask(Task): ...@@ -1481,7 +1481,10 @@ class ConfigurableTask(Task):
# here mutual info refers to calculating # here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice)) # log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice. # in other words normalizing by subtracting the unconditional logprob of each choice.
aux_arguments = [("", f"{choice}") for choice in choices] # TODO: should these be strided? will have to modify the processing in process_results if so
aux_arguments = [
("", f"{target_delimiter}{choice}") for choice in choices
]
arguments.extend(aux_arguments) arguments.extend(aux_arguments)
...@@ -1580,11 +1583,12 @@ class ConfigurableTask(Task): ...@@ -1580,11 +1583,12 @@ class ConfigurableTask(Task):
): ):
# then we are doing mutual info. # then we are doing mutual info.
# this stores the "dryrun" / unconditional answer loglikelihoods # this stores the "dryrun" / unconditional answer loglikelihoods
lls_unconditional = lls[1::2] # as we extend the args list with unconditional ("", continuation) pairs
lls_unconditional = lls[len(choices) :]
if len(lls_unconditional) != len(choices): if len(lls_unconditional) != len(choices):
raise ValueError raise ValueError
# and this stores our "regular" conditional loglikelihoods # and this stores our "regular" conditional loglikelihoods
lls = lls[::2] lls = lls[: len(choices)]
pred = np.argmax(lls) pred = np.argmax(lls)
pred_norm = np.argmax(lls / completion_len) pred_norm = np.argmax(lls / completion_len)
......
...@@ -54,6 +54,51 @@ class RegexFilter(Filter): ...@@ -54,6 +54,51 @@ class RegexFilter(Filter):
return filtered return filtered
filtered_resps = list(map(lambda x: filter_set(x), resps)) filtered_resps = list(map(lambda x: filter_set(x), resps))
return filtered_resps
@register_filter("regex_pos")
class POSFilter(Filter):
""" """
def __init__(
self,
regex_pattern: str = r"\['(.*?)'\]",
group_select=0,
fallback=None,
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
if fallback is None:
fallback = ["invalid"]
self.regex_pattern = regex_pattern
self.regex = re.compile(regex_pattern)
self.group_select = group_select
self.fallback = fallback
def apply(self, resps, docs):
def extract_tagged_tokens(text):
# Extract tagged tokens list from text input using regex
tokens = re.findall(r"\('([^']*)', '([^']*)'\)", text)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result):
pos_tags = []
if isinstance(result, str):
result = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result)
return pos_tags if pos_tags else self.fallback
def filter_set(inst):
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
filtered_resps = map(lambda x: filter_set(x), resps)
return filtered_resps return filtered_resps
......
import re
from lm_eval.api.filter import Filter from lm_eval.api.filter import Filter
from lm_eval.api.registry import register_filter from lm_eval.api.registry import register_filter
...@@ -54,3 +56,67 @@ class MapFilter(Filter): ...@@ -54,3 +56,67 @@ class MapFilter(Filter):
return [self.mapping_dict.get(resp, self.default_value) for resp in inst] return [self.mapping_dict.get(resp, self.default_value) for resp in inst]
return [filter_set(resp) for resp in resps] return [filter_set(resp) for resp in resps]
@register_filter("format_span")
class SPANFilter(Filter):
def __init__(self) -> None:
pass
def apply(self, resps, docs):
def format_ner_text(text):
label_dict = {
"person": "PER",
"location": "LOC",
"organization": "ORG",
"counties": "LOC",
"places": "LOC",
"people": "PER",
"persons": "PER",
"company": "ORG",
"country": "LOC",
"continent": "LOC",
"time": "DATE",
"date": "DATE",
"per": "PER",
"loc": "LOC",
"org": "ORG",
}
text = text.lower()
for key, value in label_dict.items():
text = text.replace(key, value)
text = "$".join(i for i in text.split("$$"))
return text.rstrip("$$")
def format_named_entities(text):
"""
Extract named entities from text and format them as 'label: value $$ label: value'.
Handles grouped entities (e.g., LOC: kenya, uganda) and excludes 'none' values.
"""
# Regular expression to match label: entities pattern
pattern = r"\b(PER|LOC|ORG|DATE):\s*([^$]+)"
# Normalize newline characters
text = text.replace("\n", "$").strip()
matches = re.findall(pattern, text)
formatted_entities = []
for label, values in matches:
# Split multiple entities separated by commas and strip whitespace
entities = [value.strip() for value in values.split(",")]
# Exclude 'none' entities
for entity in entities:
if entity.lower() != "none":
formatted_entities.append(f"{label.lower()}: {entity}")
# Join entities with the desired separator
return " $ ".join(formatted_entities)
def filter_set(inst):
return [
format_named_entities(format_ner_text(resp.lower())) for resp in inst
]
return [filter_set(resp) for resp in resps]
...@@ -229,11 +229,21 @@ class EvaluationTracker: ...@@ -229,11 +229,21 @@ class EvaluationTracker:
) )
path = Path(self.output_path if self.output_path else Path.cwd()) path = Path(self.output_path if self.output_path else Path.cwd())
path = path.joinpath(self.general_config_tracker.model_name_sanitized)
path.mkdir(parents=True, exist_ok=True)
self.date_id = datetime.now().isoformat().replace(":", "-") self.date_id = datetime.now().isoformat().replace(":", "-")
file_results_aggregated = path.joinpath(f"results_{self.date_id}.json") if path.suffix == ".json":
path.parent.mkdir(parents=True, exist_ok=True)
file_results_aggregated = path.with_name(
f"{path.stem}_{self.date_id}.json"
)
else:
path = path.joinpath(
self.general_config_tracker.model_name_sanitized
)
path.mkdir(parents=True, exist_ok=True)
file_results_aggregated = path.joinpath(
f"results_{self.date_id}.json"
)
file_results_aggregated.open("w", encoding="utf-8").write(dumped) file_results_aggregated.open("w", encoding="utf-8").write(dumped)
if self.api and self.push_results_to_hub: if self.api and self.push_results_to_hub:
...@@ -250,12 +260,10 @@ class EvaluationTracker: ...@@ -250,12 +260,10 @@ class EvaluationTracker:
) )
self.api.upload_file( self.api.upload_file(
repo_id=repo_id, repo_id=repo_id,
path_or_fileobj=str( path_or_fileobj=str(file_results_aggregated),
path.joinpath(f"results_{self.date_id}.json")
),
path_in_repo=os.path.join( path_in_repo=os.path.join(
self.general_config_tracker.model_name, self.general_config_tracker.model_name,
f"results_{self.date_id}.json", file_results_aggregated.name,
), ),
repo_type="dataset", repo_type="dataset",
commit_message=f"Adding aggregated results for {self.general_config_tracker.model_name}", commit_message=f"Adding aggregated results for {self.general_config_tracker.model_name}",
...@@ -290,7 +298,12 @@ class EvaluationTracker: ...@@ -290,7 +298,12 @@ class EvaluationTracker:
eval_logger.info(f"Saving per-sample results for: {task_name}") eval_logger.info(f"Saving per-sample results for: {task_name}")
path = Path(self.output_path if self.output_path else Path.cwd()) path = Path(self.output_path if self.output_path else Path.cwd())
path = path.joinpath(self.general_config_tracker.model_name_sanitized) if path.suffix == ".json":
path = path.parent
else:
path = path.joinpath(
self.general_config_tracker.model_name_sanitized
)
path.mkdir(parents=True, exist_ok=True) path.mkdir(parents=True, exist_ok=True)
file_results_samples = path.joinpath( file_results_samples = path.joinpath(
......
...@@ -16,6 +16,7 @@ from . import ( ...@@ -16,6 +16,7 @@ from . import (
optimum_ipex, optimum_ipex,
optimum_lm, optimum_lm,
sglang_causallms, sglang_causallms,
sglang_generate_API,
textsynth, textsynth,
vllm_causallms, vllm_causallms,
vllm_vlms, vllm_vlms,
......
...@@ -6,6 +6,7 @@ import json ...@@ -6,6 +6,7 @@ import json
import logging import logging
from functools import cached_property from functools import cached_property
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
...@@ -30,7 +31,9 @@ except ModuleNotFoundError: ...@@ -30,7 +31,9 @@ except ModuleNotFoundError:
pass pass
import base64
from importlib.util import find_spec from importlib.util import find_spec
from io import BytesIO
from lm_eval import utils from lm_eval import utils
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
...@@ -38,6 +41,10 @@ from lm_eval.api.model import TemplateLM ...@@ -38,6 +41,10 @@ from lm_eval.api.model import TemplateLM
from lm_eval.models.utils import Collator, chunks, configure_pad_token from lm_eval.models.utils import Collator, chunks, configure_pad_token
if TYPE_CHECKING:
from PIL import Image
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
LogLikelihoodInputs = Tuple[Tuple[str, str], List[int], List[int]] LogLikelihoodInputs = Tuple[Tuple[str, str], List[int], List[int]]
...@@ -51,7 +58,52 @@ class JsonChatStr(NamedTuple): ...@@ -51,7 +58,52 @@ class JsonChatStr(NamedTuple):
return self.prompt.encode(encoding) return self.prompt.encode(encoding)
def create_image_prompt(
imgs: list["Image.Image"], chat: dict, fmt: str = "PNG"
) -> dict:
"""
Parameters
----------
img : list[PIL.Image.Image]
The list of images to encode to base64
chat : dict
fmt : str, optional
Any format Pillow understands (e.g. "PNG", "JPEG").
Defaults to "PNG".
Returns
-------
dict
"""
images = []
for img in imgs:
buf = BytesIO()
img.save(buf, format=fmt)
img_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
img_dict = {
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{img_b64}", "detail": "auto"},
}
images.append(img_dict)
# chat is in format of list[dict["role": "user"/"system", "content": str, "type": "text"],...]
# with images, we need "content" to be a list of dicts with "type" and "text"/"image_url"
# currently we do not support few-shots so only one user message
# text content also has <image> placeholders, which apparently is not necessary for API class (confirm)
if isinstance(chat[-1]["content"], list):
chat[-1]["content"] = images + chat[-1]["content"]
else:
text_content = {"type": "text", "text": chat[-1]["content"]}
chat[-1]["content"] = images + [text_content]
chat[-1].pop("type")
return chat
class TemplateAPI(TemplateLM): class TemplateAPI(TemplateLM):
MULTIMODAL = True
def __init__( def __init__(
self, self,
model: str = None, model: str = None,
...@@ -83,6 +135,7 @@ class TemplateAPI(TemplateLM): ...@@ -83,6 +135,7 @@ class TemplateAPI(TemplateLM):
eos_string: str = None, eos_string: str = None,
# timeout in seconds # timeout in seconds
timeout: int = 300, timeout: int = 300,
max_images: int = 1,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -129,6 +182,7 @@ class TemplateAPI(TemplateLM): ...@@ -129,6 +182,7 @@ class TemplateAPI(TemplateLM):
self.verify_certificate = verify_certificate self.verify_certificate = verify_certificate
self._eos_string = eos_string self._eos_string = eos_string
self.timeout = int(timeout) self.timeout = int(timeout)
self.max_images = int(max_images)
eval_logger.info(f"Using tokenizer {self.tokenizer_backend}") eval_logger.info(f"Using tokenizer {self.tokenizer_backend}")
if self.tokenizer_backend is None: if self.tokenizer_backend is None:
...@@ -265,7 +319,12 @@ class TemplateAPI(TemplateLM): ...@@ -265,7 +319,12 @@ class TemplateAPI(TemplateLM):
) )
else: else:
# bit of a hack. We'll load back before sending to the API # bit of a hack. We'll load back before sending to the API
return JsonChatStr(json.dumps(chat_history, ensure_ascii=False)) return JsonChatStr(
json.dumps(
[{**item, "type": "text"} for item in chat_history],
ensure_ascii=False,
)
)
@cached_property @cached_property
def eot_token_id(self) -> Optional[int]: def eot_token_id(self) -> Optional[int]:
...@@ -578,7 +637,28 @@ class TemplateAPI(TemplateLM): ...@@ -578,7 +637,28 @@ class TemplateAPI(TemplateLM):
return -len(_requests[0]) return -len(_requests[0])
# Let the API deal with tokenization # Let the API deal with tokenization
requests, all_gen_kwargs = zip(*(req.args for req in requests)) if len(requests[0].args) > 2:
assert self.tokenizer is None, (
"tokenizer is not supported for multimodal requests yet!"
)
eval_logger.info(
f"Using max_images {self.max_images}. Set in the model args."
)
requests, all_gen_kwargs, auxiliary_args = zip(
*(req.args for req in requests)
)
requests = tuple(
JsonChatStr(
json.dumps(
create_image_prompt(
y["visual"][: self.max_images], json.loads(x.prompt)
)
)
)
for x, y in zip(requests, auxiliary_args)
)
else:
requests, all_gen_kwargs = zip(*(req.args for req in requests))
if self.tokenized_requests: if self.tokenized_requests:
encodings_list = self.tok_encode( encodings_list = self.tok_encode(
requests, add_special_tokens=self.add_bos_token requests, add_special_tokens=self.add_bos_token
...@@ -597,6 +677,10 @@ class TemplateAPI(TemplateLM): ...@@ -597,6 +677,10 @@ class TemplateAPI(TemplateLM):
chunked = re_ord.get_batched( chunked = re_ord.get_batched(
n=self._batch_size if self._concurrent <= 1 else 0, batch_fn=None n=self._batch_size if self._concurrent <= 1 else 0, batch_fn=None
) )
if not self.tokenized_requests:
eval_logger.info(
"Tokenized requests are disabled. Context + generation length is not checked."
)
if self._concurrent <= 1: if self._concurrent <= 1:
pbar = tqdm(desc="Requesting API", total=len(requests)) pbar = tqdm(desc="Requesting API", total=len(requests))
for chunk in chunked: for chunk in chunked:
...@@ -615,10 +699,7 @@ class TemplateAPI(TemplateLM): ...@@ -615,10 +699,7 @@ class TemplateAPI(TemplateLM):
eval_logger.warning( eval_logger.warning(
f"Some contexts exceeded (max length: ({self.max_length}) - max_gen_toks: ({max_gen_toks}). They were left truncated." f"Some contexts exceeded (max length: ({self.max_length}) - max_gen_toks: ({max_gen_toks}). They were left truncated."
) )
else:
eval_logger.info(
"Tokenized requests are disabled. Context + generation length is not checked."
)
req = encodings_list if self.tokenized_requests else contexts req = encodings_list if self.tokenized_requests else contexts
outputs = retry( outputs = retry(
stop=stop_after_attempt(self.max_retries), stop=stop_after_attempt(self.max_retries),
...@@ -664,10 +745,7 @@ class TemplateAPI(TemplateLM): ...@@ -664,10 +745,7 @@ class TemplateAPI(TemplateLM):
eval_logger.warning( eval_logger.warning(
f"Some contexts exceeded (max length: ({self.max_length}) - max_gen_toks ({max_gen_toks}). They were left truncated." f"Some contexts exceeded (max length: ({self.max_length}) - max_gen_toks ({max_gen_toks}). They were left truncated."
) )
else:
eval_logger.info(
"Tokenized requests are disabled. Context + generation length is not checked."
)
req = encodings_list if self.tokenized_requests else contexts req = encodings_list if self.tokenized_requests else contexts
results = itertools.chain.from_iterable( results = itertools.chain.from_iterable(
asyncio.run( asyncio.run(
......
...@@ -17,6 +17,7 @@ from lm_eval.models.utils import ( ...@@ -17,6 +17,7 @@ from lm_eval.models.utils import (
handle_stop_sequences, handle_stop_sequences,
pad_and_concat, pad_and_concat,
replace_placeholders, replace_placeholders,
resize_image,
stop_sequences_criteria, stop_sequences_criteria,
) )
...@@ -45,10 +46,23 @@ class HFMultimodalLM(HFLM): ...@@ -45,10 +46,23 @@ class HFMultimodalLM(HFLM):
# TODO: handle whitespace in image placeholder (replacement) # TODO: handle whitespace in image placeholder (replacement)
max_images: Optional[int] = 999, max_images: Optional[int] = 999,
convert_img_format=False, convert_img_format=False,
# For image resizing
min_pixels: Optional[int] = None, min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None, max_pixels: Optional[int] = None,
image_width: Optional[int] = None,
image_height: Optional[int] = None,
image_max_side: Optional[int] = None,
**kwargs, **kwargs,
): ):
self.image_width = image_width
self.image_height = image_height
self.image_max_side = image_max_side
if self.image_max_side and (self.image_width or self.image_height):
raise ValueError(
"Ambiguous config for image resize: you can not specify both "
"image_max_side and (image_width or image_height)"
)
# init pixels before calling tokenizer creation to avoid errors # init pixels before calling tokenizer creation to avoid errors
self.pixels = ({"min_pixels": min_pixels} if min_pixels else {}) | ( self.pixels = ({"min_pixels": min_pixels} if min_pixels else {}) | (
{"max_pixels": max_pixels} if max_pixels else {} {"max_pixels": max_pixels} if max_pixels else {}
...@@ -385,6 +399,9 @@ class HFMultimodalLM(HFLM): ...@@ -385,6 +399,9 @@ class HFMultimodalLM(HFLM):
return batched_imgs return batched_imgs
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
if requests and len(requests[0].args) < 3:
# Fall back to non-multimodal generation.
return super().loglikelihood_rolling(requests=requests)
raise NotImplementedError( raise NotImplementedError(
"model type `hf-multimodal` does not support loglikelihood_rolling. Use 'hf' model type for text-only loglikelihood_rolling tasks ", "model type `hf-multimodal` does not support loglikelihood_rolling. Use 'hf' model type for text-only loglikelihood_rolling tasks ",
"this is because we do not support measuring the loglikelihood a model assigns to an image.", "this is because we do not support measuring the loglikelihood a model assigns to an image.",
...@@ -393,6 +410,9 @@ class HFMultimodalLM(HFLM): ...@@ -393,6 +410,9 @@ class HFMultimodalLM(HFLM):
def loglikelihood( def loglikelihood(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: List[Instance], disable_tqdm: bool = False
) -> List[Tuple[float, bool]]: ) -> List[Tuple[float, bool]]:
if requests and len(requests[0].args) < 3:
# Fall back to non-multimodal generation.
return super().loglikelihood(requests=requests, disable_tqdm=disable_tqdm)
raise NotImplementedError( raise NotImplementedError(
"'loglikelihood' requests for model type `hf-multimodal` are not yet tested. This feature will be enabled when a loglikelihood-based multiple-choice VQA dataset is added!" "'loglikelihood' requests for model type `hf-multimodal` are not yet tested. This feature will be enabled when a loglikelihood-based multiple-choice VQA dataset is added!"
) )
...@@ -419,9 +439,11 @@ class HFMultimodalLM(HFLM): ...@@ -419,9 +439,11 @@ class HFMultimodalLM(HFLM):
) )
) )
return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm) return self._multimodal_loglikelihood_tokens(
new_reqs, disable_tqdm=disable_tqdm
)
def _loglikelihood_tokens( def _multimodal_loglikelihood_tokens(
self, self,
requests: List[ requests: List[
Tuple[Tuple[None, str, str], List[int], List[int], List[int]] Tuple[Tuple[None, str, str], List[int], List[int], List[int]]
...@@ -610,7 +632,10 @@ class HFMultimodalLM(HFLM): ...@@ -610,7 +632,10 @@ class HFMultimodalLM(HFLM):
def generate_until( def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]: ) -> List[str]:
# TODO: back out to HFLM.generate_until() for all requests without aux_arguments (text-only reqs) if requests and len(requests[0].args) < 3:
# Fall back to non-multimodal generation.
return super().generate_until(requests=requests, disable_tqdm=disable_tqdm)
res = [] res = []
def _collate(x): def _collate(x):
...@@ -646,7 +671,15 @@ class HFMultimodalLM(HFLM): ...@@ -646,7 +671,15 @@ class HFMultimodalLM(HFLM):
for chunk in chunks: for chunk in chunks:
contexts, all_gen_kwargs, aux_arguments = zip(*chunk) contexts, all_gen_kwargs, aux_arguments = zip(*chunk)
visuals = [arg["visual"] for arg in aux_arguments] visuals = [
[
resize_image(
img, self.image_width, self.image_height, self.image_max_side
)
for img in arg["visual"]
]
for arg in aux_arguments
]
if not isinstance(contexts, list): if not isinstance(contexts, list):
contexts = list( contexts = list(
......
...@@ -890,7 +890,10 @@ class HFLM(TemplateLM): ...@@ -890,7 +890,10 @@ class HFLM(TemplateLM):
input_ids=inps, attention_mask=attn_mask, labels=labels input_ids=inps, attention_mask=attn_mask, labels=labels
).logits ).logits
else: else:
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM assert self.AUTO_MODEL_CLASS in (
transformers.AutoModelForCausalLM,
transformers.AutoModelForVision2Seq,
)
return self.model(inps).logits return self.model(inps).logits
def _model_generate(self, context, max_length, stop, **generation_kwargs): def _model_generate(self, context, max_length, stop, **generation_kwargs):
...@@ -1136,7 +1139,7 @@ class HFLM(TemplateLM): ...@@ -1136,7 +1139,7 @@ class HFLM(TemplateLM):
if self.backend == "causal": if self.backend == "causal":
total_length = len(context_enc) + len(continuation_enc) total_length = len(context_enc) + len(continuation_enc)
if total_length > self.max_length + 1: if total_length > self.max_length + 1:
eval_logger.warn( eval_logger.warning(
f"Combined length of context ({len(context_enc)}) and continuation ({len(continuation_enc)}) " f"Combined length of context ({len(context_enc)}) and continuation ({len(continuation_enc)}) "
f"exceeds model's maximum length ({self.max_length}). " f"exceeds model's maximum length ({self.max_length}). "
f"Truncating {total_length - self.max_length + 1} tokens from the left." f"Truncating {total_length - self.max_length + 1} tokens from the left."
...@@ -1247,7 +1250,12 @@ class HFLM(TemplateLM): ...@@ -1247,7 +1250,12 @@ class HFLM(TemplateLM):
cont_toks = torch.tensor( cont_toks = torch.tensor(
cont_toks, dtype=torch.long, device=self.device cont_toks, dtype=torch.long, device=self.device
).unsqueeze(0) # [1, seq] ).unsqueeze(0) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all() # Use trailing slice [-cont_toks.shape[1]:] to handle variable length cont_len (but same ctx+cont[:-1]).
# i.e. continuations can be sliced at diff points. Collator ensures we have sufficient greedy_tokens
# by choosing key with longest cont if group_by="contexts".
max_equal = (
greedy_tokens[:, -cont_toks.shape[1] :] == cont_toks
).all()
# Obtain log-probs at the corresponding continuation token indices # Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist() # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
......
from typing import Dict, List, Optional, Tuple, Union
from lm_eval.api.registry import register_model
from lm_eval.models.openai_completions import LocalCompletionsAPI
from lm_eval.models.utils import handle_stop_sequences
@register_model("sglang-generate")
class SGLANGGENERATEAPI(LocalCompletionsAPI):
def __init__(
self,
base_url=None,
tokenizer_backend="huggingface",
**kwargs,
):
super().__init__(
base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
)
def _create_payload(
self,
messages: Union[List[List[int]], List[dict], List[str], str],
generate=False,
gen_kwargs: Optional[dict] = None,
seed: int = 1234,
eos=None,
**kwargs,
) -> dict:
is_string = (
True
if (isinstance(messages, str) or isinstance(messages[0], str))
else False
)
if generate:
gen_kwargs.pop("do_sample", False)
if "max_tokens" in gen_kwargs:
max_tokens = gen_kwargs.pop("max_tokens")
else:
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
temperature = gen_kwargs.pop("temperature", 0)
stop = handle_stop_sequences(gen_kwargs.pop("until", None), eos)
request = {
"sampling_params": {
"max_new_tokens": max_tokens,
"temperature": temperature,
"stop": stop,
**gen_kwargs,
},
}
request.update({"text": messages}) if is_string else request.update(
{"input_ids": messages}
)
return request
else:
assert not is_string, "Logprobs are only supported for tokenized inputs"
request = {
"input_ids": messages,
"sampling_params": {"max_new_tokens": 1, "temperature": 0},
"logprob_start_len": 0,
"top_logprobs_num": 1,
"return_logprob": True,
}
return request
@staticmethod
def parse_logprobs(
outputs: Union[Dict, List[Dict]],
tokens: List[List[int]] = None,
ctxlens: List[int] = None,
**kwargs,
) -> List[Tuple[float, bool]]:
res = []
if not isinstance(outputs, list):
outputs = [outputs]
for choice, ctxlen in zip(outputs, ctxlens):
choice = choice["meta_info"]
assert ctxlen > 0, "Context length must be greater than 0"
logprobs = sum(x[0] for x in choice["input_token_logprobs"][ctxlen:])
is_greedy = all(
x[1] != y[0][1]
for x, y in zip(
choice["input_token_logprobs"][ctxlen:],
choice["input_top_logprobs"][ctxlen:],
)
)
res.append((logprobs, is_greedy))
return res
@staticmethod
def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]:
res = []
if not isinstance(outputs, list):
outputs = [outputs]
for out in outputs:
res.append(out["text"])
return res
@property
def api_key(self):
return ""
...@@ -28,6 +28,7 @@ eval_logger = logging.getLogger(__name__) ...@@ -28,6 +28,7 @@ eval_logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from PIL import Image
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
...@@ -427,9 +428,13 @@ class Collator: ...@@ -427,9 +428,13 @@ class Collator:
batch = self.get_chunks(values, n=n, fn=batch_fn) batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch yield from batch
elif self._group_by == "contexts": elif self._group_by == "contexts":
# Get one sample from each key # Get one sample from each key.
# Select longest continuation per group to ensure sufficient context logits
values = self._reorder( values = self._reorder(
[value[0] for value in self._arr_with_indices.values()] [
max(value, key=lambda x: len(x[1][-1]))
for value in self._arr_with_indices.values()
]
) )
batch = self.get_chunks(values, n=n, fn=batch_fn) batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch yield from batch
...@@ -729,3 +734,121 @@ def handle_stop_sequences( ...@@ -729,3 +734,121 @@ def handle_stop_sequences(
if eos is not None and eos not in until: if eos is not None and eos not in until:
until.append(eos) until.append(eos)
return until return until
def resize_image(
image: "Image.Image",
width: Optional[int] = None,
height: Optional[int] = None,
max_dimension: Optional[int] = None,
keep_aspect_ratio: bool = True,
resample_filter: Union[int, str] = "Image.BICUBIC",
min_width: int = 1,
min_height: int = 1,
) -> "Image.Image":
"""
Resizes a PIL Image object with flexible options.
Args:
image: The PIL Image object to resize.
width: Target width in pixels.
height: Target height in pixels.
max_dimension: Maximum size for the longer dimension of the image.
keep_aspect_ratio: If True (default) and both width and height are provided,
the image is resized to fit within these dimensions while
maintaining its aspect ratio. If False, the image is stretched
to the exact width and height.
resample_filter: The resampling filter to use for resizing.
Defaults to Image.BICUBIC.
min_width: Minimum width for the resized image. Defaults to 1.
min_height: Minimum height for the resized image. Defaults to 1.
Returns:
The resized PIL Image object. If no resize parameters are provided
or if the image already meets the criteria, the original image is returned.
Order of precedence for resizing:
1. If width AND height are provided:
- If keep_aspect_ratio is True: Fits image within bounds, preserving aspect ratio.
- If keep_aspect_ratio is False: Resizes to exact dimensions (may distort).
2. Else if only width is provided: Calculates height proportionally.
3. Else if only height is provided: Calculates width proportionally.
4. Else if max_dimension is provided: Resizes the longest side to max_dimension
and scales the other side proportionally.
5. If none of the above are provided, returns the original image.
"""
original_width, original_height = image.size
# If no arguments are provided, return the original image
if width is None and height is None and max_dimension is None:
return image
new_width = original_width
new_height = original_height
if width is not None and height is not None:
# No resize needed if image is already smaller than target dimensions
if original_width <= width and original_height <= height:
return image
if keep_aspect_ratio:
# Calculate the ratio to fit within the target dimensions
ratio = min(width / original_width, height / original_height)
new_width = int(original_width * ratio)
new_height = int(original_height * ratio)
else:
# Stretch to exact dimensions
new_width = width
new_height = height
elif width is not None:
# No resize needed if width is already smaller
if original_width <= width:
return image
# Calculate height proportionally
new_width = width
new_height = int((original_height / original_width) * new_width)
elif height is not None:
# No resize needed if height is already smaller
if original_height <= height:
return image
# Calculate width proportionally
new_height = height
new_width = int((original_width / original_height) * new_height)
elif max_dimension is not None:
# No resize needed if both dimensions are smaller than max_dimension
if max(original_height, original_width) <= max_dimension:
return image
if original_width > original_height:
# Width is the longer side
new_width = max_dimension
new_height = int((original_height / original_width) * new_width)
else:
# Height is the longer side or sides are equal
new_height = max_dimension
new_width = int((original_width / original_height) * new_height)
# Ensure dimensions are at least minimum values
new_width = max(min_width, new_width)
new_height = max(min_height, new_height)
# Perform the resize operation with the calculated dimensions
return image.resize((new_width, new_height), resample_filter)
def truncate_tokens(
tokens: List[int],
max_length: int,
tokenizer: "PreTrainedTokenizerBase",
strategy: str = "left",
):
if strategy == "left":
return tokens[-max_length:]
elif strategy == "right":
return tokens[:max_length]
elif strategy == "middle":
# Truncate the middle of the sequence
left_length = max_length // 2
right_length = max_length - left_length
return tokens[:left_length] + tokens[-right_length:]
return None
import copy import copy
import gc
import inspect
import logging import logging
import os
from importlib.metadata import version from importlib.metadata import version
from importlib.util import find_spec from importlib.util import find_spec
from multiprocessing import Process, Queue
from queue import Empty
from time import sleep
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
from more_itertools import distribute from more_itertools import distribute
...@@ -28,6 +34,7 @@ try: ...@@ -28,6 +34,7 @@ try:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import get_open_port
if parse_version(version("vllm")) >= parse_version("0.8.3"): if parse_version(version("vllm")) >= parse_version("0.8.3"):
from vllm.entrypoints.chat_utils import resolve_hf_chat_template from vllm.entrypoints.chat_utils import resolve_hf_chat_template
...@@ -40,6 +47,63 @@ if TYPE_CHECKING: ...@@ -40,6 +47,63 @@ if TYPE_CHECKING:
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
def _vllm_mp_worker(
model_args: dict,
sampling_params: "SamplingParams",
requests: list[list[int]],
lora_request: "LoRARequest",
result_queue: "Queue",
dp_size: int,
local_dp_rank: int,
dp_master_port: int,
dp_master_ip: str = "127.0.0.1",
) -> None:
"""
Worker process for vLLM multiprocessing.
Initializes a vLLM engine, processes requests, and puts results or errors
onto the result_queue.
"""
if not requests:
result_queue.put((local_dp_rank, []))
return None
os.environ["VLLM_DP_RANK"] = os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size)
os.environ["VLLM_DP_MASTER_IP"] = str(dp_master_ip)
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
llm = None
try:
llm = LLM(**model_args)
res = llm.generate(
prompt_token_ids=requests,
sampling_params=sampling_params,
lora_request=lora_request,
)
# Give engines time to pause their processing loops before exiting."
sleep(1)
result_queue.put((local_dp_rank, res))
except Exception as e:
error_message = f"Worker {local_dp_rank} failed during generation: {type(e).__name__}: {str(e)}"
eval_logger.error(error_message, exc_info=True)
result_queue.put((local_dp_rank, {"error": error_message}))
finally:
if llm is not None:
try:
del llm
gc.collect()
except Exception as e_cleanup:
eval_logger.warning(
f"Worker {local_dp_rank} encountered an error during LLM cleanup: {type(e_cleanup).__name__}: {str(e_cleanup)}",
exc_info=True,
)
return None
@register_model("vllm") @register_model("vllm")
class VLLM(TemplateLM): class VLLM(TemplateLM):
_DEFAULT_MAX_LENGTH = 2048 _DEFAULT_MAX_LENGTH = 2048
...@@ -68,6 +132,7 @@ class VLLM(TemplateLM): ...@@ -68,6 +132,7 @@ class VLLM(TemplateLM):
device: str = "cuda", device: str = "cuda",
data_parallel_size: int = 1, data_parallel_size: int = 1,
lora_local_path: str = None, lora_local_path: str = None,
enable_thinking: bool = False,
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
...@@ -81,7 +146,7 @@ class VLLM(TemplateLM): ...@@ -81,7 +146,7 @@ class VLLM(TemplateLM):
assert max_length is None or max_model_len is None, ( assert max_length is None or max_model_len is None, (
"Either max_length or max_model_len may be provided, but not both" "Either max_length or max_model_len may be provided, but not both"
) )
self.V1 = os.environ.get("VLLM_USE_V1", "1") != "0"
self._max_length = max_model_len if max_model_len is not None else max_length self._max_length = max_model_len if max_model_len is not None else max_length
self.tensor_parallel_size = int(tensor_parallel_size) self.tensor_parallel_size = int(tensor_parallel_size)
self.data_parallel_size = int(data_parallel_size) self.data_parallel_size = int(data_parallel_size)
...@@ -96,9 +161,11 @@ class VLLM(TemplateLM): ...@@ -96,9 +161,11 @@ class VLLM(TemplateLM):
"trust_remote_code": trust_remote_code, "trust_remote_code": trust_remote_code,
"tensor_parallel_size": int(tensor_parallel_size), "tensor_parallel_size": int(tensor_parallel_size),
"max_model_len": int(self._max_length) if self._max_length else None, "max_model_len": int(self._max_length) if self._max_length else None,
"max_num_seqs": kwargs.get("max_num_seqs", max_batch_size),
"swap_space": int(swap_space), "swap_space": int(swap_space),
"quantization": quantization, "quantization": quantization,
"seed": int(seed), "seed": int(seed),
"device": str(device),
} }
self.model_args.update(kwargs) self.model_args.update(kwargs)
self.batch_size = ( self.batch_size = (
...@@ -112,7 +179,11 @@ class VLLM(TemplateLM): ...@@ -112,7 +179,11 @@ class VLLM(TemplateLM):
eval_logger.warning( eval_logger.warning(
"You might experience occasional issues with model weight downloading when data_parallel is in use. To ensure stable performance, run with data_parallel_size=1 until the weights are downloaded and cached." "You might experience occasional issues with model weight downloading when data_parallel is in use. To ensure stable performance, run with data_parallel_size=1 until the weights are downloaded and cached."
) )
self.model_args["distributed_executor_backend"] = "ray" self.model_args["distributed_executor_backend"] = (
"ray"
if not self.V1
else self.model_args.get("distributed_executor_backend", None)
)
self.batch_size = "auto" self.batch_size = "auto"
eval_logger.info("Manual batching is not compatible with data parallelism.") eval_logger.info("Manual batching is not compatible with data parallelism.")
...@@ -129,6 +200,7 @@ class VLLM(TemplateLM): ...@@ -129,6 +200,7 @@ class VLLM(TemplateLM):
add_bos_token=add_bos_token, add_bos_token=add_bos_token,
) )
self.tokenizer = configure_pad_token(self.tokenizer, model_config=self._config) self.tokenizer = configure_pad_token(self.tokenizer, model_config=self._config)
self.enable_thinking = enable_thinking
self.add_bos_token = add_bos_token self.add_bos_token = add_bos_token
if "gemma" in pretrained.lower(): if "gemma" in pretrained.lower():
self.add_bos_token = True self.add_bos_token = True
...@@ -137,11 +209,36 @@ class VLLM(TemplateLM): ...@@ -137,11 +209,36 @@ class VLLM(TemplateLM):
) )
if parse_version(version("vllm")) >= parse_version("0.8.3"): if parse_version(version("vllm")) >= parse_version("0.8.3"):
kwargs_resolve_hf_chat_template = {
"tokenizer": self.tokenizer,
"chat_template": None,
"tools": None,
}
if parse_version(version("vllm")) >= parse_version("0.9.0"):
if self.data_parallel_size <= 1:
kwargs_resolve_hf_chat_template["model_config"] = (
self.model.llm_engine.model_config
)
else:
from vllm.engine.arg_utils import EngineArgs
engine_args = EngineArgs(**self.model_args)
model_config = engine_args.create_model_config()
kwargs_resolve_hf_chat_template["model_config"] = model_config
# https://github.com/vllm-project/vllm/pull/18259
if (
"trsut_remote_code"
in inspect.signature(resolve_hf_chat_template).parameters
):
kwargs_resolve_hf_chat_template["trsut_remote_code"] = trust_remote_code
else:
kwargs_resolve_hf_chat_template["trust_remote_code"] = trust_remote_code
self.hf_chat_template = resolve_hf_chat_template( self.hf_chat_template = resolve_hf_chat_template(
tokenizer=self.tokenizer, **kwargs_resolve_hf_chat_template
chat_template=None,
tools=None,
trust_remote_code=trust_remote_code,
) )
else: else:
self.hf_chat_template = None self.hf_chat_template = None
...@@ -209,6 +306,7 @@ class VLLM(TemplateLM): ...@@ -209,6 +306,7 @@ class VLLM(TemplateLM):
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt, continue_final_message=not add_generation_prompt,
chat_template=self.hf_chat_template, chat_template=self.hf_chat_template,
enable_thinking=self.enable_thinking,
) )
return chat_templated return chat_templated
...@@ -257,7 +355,7 @@ class VLLM(TemplateLM): ...@@ -257,7 +355,7 @@ class VLLM(TemplateLM):
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False
) )
if self.data_parallel_size > 1: if self.data_parallel_size > 1 and not self.V1:
# vLLM hangs if resources are set in ray.remote # vLLM hangs if resources are set in ray.remote
# also seems to only work with decorator and not with ray.remote() fn # also seems to only work with decorator and not with ray.remote() fn
# see https://github.com/vllm-project/vllm/issues/973 # see https://github.com/vllm-project/vllm/issues/973
...@@ -288,14 +386,83 @@ class VLLM(TemplateLM): ...@@ -288,14 +386,83 @@ class VLLM(TemplateLM):
ray.shutdown() ray.shutdown()
# flatten results # flatten results
return undistribute(results) return undistribute(results)
elif self.data_parallel_size > 1:
# based on https://github.com/vllm-project/vllm/blob/a04720bc36401d831cb048c3917b9e58173d9c1d/examples/offline_inference/data_parallel.py
dp_size = self.data_parallel_size
dp_master_ip = os.environ.get("VLLM_DP_MASTER_IP", "127.0.0.1")
dp_master_port = os.environ.get("VLLM_DP_MASTER_PORT") or get_open_port()
requests = (list(x) for x in distribute(self.data_parallel_size, requests))
procs, resq = [], Queue()
# We use Process as it is non-daemonic
try:
for rank, req in enumerate(requests):
proc = Process(
target=_vllm_mp_worker,
args=(
self.model_args.copy(),
sampling_params,
req,
self.lora_request,
resq,
dp_size,
rank,
dp_master_port,
dp_master_ip,
),
)
proc.start()
procs.append(proc)
# Collect results
rank_res = {}
while len(rank_res) < len(procs):
try:
rank, result = resq.get(timeout=30)
if isinstance(result, dict) and "error" in result:
raise RuntimeError(result["error"])
rank_res[rank] = result
except Empty:
dead_procs = [
idx
for idx, p in enumerate(procs)
if not p.is_alive() and idx not in rank_res
]
if dead_procs:
raise RuntimeError(
f"Worker processes {dead_procs} died unexpectedly"
)
continue
results = [rank_res[i] for i in range(len(procs))]
return undistribute(results)
# cleanup
finally:
try:
resq.close()
resq.join_thread()
except Exception:
eval_logger.debug(
"Failed to close vllm DP results queue", exc_info=True
)
for proc in procs:
proc.join(timeout=10)
if proc.is_alive():
proc.terminate()
proc.join(timeout=5)
if proc.is_alive():
proc.kill()
outputs = self.model.generate( else:
prompt_token_ids=requests, outputs = self.model.generate(
sampling_params=sampling_params, prompt_token_ids=requests,
use_tqdm=True if self.batch_size == "auto" else False, sampling_params=sampling_params,
lora_request=self.lora_request, use_tqdm=True if self.batch_size == "auto" else False,
) lora_request=self.lora_request,
return outputs )
return outputs
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: List[Instance], disable_tqdm: bool = False
...@@ -427,6 +594,12 @@ class VLLM(TemplateLM): ...@@ -427,6 +594,12 @@ class VLLM(TemplateLM):
# set the max length in tokens of inputs ("context_enc") # set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens # max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks max_ctx_len = self.max_length - max_gen_toks
all_lengths = [len(x) for x in context_encoding]
for length in all_lengths:
if length > max_ctx_len:
eval_logger.warning(
f"Context length {length} exceeds max length (context + max gen tokens): {max_ctx_len}. Truncating context."
)
context_encoding = [x[-max_ctx_len:] for x in context_encoding] context_encoding = [x[-max_ctx_len:] for x in context_encoding]
# perform batched generation # perform batched generation
...@@ -441,6 +614,10 @@ class VLLM(TemplateLM): ...@@ -441,6 +614,10 @@ class VLLM(TemplateLM):
# cache generations # cache generations
for output, context in zip(cont, context): for output, context in zip(cont, context):
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until:
if len(term) > 0:
generated_text = generated_text.split(term)[0]
res.append(generated_text) res.append(generated_text)
self.cache_hook.add_partial( self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), generated_text "generate_until", (context, gen_kwargs), generated_text
...@@ -477,6 +654,12 @@ class VLLM(TemplateLM): ...@@ -477,6 +654,12 @@ class VLLM(TemplateLM):
inputs = [] inputs = []
ctxlens = [] ctxlens = []
for cache_key, context_enc, continuation_enc in chunk: for cache_key, context_enc, continuation_enc in chunk:
if (
full_length := len(context_enc + continuation_enc)
) > self.max_length:
eval_logger.warning(
f"Context length {full_length} exceeds max length ({self.max_length}). Truncating context."
)
inp = (context_enc + continuation_enc)[-(self.max_length) :] inp = (context_enc + continuation_enc)[-(self.max_length) :]
ctxlen = len(context_enc) - max( ctxlen = len(context_enc) - max(
0, len(context_enc) + len(continuation_enc) - (self.max_length) 0, len(context_enc) + len(continuation_enc) - (self.max_length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment