Unverified Commit 53fe3904 authored by Yang Yong's avatar Yang Yong Committed by GitHub
Browse files

fix LightllmApi workers bug (#1113)

parent baed2ed9
...@@ -6,7 +6,7 @@ from opencompass.tasks import OpenICLInferTask ...@@ -6,7 +6,7 @@ from opencompass.tasks import OpenICLInferTask
with read_base(): with read_base():
from .summarizers.leaderboard import summarizer from .summarizers.leaderboard import summarizer
from .datasets.humaneval.humaneval_gen import humaneval_datasets from .datasets.humaneval.humaneval_gen_a82cae import humaneval_datasets
datasets = [*humaneval_datasets] datasets = [*humaneval_datasets]
...@@ -32,7 +32,8 @@ models = [ ...@@ -32,7 +32,8 @@ models = [
url='http://localhost:1030/generate', url='http://localhost:1030/generate',
meta_template=_meta_template, meta_template=_meta_template,
batch_size=32, batch_size=32,
rate_per_worker=32, max_workers_per_task=128,
rate_per_worker=1024,
retry=4, retry=4,
generation_kwargs=dict( generation_kwargs=dict(
do_sample=False, do_sample=False,
......
...@@ -23,6 +23,7 @@ class LightllmAPI(BaseModel): ...@@ -23,6 +23,7 @@ class LightllmAPI(BaseModel):
path: str = 'LightllmAPI', path: str = 'LightllmAPI',
url: str = 'http://localhost:8080/generate', url: str = 'http://localhost:8080/generate',
meta_template: Optional[Dict] = None, meta_template: Optional[Dict] = None,
max_workers_per_task: int = 2,
rate_per_worker: int = 2, rate_per_worker: int = 2,
retry: int = 2, retry: int = 2,
generation_kwargs: Optional[Dict] = dict(), generation_kwargs: Optional[Dict] = dict(),
...@@ -37,6 +38,7 @@ class LightllmAPI(BaseModel): ...@@ -37,6 +38,7 @@ class LightllmAPI(BaseModel):
self.generation_kwargs = generation_kwargs self.generation_kwargs = generation_kwargs
self.max_out_len = self.generation_kwargs.get('max_new_tokens', 1024) self.max_out_len = self.generation_kwargs.get('max_new_tokens', 1024)
self.meta_template = meta_template self.meta_template = meta_template
self.max_workers_per_task = max_workers_per_task
self.token_bucket = TokenBucket(rate_per_worker, False) self.token_bucket = TokenBucket(rate_per_worker, False)
def generate(self, inputs: List[str], max_out_len: int, def generate(self, inputs: List[str], max_out_len: int,
...@@ -53,7 +55,8 @@ class LightllmAPI(BaseModel): ...@@ -53,7 +55,8 @@ class LightllmAPI(BaseModel):
List[str]: A list of generated strings. List[str]: A list of generated strings.
""" """
with ThreadPoolExecutor() as executor: with ThreadPoolExecutor(
max_workers=self.max_workers_per_task) as executor:
results = list( results = list(
executor.map(self._generate, inputs, executor.map(self._generate, inputs,
[self.max_out_len] * len(inputs))) [self.max_out_len] * len(inputs)))
...@@ -103,7 +106,8 @@ class LightllmAPI(BaseModel): ...@@ -103,7 +106,8 @@ class LightllmAPI(BaseModel):
List[str]: A list of generated strings. List[str]: A list of generated strings.
""" """
with ThreadPoolExecutor() as executor: with ThreadPoolExecutor(
max_workers=self.max_workers_per_task) as executor:
results = list( results = list(
executor.map(self._get_ppl, inputs, executor.map(self._get_ppl, inputs,
[self.max_out_len] * len(inputs))) [self.max_out_len] * len(inputs)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment