"tests/models/unets/test_models_unet_2d.py" did not exist on "4d1cce2fd01056515f0f353322a231164a4a5c5d"
Commit 5075de60 authored by baberabb's avatar baberabb
Browse files

fix chunking of inputs

parent 9c0d4e93
from collections import defaultdict from collections import defaultdict
from typing import List, Tuple, Optional, Literal, Union from itertools import islice
from typing import List, Tuple, Optional, Literal, Union, Any
from transformers import AutoTokenizer from transformers import AutoTokenizer
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.model import LM from lm_eval.api.model import LM
...@@ -24,6 +25,11 @@ def run_inference_one_gpu(model_args: dict, sampling_params, requests: List[int] ...@@ -24,6 +25,11 @@ def run_inference_one_gpu(model_args: dict, sampling_params, requests: List[int]
return llm.generate(prompt_token_ids=requests, sampling_params=sampling_params) return llm.generate(prompt_token_ids=requests, sampling_params=sampling_params)
def chunk_list(my_list: List[Any], chunk_size: int):
for i in range(0, len(my_list), chunk_size):
yield list(islice(my_list, i, i + chunk_size))
@register_model("vllm") @register_model("vllm")
class VLLM(LM): class VLLM(LM):
_DEFAULT_MAX_LENGTH = 2048 _DEFAULT_MAX_LENGTH = 2048
...@@ -137,16 +143,13 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -137,16 +143,13 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
temperature=0, prompt_logprobs=2, max_tokens=1 temperature=0, prompt_logprobs=2, max_tokens=1
) )
if self.data_parallel > 1: if self.data_parallel > 1:
req_list = [] requests = chunk_list(requests, self.data_parallel)
for replicas in range(self.data_parallel): inputs = [(self.model_args, sampling_params, req) for req in requests]
reqs = utils.create_iterator(
requests, rank=replicas, world_size=self.data_parallel
)
req_list.append(reqs)
inputs = [(self.model_args, sampling_params, req) for req in req_list]
with Pool(processes=self.data_parallel) as pool: with Pool() as pool:
results = pool.starmap(run_inference_one_gpu, inputs) results = pool.starmap(
run_inference_one_gpu, inputs, self.data_parallel
)
# flatten results # flatten results
return [item for sublist in results for item in sublist] return [item for sublist in results for item in sublist]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment