"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "bf9a641f1a51368af5f3ae99cc460107d4fa2103"
Unverified Commit c54a5d3b authored by RunningLeon's avatar RunningLeon Committed by GitHub
Browse files

Support get_ppl for TurbomindModel (#878)

* update ppl for turbomindmodel

* update api_server

* rename config and set thread_safe for pytorch engine if possible
parent caf1cf8a
...@@ -6,9 +6,9 @@ with read_base(): ...@@ -6,9 +6,9 @@ with read_base():
from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets
from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets from .datasets.SuperGLUE_WiC.SuperGLUE_WiC_gen_d06864 import WiC_datasets
from .datasets.SuperGLUE_WSC.SuperGLUE_WSC_gen_7902a7 import WSC_datasets
from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets from .datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets
from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
from .datasets.race.race_gen_69ee4f import race_datasets from .datasets.race.race_gen_69ee4f import race_datasets
from .datasets.crowspairs.crowspairs_gen_381af0 import crowspairs_datasets from .datasets.crowspairs.crowspairs_gen_381af0 import crowspairs_datasets
# and output the results in a choosen format # and output the results in a choosen format
...@@ -24,16 +24,29 @@ meta_template = dict( ...@@ -24,16 +24,29 @@ meta_template = dict(
], ],
eos_token_id=103028) eos_token_id=103028)
models = [ internlm_chat_20b = dict(
dict( type=TurboMindAPIModel,
type=TurboMindAPIModel, abbr='internlm-chat-20b-turbomind',
abbr='internlm-chat-20b-turbomind', api_addr='http://0.0.0.0:23333',
path="internlm-chat-20b", max_out_len=100,
api_addr='http://0.0.0.0:23333', max_seq_len=2048,
max_out_len=100, batch_size=8,
max_seq_len=2048, meta_template=meta_template,
batch_size=8, run_cfg=dict(num_gpus=1, num_procs=1),
meta_template=meta_template, end_str='<eoa>',
run_cfg=dict(num_gpus=1, num_procs=1), )
)
] internlm_chat_7b = dict(
type=TurboMindAPIModel,
abbr='internlm-chat-7b-turbomind',
api_addr='http://0.0.0.0:23333',
max_out_len=100,
max_seq_len=2048,
batch_size=16,
meta_template=meta_template,
run_cfg=dict(num_gpus=1, num_procs=1),
end_str='<eoa>',
)
models = [internlm_chat_20b]
...@@ -14,15 +14,25 @@ with read_base(): ...@@ -14,15 +14,25 @@ with read_base():
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
models = [ internlm_chat_20b = dict(
dict( type=TurboMindAPIModel,
type=TurboMindAPIModel, abbr='internlm-chat-20b-turbomind',
abbr='internlm-chat-20b-turbomind', api_addr='http://0.0.0.0:23333',
path="internlm-chat-20b", max_out_len=100,
api_addr='http://0.0.0.0:23333', max_seq_len=2048,
max_out_len=100, batch_size=8,
max_seq_len=2048, run_cfg=dict(num_gpus=1, num_procs=1),
batch_size=8, )
run_cfg=dict(num_gpus=1, num_procs=1),
) internlm_chat_7b = dict(
] type=TurboMindAPIModel,
abbr='internlm-chat-7b-turbomind',
api_addr='http://0.0.0.0:23333',
max_out_len=100,
max_seq_len=2048,
batch_size=16,
run_cfg=dict(num_gpus=1, num_procs=1),
)
models = [internlm_chat_20b]
...@@ -54,6 +54,10 @@ class LmdeployPytorchModel(BaseModel): ...@@ -54,6 +54,10 @@ class LmdeployPytorchModel(BaseModel):
if engine_config is not None: if engine_config is not None:
from lmdeploy.messages import PytorchEngineConfig from lmdeploy.messages import PytorchEngineConfig
engine_config = PytorchEngineConfig(**engine_config) engine_config = PytorchEngineConfig(**engine_config)
# set thread_safe
if hasattr(engine_config, 'thread_safe'):
engine_config.thread_safe = True
if gen_config is not None: if gen_config is not None:
from lmdeploy.messages import EngineGenerationConfig from lmdeploy.messages import EngineGenerationConfig
gen_config = EngineGenerationConfig(**gen_config) gen_config = EngineGenerationConfig(**gen_config)
......
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import numpy as np
from opencompass.models.base import BaseModel from opencompass.models.base import BaseModel
from opencompass.utils.logging import get_logger from opencompass.utils.logging import get_logger
from opencompass.utils.prompt import PromptList from opencompass.utils.prompt import PromptList
...@@ -161,3 +163,29 @@ class TurboMindModel(BaseModel): ...@@ -161,3 +163,29 @@ class TurboMindModel(BaseModel):
if end_str: if end_str:
response = response.split(end_str)[0] response = response.split(end_str)[0]
return response return response
def get_ppl(self,
inputs: List[str],
mask_length: Optional[List[int]] = None) -> List[float]:
"""Get perplexity scores given a list of inputs.
Args:
inputs (List[str]): A list of strings.
mask_length (Optional[List[int]]): A list of mask lengths. If
provided, the perplexity scores will be calculated with the
first mask_length[i] tokens masked out. It's okay to skip
its implementation if advanced features in PPLInfernecer is
not needed.
Returns:
np.ndarray: The perplexity scores in shape of (N,)
"""
assert isinstance(
inputs, List), f'List(str) is expected, but got {type(inputs)}'
results = []
for text in inputs:
input_ids = self.tokenizer.encode(text)
res = self.generators[0].get_ppl(input_ids)
results.append(res)
results = np.concatenate(results)
return results
...@@ -20,30 +20,31 @@ def valid_str(string, coding='utf-8'): ...@@ -20,30 +20,31 @@ def valid_str(string, coding='utf-8'):
class TurboMindAPIModel(BaseModel): class TurboMindAPIModel(BaseModel):
"""Model wrapper for TurboMind Triton Inference Server gRPC API. """Model wrapper for lmdeploy api server.
Args: Args:
path (str): The name of OpenAI's model. api_addr (str): The address (ip:port format) of lmdeploy's
tis_addr (str): The address (ip:port format) of turbomind's api server.
triton inference server
max_seq_len (int): The maximum allowed sequence length of a model. max_seq_len (int): The maximum allowed sequence length of a model.
Note that the length of prompt + generated tokens shall not exceed Note that the length of prompt + generated tokens shall not exceed
this value. Defaults to 2048. this value. Defaults to 2048.
meta_template (Dict, optional): The model's meta prompt meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or template if needed, in case the requirement of injecting or
wrapping of any meta instructions. wrapping of any meta instructions.
end_str (str, optional): Whether to trim generated strings with end_str
if the model has special ending strings that are not handled well.
Defaults to None.
""" """
is_api: bool = True is_api: bool = True
def __init__( def __init__(self,
self, api_addr: str = 'http://0.0.0.0:23333',
path: str, max_seq_len: int = 2048,
api_addr: str = 'http://0.0.0.0:23333', meta_template: Optional[Dict] = None,
max_seq_len: int = 2048, end_str: Optional[str] = None,
meta_template: Optional[Dict] = None, **kwargs):
): super().__init__(path='',
super().__init__(path=path,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
meta_template=meta_template) meta_template=meta_template)
from lmdeploy.serve.openai.api_client import APIClient from lmdeploy.serve.openai.api_client import APIClient
...@@ -55,6 +56,7 @@ class TurboMindAPIModel(BaseModel): ...@@ -55,6 +56,7 @@ class TurboMindAPIModel(BaseModel):
if meta_template and 'eos_token_id' in meta_template: if meta_template and 'eos_token_id' in meta_template:
self.eos_token_id = meta_template['eos_token_id'] self.eos_token_id = meta_template['eos_token_id']
self.api_addr = api_addr self.api_addr = api_addr
self.end_str = end_str
def generate( def generate(
self, self,
...@@ -73,7 +75,10 @@ class TurboMindAPIModel(BaseModel): ...@@ -73,7 +75,10 @@ class TurboMindAPIModel(BaseModel):
between 0 and 2. Higher values like 0.8 will make the output between 0 and 2. Higher values like 0.8 will make the output
more random, while lower values like 0.2 will make it more more random, while lower values like 0.2 will make it more
focused and deterministic. Defaults to 0.7. focused and deterministic. Defaults to 0.7.
end_str (str, optional): Whether to trim generated strings
with end_str if the model has special ending strings
that are not handled well.
Defaults to None.
Returns: Returns:
List[str]: A list of generated strings. List[str]: A list of generated strings.
""" """
...@@ -82,7 +87,8 @@ class TurboMindAPIModel(BaseModel): ...@@ -82,7 +87,8 @@ class TurboMindAPIModel(BaseModel):
results = list( results = list(
executor.map(self._generate, inputs, executor.map(self._generate, inputs,
[max_out_len] * len(inputs), [max_out_len] * len(inputs),
[temperature] * len(inputs))) [temperature] * len(inputs),
[self.end_str] * len(inputs)))
return results return results
def get_token_len(self, prompt: str) -> int: def get_token_len(self, prompt: str) -> int:
...@@ -97,7 +103,7 @@ class TurboMindAPIModel(BaseModel): ...@@ -97,7 +103,7 @@ class TurboMindAPIModel(BaseModel):
return self.token_bucket.get_token() return self.token_bucket.get_token()
def _generate(self, prompt: str or PromptList, max_out_len: int, def _generate(self, prompt: str or PromptList, max_out_len: int,
temperature: float) -> str: temperature: float, end_str: str) -> str:
"""Generate results given a list of inputs. """Generate results given a list of inputs.
Args: Args:
...@@ -127,4 +133,6 @@ class TurboMindAPIModel(BaseModel): ...@@ -127,4 +133,6 @@ class TurboMindAPIModel(BaseModel):
top_k=1): top_k=1):
response += output['choices'][0]['text'] response += output['choices'][0]['text']
response = valid_str(response) response = valid_str(response)
if end_str:
response = response.split(end_str)[0]
return response return response
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment