Commit da2d160b authored by baberabb's avatar baberabb
Browse files

fixed chunking

parent b97c561f
...@@ -7,7 +7,6 @@ import copy ...@@ -7,7 +7,6 @@ import copy
from tqdm import tqdm from tqdm import tqdm
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval import utils from lm_eval import utils
import numpy as np
try: try:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
...@@ -142,7 +141,7 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -142,7 +141,7 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
temperature=0, prompt_logprobs=2, max_tokens=1 temperature=0, prompt_logprobs=2, max_tokens=1
) )
if self.data_parallel > 1: if self.data_parallel > 1:
requests = np.array_split(requests, self.data_parallel) requests = [list(x) for x in (utils.divide(requests, self.data_parallel))]
inputs = [(self.model_args, sampling_params, req) for req in requests] inputs = [(self.model_args, sampling_params, req) for req in requests]
with Pool(self.data_parallel) as pool: with Pool(self.data_parallel) as pool:
......
...@@ -662,3 +662,55 @@ def stop_sequences_criteria( ...@@ -662,3 +662,55 @@ def stop_sequences_criteria(
], ],
] ]
) )
# from more_itertools
def divide(iterable, n) -> List[Iterator]:
"""Divide the elements from *iterable* into *n* parts, maintaining
order.
>>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6])
>>> list(group_1)
[1, 2, 3]
>>> list(group_2)
[4, 5, 6]
If the length of *iterable* is not evenly divisible by *n*, then the
length of the returned iterables will not be identical:
>>> children = divide(3, [1, 2, 3, 4, 5, 6, 7])
>>> [list(c) for c in children]
[[1, 2, 3], [4, 5], [6, 7]]
If the length of the iterable is smaller than n, then the last returned
iterables will be empty:
>>> children = divide(5, [1, 2, 3])
>>> [list(c) for c in children]
[[1], [2], [3], [], []]
This function will exhaust the iterable before returning and may require
significant storage. If order is not important, see :func:`distribute`,
which does not first pull the iterable into memory.
"""
if n < 1:
raise ValueError("n must be at least 1")
try:
iterable[:0]
except TypeError:
seq = tuple(iterable)
else:
seq = iterable
q, r = divmod(len(seq), n)
ret = []
stop = 0
for i in range(1, n + 1):
start = stop
stop += q + 1 if i <= r else q
ret.append(iter(seq[start:stop]))
return ret
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment