Commit b97c561f authored by baberabb's avatar baberabb
Browse files

simpler request chunking

parent c57ca81a
...@@ -7,6 +7,7 @@ import copy ...@@ -7,6 +7,7 @@ import copy
from tqdm import tqdm from tqdm import tqdm
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval import utils from lm_eval import utils
import numpy as np
try: try:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
...@@ -18,7 +19,7 @@ except ModuleNotFoundError: ...@@ -18,7 +19,7 @@ except ModuleNotFoundError:
eval_logger = utils.eval_logger eval_logger = utils.eval_logger
# adapter from https://github.com/vllm-project/vllm/issues/367#issuecomment-1788341727 # adapted from https://github.com/vllm-project/vllm/issues/367#issuecomment-1788341727
def run_inference_one_model(model_args: dict, sampling_params, requests: List[int]): def run_inference_one_model(model_args: dict, sampling_params, requests: List[int]):
# gpu_id = [x for x in gpu_id] # gpu_id = [x for x in gpu_id]
# os.environ["CUDA_VISIBLE_DEVICES"]= str(gpu_id) # os.environ["CUDA_VISIBLE_DEVICES"]= str(gpu_id)
...@@ -26,11 +27,6 @@ def run_inference_one_model(model_args: dict, sampling_params, requests: List[in ...@@ -26,11 +27,6 @@ def run_inference_one_model(model_args: dict, sampling_params, requests: List[in
return llm.generate(prompt_token_ids=requests, sampling_params=sampling_params) return llm.generate(prompt_token_ids=requests, sampling_params=sampling_params)
def chunk_list(lst, n):
chunk_size = len(lst) // n + (1 if len(lst) % n else 0)
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
@register_model("vllm") @register_model("vllm")
class VLLM(LM): class VLLM(LM):
_DEFAULT_MAX_LENGTH = 2048 _DEFAULT_MAX_LENGTH = 2048
...@@ -130,7 +126,7 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -130,7 +126,7 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
def _model_generate( def _model_generate(
self, self,
requests: List[int] = None, requests: List[List[int]] = None,
generate: bool = False, generate: bool = False,
max_tokens: int = None, max_tokens: int = None,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
...@@ -146,7 +142,7 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -146,7 +142,7 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
temperature=0, prompt_logprobs=2, max_tokens=1 temperature=0, prompt_logprobs=2, max_tokens=1
) )
if self.data_parallel > 1: if self.data_parallel > 1:
requests = chunk_list(requests, self.data_parallel) requests = np.array_split(requests, self.data_parallel)
inputs = [(self.model_args, sampling_params, req) for req in requests] inputs = [(self.model_args, sampling_params, req) for req in requests]
with Pool(self.data_parallel) as pool: with Pool(self.data_parallel) as pool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment