Unverified Commit a3d08277 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #1011 from baberabb/big-refactor_vllm

[Refactor] vllm support
parents 30936bc7 4277dc35
...@@ -39,7 +39,7 @@ pip install -e ".[gptq]" ...@@ -39,7 +39,7 @@ pip install -e ".[gptq]"
``` ```
To install the package with all extras, run Though we recommend only installing the extras you require, to install the package with all extras, run
```bash ```bash
pip install -e ".[all]" pip install -e ".[all]"
``` ```
...@@ -132,23 +132,44 @@ To use `accelerate` with the `lm-eval` command, use ...@@ -132,23 +132,44 @@ To use `accelerate` with the `lm-eval` command, use
accelerate launch --no_python lm-eval --model ... accelerate launch --no_python lm-eval --model ...
``` ```
### Commercial APIs ### Tensor Parallel + Optimized Inference with vLLM
Our library also supports the evaluation of models served via several commercial APIs, and hope to implement support for common performant local/self-hosted inference servers. We also support vLLM for faster inference on [supported model types](https://docs.vllm.ai/en/latest/models/supported_models.html).
To run with vLLM, first install the vllm library, externally or via the lm_eval[vllm] extra:
```bash
pip install -e .[vllm]
```
Then, you can run the library as normal, for single-GPU or tensor-parallel inference, for example:
```bash
python -m lm_eval \
--model vllm \
--model_args pretrained={model_name},tensor_parallel_size={number of GPUs to use},dtype=auto,gpu_memory_utilization=0.8
--tasks lambada_openai
--batch_size auto
```
For a full list of supported vLLM configurations, please reference our vLLM integration and the vLLM documentation.
### Supported APIs and Inference Libraries
Our library also supports the evaluation of models served via several commercial APIs, and we hope to implement support for the most commonly used performant local/self-hosted inference servers.
A full accounting of the supported and planned libraries + APIs can be seen below: A full accounting of the supported and planned libraries + APIs can be seen below:
| API or Inference Server | Implemented? | `--model <xxx>` name | Models supported: | Request Types: | | API or Inference Server | Implemented? | `--model <xxx>` name | Models supported: | Request Types: |
|-----------------------------|---------------------------------|----------------------------------------------------------------------------------|--------------------------------------|----------------------------------------------------------| |-----------------------------|---------------------------------|----------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------|
| OpenAI Completions | :heavy_check_mark: | `openai`, `openai-completions`, `gooseai` | up to `code-davinci-002` | `generate_until`, `loglikelihood`, `loglikelihood_rolling` | | OpenAI Completions | :heavy_check_mark: | `openai`, `openai-completions`, `gooseai` | up to `code-davinci-002` | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| OpenAI ChatCompletions | :x: Not yet - needs testing! | N/A | [All ChatCompletions API models](https://platform.openai.com/docs/guides/gpt) | `generate_until` (no logprobs) | | OpenAI ChatCompletions | :x: Not yet - needs testing! | N/A | [All ChatCompletions API models](https://platform.openai.com/docs/guides/gpt) | `generate_until` (no logprobs) |
| Anthropic | :heavy_check_mark: | `anthropic` | [Supported Anthropic Engines](https://docs.anthropic.com/claude/reference/selecting-a-model) | `generate_until` (no logprobs) | | Anthropic | :heavy_check_mark: | `anthropic` | [Supported Anthropic Engines](https://docs.anthropic.com/claude/reference/selecting-a-model) | `generate_until` (no logprobs) |
| GooseAI | :heavy_check_mark: (not separately maintained) | `openai`, `openai-completions`, `gooseai` (same interface as OpenAI Completions) | | `generate_until`, `loglikelihood`, `loglikelihood_rolling` | | GooseAI | :heavy_check_mark: (not separately maintained) | `openai`, `openai-completions`, `gooseai` (same interface as OpenAI Completions) | | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| Textsynth | Needs testing | `textsynth` | ??? | `generate_until`, `loglikelihood`, `loglikelihood_rolling` | | Textsynth | Needs testing | `textsynth` | ??? | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| Cohere | :hourglass: - blocked on Cohere API bug | N/A | [All `cohere.generate()` engines](https://docs.cohere.com/docs/models) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` | | Cohere | :hourglass: - blocked on Cohere API bug | N/A | [All `cohere.generate()` engines](https://docs.cohere.com/docs/models) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| GGML/[Llama.cpp](https://github.com/ggerganov/llama.cpp) (via [llama-cpp-python](https://github.com/abetlen/llama-cpp-python)) | :heavy_check_mark: | `gguf`, `ggml` | Llama-architecture models (Llama, Llama 2, Llemma, Mistral(?), Llama finetunes) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` | | GGML/[Llama.cpp](https://github.com/ggerganov/llama.cpp) (via [llama-cpp-python](https://github.com/abetlen/llama-cpp-python)) | :heavy_check_mark: | `gguf`, `ggml` | Llama-architecture models (Llama, Llama 2, Llemma, Mistral(?), Llama finetunes) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| vLLM | :x: Not yet - needs help! | N/A | All HF models | `generate_until` (no logprobs) | | vLLM | :heavy_check_mark: | `vllm` | [Most HF Causal Language Models](https://docs.vllm.ai/en/latest/models/supported_models.html) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| Your inference server here! | ... | ... | ... | ... | | ... | | Your inference server here! | ... | ... | ... | ... | | ... |
It is on our roadmap to create task variants designed to enable models which do not serve logprobs/loglikelihoods to be compared with generation performance of open-source models. It is on our roadmap to create task variants designed to enable models which do not serve logprobs/loglikelihoods to be compared with generation performance of open-source models.
......
...@@ -102,6 +102,8 @@ class MyCustomLM(LM): ...@@ -102,6 +102,8 @@ class MyCustomLM(LM):
Using this decorator results in the class being added to an accounting of the usable LM types maintained internally to the library at `lm_eval.api.registry.MODEL_REGISTRY`. See `lm_eval.api.registry` for more detail on what sorts of registries and decorators exist in the library! Using this decorator results in the class being added to an accounting of the usable LM types maintained internally to the library at `lm_eval.api.registry.MODEL_REGISTRY`. See `lm_eval.api.registry` for more detail on what sorts of registries and decorators exist in the library!
**Tip: be sure to import your model in `lm_eval/models/__init__.py!`**
## Testing ## Testing
We also recommend that new model contributions be accompanied by short tests of their 3 core functionalities, at minimum. To see an example of such tests, look at https://github.com/EleutherAI/lm-evaluation-harness/blob/35bdecd379c0cefad6897e67db892f4a6026a128/tests/test_ggml.py . We also recommend that new model contributions be accompanied by short tests of their 3 core functionalities, at minimum. To see an example of such tests, look at https://github.com/EleutherAI/lm-evaluation-harness/blob/35bdecd379c0cefad6897e67db892f4a6026a128/tests/test_ggml.py .
......
...@@ -4,6 +4,6 @@ from . import textsynth ...@@ -4,6 +4,6 @@ from . import textsynth
from . import dummy from . import dummy
from . import anthropic_llms from . import anthropic_llms
from . import gguf from . import gguf
from . import vllm_causallms
# TODO: implement __all__ # TODO: implement __all__
...@@ -16,13 +16,14 @@ from pathlib import Path ...@@ -16,13 +16,14 @@ from pathlib import Path
import torch.nn.functional as F import torch.nn.functional as F
from lm_eval import utils from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator, find_executable_batch_size, DistributedType from accelerate import Accelerator, find_executable_batch_size, DistributedType
from typing import List, Optional, Union from typing import List, Optional, Union, Tuple
eval_logger = utils.eval_logger eval_logger = utils.eval_logger
...@@ -420,7 +421,9 @@ class HFLM(LM): ...@@ -420,7 +421,9 @@ class HFLM(LM):
utils.clear_torch_cache() utils.clear_torch_cache()
return batch_size return batch_size
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None): def tok_encode(
self, string: str, left_truncate_len=None, add_special_tokens=None
) -> List[int]:
""" """ """ """
if add_special_tokens is None: if add_special_tokens is None:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
...@@ -442,7 +445,7 @@ class HFLM(LM): ...@@ -442,7 +445,7 @@ class HFLM(LM):
padding_side: str = "left", padding_side: str = "left",
left_truncate_len: int = None, left_truncate_len: int = None,
truncation: bool = False, truncation: bool = False,
): ) -> Tuple[List[int], List[int]]:
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side = self.tokenizer.padding_side old_padding_side = self.tokenizer.padding_side
self.tokenizer.padding_side = padding_side self.tokenizer.padding_side = padding_side
...@@ -536,7 +539,9 @@ class HFLM(LM): ...@@ -536,7 +539,9 @@ class HFLM(LM):
return logits return logits
def _encode_pair(self, context, continuation): def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip()) n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0: if n_spaces > 0:
continuation = context[-n_spaces:] + continuation continuation = context[-n_spaces:] + continuation
...@@ -551,7 +556,7 @@ class HFLM(LM): ...@@ -551,7 +556,7 @@ class HFLM(LM):
continuation_enc = whole_enc[context_enc_len:] continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc return context_enc, continuation_enc
def loglikelihood(self, requests): def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
new_reqs = [] new_reqs = []
for context, continuation in [req.args for req in requests]: for context, continuation in [req.args for req in requests]:
if context == "": if context == "":
...@@ -566,7 +571,7 @@ class HFLM(LM): ...@@ -566,7 +571,7 @@ class HFLM(LM):
return self._loglikelihood_tokens(new_reqs) return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests): def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
loglikelihoods = [] loglikelihoods = []
adaptive_batch_size = None adaptive_batch_size = None
...@@ -640,8 +645,11 @@ class HFLM(LM): ...@@ -640,8 +645,11 @@ class HFLM(LM):
return self.batch_sizes[sched] return self.batch_sizes[sched]
def _loglikelihood_tokens( def _loglikelihood_tokens(
self, requests, disable_tqdm: bool = False, override_bs=None self,
): requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
disable_tqdm: bool = False,
override_bs: int = None,
) -> List[Tuple[float, bool]]:
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = [] res = []
...@@ -820,7 +828,7 @@ class HFLM(LM): ...@@ -820,7 +828,7 @@ class HFLM(LM):
return re_ord.get_original(res) return re_ord.get_original(res)
def generate_until(self, requests): def generate_until(self, requests: List[Instance]) -> List[str]:
res = defaultdict(list) res = defaultdict(list)
re_ords = {} re_ords = {}
......
from collections import defaultdict
from typing import List, Tuple, Optional, Literal, Union
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
import copy
from tqdm import tqdm
from lm_eval.api.registry import register_model
from lm_eval import utils
try:
from vllm import LLM, SamplingParams
except ModuleNotFoundError:
pass
eval_logger = utils.eval_logger
@register_model("vllm")
class VLLM(LM):
_DEFAULT_MAX_LENGTH = 2048
def __init__(
self,
pretrained="gpt2",
dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto",
revision: Optional[str] = None,
trust_remote_code: Optional[bool] = False,
tokenizer_mode: Literal["auto", "slow"] = "auto",
tensor_parallel_size: int = 1,
quantization: Optional[Literal["awq"]] = None,
max_gen_toks: int = 256,
swap_space: int = 4,
batch_size: Union[str, int] = 1,
max_batch_size=None,
max_length: int = None,
seed: int = 1234,
gpu_memory_utilization: float = 0.9,
device: str = "cuda",
):
super().__init__()
try:
import vllm
except ModuleNotFoundError:
raise Exception(
"attempted to use 'vllm' LM type, but package `vllm` is not installed. \
please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`",
)
assert "cuda" in device or device is None, "vLLM only supports CUDA"
self.model = LLM(
model=pretrained,
gpu_memory_utilization=float(gpu_memory_utilization),
revision=revision,
dtype=dtype,
tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
tensor_parallel_size=int(tensor_parallel_size),
swap_space=int(swap_space),
quantization=quantization,
seed=int(seed),
)
self.tokenizer = self.model.get_tokenizer()
self.batch_size = batch_size
self._max_length = max_length
self._max_gen_toks = max_gen_toks
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def max_length(self):
if self._max_length: # if max length manually set, return it
return self._max_length
if hasattr(self.model.llm_engine.model_config, "max_model_len"):
return self.model.llm_engine.model_config.max_model_len
return self._DEFAULT_MAX_LENGTH
@property
def max_gen_toks(self):
return self._max_gen_toks
def tok_encode(
self,
string: str,
left_truncate_len=None,
add_special_tokens=False,
truncation=False,
):
""" """
encoding = self.tokenizer.encode(
string, add_special_tokens=add_special_tokens, truncation=truncation
)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if left_truncate_len:
encoding = encoding[-left_truncate_len:]
return encoding
def _model_generate(
self,
requests: List[int] = None,
generate: bool = False,
max_tokens: int = None,
stop: Optional[List[str]] = None,
use_tqdm=True,
**kwargs,
):
if "do_sample" in kwargs.keys():
kwargs.pop("do_sample")
if generate:
generate_sampling_params = SamplingParams(
max_tokens=max_tokens, stop=stop, **kwargs
)
outputs = self.model.generate(
prompt_token_ids=requests,
sampling_params=generate_sampling_params,
use_tqdm=use_tqdm,
)
else:
logliklihood_sampling_params = SamplingParams(
temperature=0, prompt_logprobs=2, max_tokens=1
)
outputs = self.model.generate(
prompt_token_ids=requests,
sampling_params=logliklihood_sampling_params,
use_tqdm=use_tqdm,
)
return outputs
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
# end of text as context
context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(
continuation
)
else:
context_enc, continuation_enc = self.tokenizer(
[context, continuation],
truncation="do_not_truncate",
add_special_tokens=False,
return_attention_mask=False,
).input_ids
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
loglikelihoods = []
for (string,) in tqdm([req.args for req in requests]):
rolling_token_windows = list(
map(
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=self.eot_token_id,
max_seq_len=self.max_length - 1,
context_len=1,
),
)
)
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
string_nll = self._loglikelihood_tokens(
rolling_token_windows,
)
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods
def generate_until(self, requests: List[Instance]) -> List[str]:
res = defaultdict(list)
re_ords = {}
# batch tokenize contexts
context, all_gen_kwargs = zip(*(req.args for req in requests))
context_encoding = self.tokenizer(context).input_ids
requests = [
((a, b), c) for a, b, c in zip(context, context_encoding, all_gen_kwargs)
]
def _collate_gen(_requests):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
return -len(_requests[0][1]), tuple(_requests[0][1])
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
grouper = utils.Grouper(requests, lambda x: str(x[1]))
for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer(requests, _collate_gen)
pbar = tqdm(total=len(requests), disable=(self.rank != 0))
# for each different set of kwargs, we execute all requests, by batch.
for key, re_ord in re_ords.items():
chunks = utils.chunks(
re_ord.get_reordered(),
n=self.batch_size if self.batch_size != "auto" else 0,
fn=None,
)
for chunk in chunks:
context_and_encoding, all_gen_kwargs = zip(*chunk)
context, context_encoding = zip(*context_and_encoding)
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments.
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {gen_kwargs}"
)
if not until:
until = [self.tokenizer.decode(self.eot_token_id)]
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
context_encoding = [x[-max_ctx_len:] for x in context_encoding]
# TODO: max_length in kwargs
# perform batched generation
cont = self._model_generate(
requests=context_encoding,
generate=True,
max_tokens=max_gen_toks,
stop=until,
**kwargs,
)
# cache generations
for output, context in zip(cont, context):
generated_text = output.outputs[0].text
res[key].append(generated_text)
self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), generated_text
)
pbar.update(1)
# reorder this group of results back to original unsorted form
res[key] = re_ord.get_original(res[key])
pbar.close()
return grouper.get_original(res)
def _loglikelihood_tokens(
self,
requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
disable_tqdm: bool = False,
) -> List[Tuple[float, bool]]:
res = []
def _collate(x):
toks = x[1] + x[2]
return -len(toks), tuple(toks)
re_ord = utils.Reorderer(requests, _collate)
chunks = utils.chunks(
re_ord.get_reordered(),
n=self.batch_size if self.batch_size != "auto" else 0,
fn=None,
)
pbar = tqdm(total=len(requests), disable=disable_tqdm)
for chunk in chunks:
inps = []
ctxlens = []
for cache_key, context_enc, continuation_enc in chunk:
inp = (context_enc + continuation_enc)[-(self.max_length) :]
ctxlen = len(context_enc) - max(
0, len(context_enc) + len(continuation_enc) - (self.max_length)
)
inps.append(inp)
ctxlens.append(ctxlen)
outputs = self._model_generate(requests=inps, generate=False)
for output, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
outputs, ctxlens, chunk
):
answer = self._parse_logprobs(
(context_enc + continuation_enc),
output,
ctxlen,
)
res.append(answer)
# partial caching
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
pbar.update(1)
pbar.close()
return re_ord.get_original(res)
@staticmethod
def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]:
"""Process logprobs and tokens.
:param tokens: list
Tokens from context+continuations
:param outputs: RequestOutput
Contains prompt
:param ctxlen: int
Length of context (so we can slice them away and only keep the predictions)
:return:
continuation_logprobs: float
Log probabilities of continuation tokens
is_greedy: bool
Whether argmax matches given continuation exactly
"""
# prompt_logprobs = [None, {}*len(context-1)]
continuation_logprobs_dicts = outputs.prompt_logprobs
# Calculate continuation_logprobs
# assume ctxlen always > 1
continuation_logprobs = sum(
logprob_dict.get(token)
for token, logprob_dict in zip(
tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
)
)
# Determine if is_greedy
is_greedy = True
for token, logprob_dict in zip(
tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
):
# Get the token with the maximum log probability from the logprob_dict
if logprob_dict: # Ensure the logprob_dict is not None
top_token = max(logprob_dict, key=logprob_dict.get)
if top_token != token:
is_greedy = False
break
return continuation_logprobs, is_greedy
...@@ -10,7 +10,7 @@ import collections ...@@ -10,7 +10,7 @@ import collections
import importlib.util import importlib.util
import fnmatch import fnmatch
from typing import Iterator, List, Literal, Union from typing import Iterator, List, Literal, Union, Any, Callable
import gc import gc
import torch import torch
...@@ -84,6 +84,32 @@ def join_iters(iters): ...@@ -84,6 +84,32 @@ def join_iters(iters):
def chunks(iter, n: int = 0, fn=None): def chunks(iter, n: int = 0, fn=None):
"""
Divides an iterable into chunks of specified size or based on a given function.
Useful for batching
Parameters:
- iter: The input iterable to be divided into chunks.
- n: An integer representing the size of each chunk. Default is 0.
- fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.
Returns:
An iterator that yields chunks of the input iterable.
Example usage:
```
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
for chunk in chunks(data, 3):
print(chunk)
```
Output:
```
[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
[10]
```
"""
arr = [] arr = []
for i, x in enumerate(iter): for i, x in enumerate(iter):
arr.append(x) arr.append(x)
...@@ -194,7 +220,13 @@ def make_disjoint_window(pair): ...@@ -194,7 +220,13 @@ def make_disjoint_window(pair):
class Reorderer: class Reorderer:
def __init__(self, arr, fn) -> None: def __init__(self, arr: List[Any], fn: Callable) -> None:
"""Reorder an array according to some function
Args:
arr (List[Any]): The initial array
fn (Callable[[Any], Any]): A function to determine the priority of elements
"""
self.size = len(arr) self.size = len(arr)
arr = list(enumerate(arr)) arr = list(enumerate(arr))
arr = group(arr, lambda x: fn(x[1])) arr = group(arr, lambda x: fn(x[1]))
...@@ -206,9 +238,22 @@ class Reorderer: ...@@ -206,9 +238,22 @@ class Reorderer:
self.arr = arr self.arr = arr
def get_reordered(self): def get_reordered(self):
"""Gets the reordered array
Returns:
List[Any]: The reordered array
"""
return [x[1] for x in self.arr] return [x[1] for x in self.arr]
def get_original(self, newarr): def get_original(self, newarr):
"""Restores the original order of a new array based on the old array's order
Args:
newarr (List[Any]): The array to be restored
Returns:
List[Any]: The array restored to the original order
"""
res = [None] * self.size res = [None] * self.size
cov = [False] * self.size cov = [False] * self.size
...@@ -435,7 +480,6 @@ yaml.add_constructor("!function", import_function) ...@@ -435,7 +480,6 @@ yaml.add_constructor("!function", import_function)
def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None): def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None):
if yaml_config is None: if yaml_config is None:
with open(yaml_path, "rb") as file: with open(yaml_path, "rb") as file:
yaml_config = yaml.full_load(file) yaml_config = yaml.full_load(file)
...@@ -456,7 +500,6 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None): ...@@ -456,7 +500,6 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None):
include_path.reverse() include_path.reverse()
final_yaml_config = {} final_yaml_config = {}
for path in include_path: for path in include_path:
# Assumes that path is a full path. # Assumes that path is a full path.
# If not found, assume the included yaml # If not found, assume the included yaml
# is in the same dir as the original yaml # is in the same dir as the original yaml
......
...@@ -71,6 +71,7 @@ promptsource = [ ...@@ -71,6 +71,7 @@ promptsource = [
gptq = ["auto-gptq[triton] @ git+https://github.com/PanQiWei/AutoGPTQ"] gptq = ["auto-gptq[triton] @ git+https://github.com/PanQiWei/AutoGPTQ"]
anthropic = ["anthropic"] anthropic = ["anthropic"]
openai = ["openai", "tiktoken"] openai = ["openai", "tiktoken"]
vllm = ["vllm"]
all = [ all = [
"lm_eval[dev]", "lm_eval[dev]",
"lm_eval[testing]", "lm_eval[testing]",
...@@ -80,5 +81,6 @@ all = [ ...@@ -80,5 +81,6 @@ all = [
"lm_eval[promptsource]", "lm_eval[promptsource]",
"lm_eval[gptq]", "lm_eval[gptq]",
"lm_eval[anthropic]", "lm_eval[anthropic]",
"lm_eval[openai]" "lm_eval[openai]",
"lm_eval[vllm]",
] ]
import pytest
from typing import List
from lm_eval.api.instance import Instance
import lm_eval.tasks as tasks
import sys
import torch
@pytest.mark.skip(reason="requires CUDA")
class TEST_VLLM:
vllm = pytest.importorskip("vllm")
try:
from lm_eval.models.vllm_causallms import VLLM
LM = VLLM(pretrained="EleutherAI/pythia-70m")
except ModuleNotFoundError:
pass
torch.use_deterministic_algorithms(True)
tasks.initialize_tasks()
multiple_choice_task = tasks.TASK_REGISTRY.get("arc_easy")() # type: ignore
multiple_choice_task.build_all_requests(limit=10, rank=0, world_size=1)
MULTIPLE_CH: List[Instance] = multiple_choice_task.instances
generate_until_task = tasks.TASK_REGISTRY.get("gsm8k")() # type: ignore
generate_until_task.build_all_requests(limit=10, rank=0, world_size=1)
generate_until_task._config.generation_kwargs["max_gen_toks"] = 10
generate_until: List[Instance] = generate_until_task.instances
rolling_task = tasks.TASK_REGISTRY.get("wikitext")() # type: ignore
rolling_task.build_all_requests(limit=10, rank=0, world_size=1)
ROLLING: List[Instance] = rolling_task.instances
# TODO: make proper tests
def test_logliklihood(self) -> None:
res = self.LM.loglikelihood(self.MULTIPLE_CH)
assert len(res) == len(self.MULTIPLE_CH)
for x in res:
assert isinstance(x[0], float)
def test_generate_until(self) -> None:
res = self.LM.generate_until(self.generate_until)
assert len(res) == len(self.generate_until)
for x in res:
assert isinstance(x, str)
def test_logliklihood_rolling(self) -> None:
res = self.LM.loglikelihood_rolling(self.ROLLING)
for x in res:
assert isinstance(x, float)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment