Unverified Commit 27a3da96 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Improve data-parallel request partitioning for VLLM (#1477)

* add undistribute + use more_itertools

* remove divide() util fn

* add more_itertools as dependency
parent 284dd80d
import collections import collections
import fnmatch import fnmatch
import gc import gc
import itertools
import time import time
from functools import wraps from functools import wraps
from typing import ( from typing import (
...@@ -262,55 +263,44 @@ def stop_sequences_criteria( ...@@ -262,55 +263,44 @@ def stop_sequences_criteria(
) )
def divide(iterable, n) -> List[Iterator]: def undistribute(iterable):
"""Divide the elements from *iterable* into *n* parts, maintaining """
order. Undoes https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distribute .
>>> group_1, group_2 = divide([1, 2, 3, 4, 5, 6], 2) Re-interleaves results that have been split using more_itertools.distribute:
>>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
>>> list(group_1) >>> list(group_1)
[1, 2, 3] [1, 3, 5]
>>> list(group_2) >>> list(group_2)
[4, 5, 6] [2, 4, 6]
>>> undistribute([group_1, group_2])
[1, 2, 3, 4, 5, 6]
If the length of *iterable* is not evenly divisible by *n*, then the Handles non-uniform component lengths:
length of the returned iterables will not be identical:
>>> children = divide([1, 2, 3, 4, 5, 6, 7], 3) >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
>>> [list(c) for c in children] >>> [list(c) for c in children]
[[1, 2, 3], [4, 5], [6, 7]] [[1, 4, 7], [2, 5], [3, 6]]
>>> undistribute(children)
[1, 2, 3, 4, 5, 6, 7]
If the length of the iterable is smaller than n, then the last returned Also handles when some iterables are empty:
iterables will be empty:
>>> children = divide([1, 2, 3], 5) >>> children = distribute(5, [1, 2, 3])
>>> [list(c) for c in children] >>> [list(c) for c in children]
[[1], [2], [3], [], []] [[1], [2], [3], [], []]
>>> undistribute(children)
This function will exhaust the iterable before returning and may require [1, 2, 3]
significant storage. If order is not important, see :func:`distribute`,
which does not first pull the iterable into memory.
""" """
if n < 1:
raise ValueError("n must be at least 1")
try:
iterable[:0]
except TypeError:
seq = tuple(iterable)
else:
seq = iterable
q, r = divmod(len(seq), n)
ret = [] return [
stop = 0 x
for i in range(1, n + 1): for x in itertools.chain.from_iterable(
start = stop itertools.zip_longest(*[list(x) for x in iterable])
stop += q + 1 if i <= r else q )
ret.append(iter(seq[start:stop])) if x is not None
]
return ret
def retry_on_specific_exceptions( def retry_on_specific_exceptions(
......
...@@ -2,12 +2,13 @@ import copy ...@@ -2,12 +2,13 @@ import copy
from importlib.util import find_spec from importlib.util import find_spec
from typing import List, Literal, Optional, Tuple, Union from typing import List, Literal, Optional, Tuple, Union
from more_itertools import distribute
from tqdm import tqdm from tqdm import tqdm
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.model import TemplateLM from lm_eval.api.model import TemplateLM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.utils import Collator, divide from lm_eval.models.utils import Collator, undistribute
from lm_eval.utils import ( from lm_eval.utils import (
eval_logger, eval_logger,
get_rolling_token_windows, get_rolling_token_windows,
...@@ -181,7 +182,9 @@ class VLLM(TemplateLM): ...@@ -181,7 +182,9 @@ class VLLM(TemplateLM):
temperature=0, prompt_logprobs=1, max_tokens=1 temperature=0, prompt_logprobs=1, max_tokens=1
) )
if self.data_parallel_size > 1: if self.data_parallel_size > 1:
requests = [list(x) for x in divide(requests, self.data_parallel_size)] # dispatch requests to all self.data_parallel_size workers, in interleaved fashion
# interleaved important to balance context lengths across workers
requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
inputs = [(self.model_args, sampling_params, req) for req in requests] inputs = [(self.model_args, sampling_params, req) for req in requests]
with Pool(self.data_parallel_size) as pool: with Pool(self.data_parallel_size) as pool:
...@@ -189,7 +192,7 @@ class VLLM(TemplateLM): ...@@ -189,7 +192,7 @@ class VLLM(TemplateLM):
# Invoke ray.shutdown() to prevent hang-ups if subsequent calls required. # Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
ray.shutdown() ray.shutdown()
# flatten results # flatten results
return [item for sublist in results for item in sublist] return undistribute(results)
outputs = self.model.generate( outputs = self.model.generate(
prompt_token_ids=requests, prompt_token_ids=requests,
......
...@@ -38,6 +38,7 @@ dependencies = [ ...@@ -38,6 +38,7 @@ dependencies = [
"zstandard", "zstandard",
"dill", "dill",
"word2number", "word2number",
"more_itertools",
] ]
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment