"test/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "0272f393c4415f45399b453de7179158b32f5e5d"
Commit 694dc642 authored by baberabb's avatar baberabb
Browse files

fix data + tensor parallel

parent 5075de60
from collections import defaultdict from collections import defaultdict
import os
from itertools import islice from itertools import islice
from typing import List, Tuple, Optional, Literal, Union, Any from typing import List, Tuple, Optional, Literal, Union, Any
from transformers import AutoTokenizer from transformers import AutoTokenizer
...@@ -9,6 +10,7 @@ from tqdm import tqdm ...@@ -9,6 +10,7 @@ from tqdm import tqdm
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval import utils from lm_eval import utils
from ray.util.multiprocessing import Pool from ray.util.multiprocessing import Pool
import multiprocessing
try: try:
...@@ -21,13 +23,15 @@ eval_logger = utils.eval_logger ...@@ -21,13 +23,15 @@ eval_logger = utils.eval_logger
def run_inference_one_gpu(model_args: dict, sampling_params, requests: List[int]): def run_inference_one_gpu(model_args: dict, sampling_params, requests: List[int]):
# gpu_id = [x for x in gpu_id]
# os.environ["CUDA_VISIBLE_DEVICES"]= str(gpu_id)
llm = LLM(**model_args) llm = LLM(**model_args)
return llm.generate(prompt_token_ids=requests, sampling_params=sampling_params) return llm.generate(prompt_token_ids=requests, sampling_params=sampling_params)
def chunk_list(my_list: List[Any], chunk_size: int): def chunk_list(lst, n):
for i in range(0, len(my_list), chunk_size): chunk_size = len(lst) // n + (1 if len(lst) % n else 0)
yield list(islice(my_list, i, i + chunk_size)) return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
@register_model("vllm") @register_model("vllm")
...@@ -80,6 +84,8 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -80,6 +84,8 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
} }
if self.data_parallel <= 1: if self.data_parallel <= 1:
self.model = LLM(**self.model_args) self.model = LLM(**self.model_args)
else:
self.model_args["worker_use_ray"] = True
self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer = AutoTokenizer.from_pretrained(
pretrained, pretrained,
revision=revision, revision=revision,
...@@ -146,10 +152,8 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -146,10 +152,8 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
requests = chunk_list(requests, self.data_parallel) requests = chunk_list(requests, self.data_parallel)
inputs = [(self.model_args, sampling_params, req) for req in requests] inputs = [(self.model_args, sampling_params, req) for req in requests]
with Pool() as pool: with Pool(self.data_parallel) as pool:
results = pool.starmap( results = pool.starmap(run_inference_one_gpu, inputs)
run_inference_one_gpu, inputs, self.data_parallel
)
# flatten results # flatten results
return [item for sublist in results for item in sublist] return [item for sublist in results for item in sublist]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment