Unverified Commit 27a3da96 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Improve data-parallel request partitioning for VLLM (#1477)

* add undistribute + use more_itertools

* remove divide() util fn

* add more_itertools as dependency
parent 284dd80d
import collections
import fnmatch
import gc
import itertools
import time
from functools import wraps
from typing import (
......@@ -262,55 +263,44 @@ def stop_sequences_criteria(
)
def divide(iterable, n) -> List[Iterator]:
"""Divide the elements from *iterable* into *n* parts, maintaining
order.
def undistribute(iterable):
"""
Undoes https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distribute .
>>> group_1, group_2 = divide([1, 2, 3, 4, 5, 6], 2)
Re-interleaves results that have been split using more_itertools.distribute:
>>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
>>> list(group_1)
[1, 2, 3]
[1, 3, 5]
>>> list(group_2)
[4, 5, 6]
[2, 4, 6]
>>> undistribute([group_1, group_2])
[1, 2, 3, 4, 5, 6]
If the length of *iterable* is not evenly divisible by *n*, then the
length of the returned iterables will not be identical:
Handles non-uniform component lengths:
>>> children = divide([1, 2, 3, 4, 5, 6, 7], 3)
>>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
>>> [list(c) for c in children]
[[1, 2, 3], [4, 5], [6, 7]]
[[1, 4, 7], [2, 5], [3, 6]]
>>> undistribute(children)
[1, 2, 3, 4, 5, 6, 7]
If the length of the iterable is smaller than n, then the last returned
iterables will be empty:
Also handles when some iterables are empty:
>>> children = divide([1, 2, 3], 5)
>>> children = distribute(5, [1, 2, 3])
>>> [list(c) for c in children]
[[1], [2], [3], [], []]
This function will exhaust the iterable before returning and may require
significant storage. If order is not important, see :func:`distribute`,
which does not first pull the iterable into memory.
>>> undistribute(children)
[1, 2, 3]
"""
if n < 1:
raise ValueError("n must be at least 1")
try:
iterable[:0]
except TypeError:
seq = tuple(iterable)
else:
seq = iterable
q, r = divmod(len(seq), n)
ret = []
stop = 0
for i in range(1, n + 1):
start = stop
stop += q + 1 if i <= r else q
ret.append(iter(seq[start:stop]))
return ret
return [
x
for x in itertools.chain.from_iterable(
itertools.zip_longest(*[list(x) for x in iterable])
)
if x is not None
]
def retry_on_specific_exceptions(
......
......@@ -2,12 +2,13 @@ import copy
from importlib.util import find_spec
from typing import List, Literal, Optional, Tuple, Union
from more_itertools import distribute
from tqdm import tqdm
from lm_eval.api.instance import Instance
from lm_eval.api.model import TemplateLM
from lm_eval.api.registry import register_model
from lm_eval.models.utils import Collator, divide
from lm_eval.models.utils import Collator, undistribute
from lm_eval.utils import (
eval_logger,
get_rolling_token_windows,
......@@ -181,7 +182,9 @@ class VLLM(TemplateLM):
temperature=0, prompt_logprobs=1, max_tokens=1
)
if self.data_parallel_size > 1:
requests = [list(x) for x in divide(requests, self.data_parallel_size)]
# dispatch requests to all self.data_parallel_size workers, in interleaved fashion
# interleaved important to balance context lengths across workers
requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
inputs = [(self.model_args, sampling_params, req) for req in requests]
with Pool(self.data_parallel_size) as pool:
......@@ -189,7 +192,7 @@ class VLLM(TemplateLM):
# Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
ray.shutdown()
# flatten results
return [item for sublist in results for item in sublist]
return undistribute(results)
outputs = self.model.generate(
prompt_token_ids=requests,
......
......@@ -38,6 +38,7 @@ dependencies = [
"zstandard",
"dill",
"word2number",
"more_itertools",
]
[tool.setuptools.packages.find]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment