Unverified Commit 107e022c authored by Yang Yong's avatar Yang Yong Committed by GitHub
Browse files

Support prompt template for LightllmApi. Update LightllmApi token bucket. (#945)

parent c54a5d3b
...@@ -5,18 +5,35 @@ from opencompass.runners import LocalRunner ...@@ -5,18 +5,35 @@ from opencompass.runners import LocalRunner
from opencompass.tasks import OpenICLInferTask from opencompass.tasks import OpenICLInferTask
with read_base(): with read_base():
from .summarizers.leaderboard import summarizer
from .datasets.humaneval.humaneval_gen import humaneval_datasets from .datasets.humaneval.humaneval_gen import humaneval_datasets
datasets = [*humaneval_datasets] datasets = [*humaneval_datasets]
'''
# Prompt template for InternLM2-Chat
# https://github.com/InternLM/InternLM/blob/main/chat/chat_format.md
_meta_template = dict(
begin='<|im_start|>system\nYou are InternLM2-Chat, a harmless AI assistant<|im_end|>\n',
round=[
dict(role='HUMAN', begin='<|im_start|>user\n', end='<|im_end|>\n'),
dict(role='BOT', begin='<|im_start|>assistant\n', end='<|im_end|>\n', generate=True),
]
)
'''
_meta_template = None
models = [ models = [
dict( dict(
abbr='LightllmAPI', abbr='LightllmAPI',
type=LightllmAPI, type=LightllmAPI,
url='http://localhost:8080/generate', url='http://localhost:1030/generate',
input_format='<input_text_to_replace>', meta_template=_meta_template,
max_seq_len=2048,
batch_size=32, batch_size=32,
rate_per_worker=32,
retry=4,
generation_kwargs=dict( generation_kwargs=dict(
do_sample=False, do_sample=False,
ignore_eos=False, ignore_eos=False,
......
...@@ -21,14 +21,21 @@ We use the evaluation of Humaneval with the llama2-7B model as an example. ...@@ -21,14 +21,21 @@ We use the evaluation of Humaneval with the llama2-7B model as an example.
```shell ```shell
python -m lightllm.server.api_server --model_dir /path/llama2-7B \ python -m lightllm.server.api_server --model_dir /path/llama2-7B \
--host 0.0.0.0 \ --host 0.0.0.0 \
--port 8080 \ --port 1030 \
--nccl_port 2066 \
--max_req_input_len 4096 \
--max_req_total_len 6144 \
--tp 1 \ --tp 1 \
--trust_remote_code \
--max_total_token_num 120000 --max_total_token_num 120000
``` ```
\*\*Note: \*\* tp can be configured to enable TensorParallel inference on several gpus, suitable for the inference of very large models. \*\*Note: \*\* tp can be configured to enable TensorParallel inference on several gpus, suitable for the inference of very large models.
\*\*Note: \*\* The max_total_token_num in the above command will affect the throughput performance during testing. It can be configured according to the documentation on the [Lightllm homepage](https://github.com/ModelTC/lightllm). As long as it does not run out of memory, it is often better to set it as high as possible. \*\*Note: \*\* The max_total_token_num in the above command will affect the throughput performance during testing. It can be configured according to the documentation on the [Lightllm homepage](https://github.com/ModelTC/lightllm). As long as it does not run out of memory, it is often better to set it as high as possible.
\*\*Note: \*\* If you want to start multiple LightLLM services on the same machine, you need to reconfigure the above port and nccl_port.
You can use the following Python script to quickly test whether the current service has been successfully started. You can use the following Python script to quickly test whether the current service has been successfully started.
```python ```python
......
...@@ -21,14 +21,21 @@ ...@@ -21,14 +21,21 @@
```shell ```shell
python -m lightllm.server.api_server --model_dir /path/llama2-7B \ python -m lightllm.server.api_server --model_dir /path/llama2-7B \
--host 0.0.0.0 \ --host 0.0.0.0 \
--port 8080 \ --port 1030 \
--nccl_port 2066 \
--max_req_input_len 4096 \
--max_req_total_len 6144 \
--tp 1 \ --tp 1 \
--trust_remote_code \
--max_total_token_num 120000 --max_total_token_num 120000
``` ```
**注:** 上述命令可以通过 tp 的数量设置,在 tp 张卡上进行 TensorParallel 推理,适用于较大的模型的推理。 **注:** 上述命令可以通过 tp 的数量设置,在 tp 张卡上进行 TensorParallel 推理,适用于较大的模型的推理。
**注:** 上述命令中的 max_total_token_num,会影响测试过程中的吞吐性能,可以根据 [Lightllm 主页](https://github.com/ModelTC/lightllm) 上的文档,进行设置。只要不爆显存,往往设置越大越好。 **注:** 上述命令中的 max_total_token_num,会影响测试过程中的吞吐性能,可以根据 [Lightllm 主页](https://github.com/ModelTC/lightllm) 上的文档,进行设置。只要不爆显存,往往设置越大越好。
**注:** 如果要在同一个机器上起多个 Lightllm 服务,需要重新设定上面的 port 和 nccl_port。
可以使用下面的 Python 脚本简单测试一下当前服务是否已经起成功 可以使用下面的 Python 脚本简单测试一下当前服务是否已经起成功
```python ```python
......
...@@ -8,11 +8,12 @@ import requests ...@@ -8,11 +8,12 @@ import requests
from opencompass.registry import MODELS from opencompass.registry import MODELS
from opencompass.utils.logging import get_logger from opencompass.utils.logging import get_logger
from .base_api import BaseAPIModel from .base import BaseModel
from .base_api import TokenBucket
@MODELS.register_module() @MODELS.register_module()
class LightllmAPI(BaseAPIModel): class LightllmAPI(BaseModel):
is_api: bool = True is_api: bool = True
...@@ -20,23 +21,21 @@ class LightllmAPI(BaseAPIModel): ...@@ -20,23 +21,21 @@ class LightllmAPI(BaseAPIModel):
self, self,
path: str = 'LightllmAPI', path: str = 'LightllmAPI',
url: str = 'http://localhost:8080/generate', url: str = 'http://localhost:8080/generate',
input_format: str = '<input_text_to_replace>',
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None, meta_template: Optional[Dict] = None,
rate_per_worker: int = 2,
retry: int = 2, retry: int = 2,
generation_kwargs: Optional[Dict] = dict(), generation_kwargs: Optional[Dict] = dict(),
): ):
super().__init__(path=path, super().__init__(path=path,
max_seq_len=max_seq_len,
meta_template=meta_template, meta_template=meta_template,
retry=retry,
generation_kwargs=generation_kwargs) generation_kwargs=generation_kwargs)
self.logger = get_logger() self.logger = get_logger()
self.url = url self.url = url
self.input_format = input_format self.retry = retry
self.generation_kwargs = generation_kwargs self.generation_kwargs = generation_kwargs
self.max_out_len = self.generation_kwargs.get('max_new_tokens', 1024) self.max_out_len = self.generation_kwargs.get('max_new_tokens', 1024)
self.token_bucket = TokenBucket(rate_per_worker, False)
def generate(self, inputs: List[str], max_out_len: int, def generate(self, inputs: List[str], max_out_len: int,
**kwargs) -> List[str]: **kwargs) -> List[str]:
...@@ -64,8 +63,6 @@ class LightllmAPI(BaseAPIModel): ...@@ -64,8 +63,6 @@ class LightllmAPI(BaseAPIModel):
self.wait() self.wait()
header = {'content-type': 'application/json'} header = {'content-type': 'application/json'}
try: try:
input = self.input_format.replace('<input_text_to_replace>',
input)
data = dict(inputs=input, parameters=self.generation_kwargs) data = dict(inputs=input, parameters=self.generation_kwargs)
raw_response = requests.post(self.url, raw_response = requests.post(self.url,
headers=header, headers=header,
...@@ -118,8 +115,6 @@ class LightllmAPI(BaseAPIModel): ...@@ -118,8 +115,6 @@ class LightllmAPI(BaseAPIModel):
self.wait() self.wait()
header = {'content-type': 'application/json'} header = {'content-type': 'application/json'}
try: try:
input = self.input_format.replace('<input_text_to_replace>',
input)
data = dict(inputs=input, parameters=self.generation_kwargs) data = dict(inputs=input, parameters=self.generation_kwargs)
raw_response = requests.post(self.url, raw_response = requests.post(self.url,
headers=header, headers=header,
...@@ -156,3 +151,10 @@ class LightllmAPI(BaseAPIModel): ...@@ -156,3 +151,10 @@ class LightllmAPI(BaseAPIModel):
raise RuntimeError('Calling LightllmAPI failed after retrying for ' raise RuntimeError('Calling LightllmAPI failed after retrying for '
f'{max_num_retries} times. Check the logs for ' f'{max_num_retries} times. Check the logs for '
'details.') 'details.')
def wait(self):
"""Wait till the next query can be sent.
Applicable in both single-thread and multi-thread environments.
"""
return self.token_bucket.get_token()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment