Unverified Commit b39f5015 authored by Fengzhe Zhou's avatar Fengzhe Zhou Committed by GitHub
Browse files

[Sync] update taco (#1030)

parent 16f29b25
...@@ -201,3 +201,77 @@ class CIReAct(ReAct): ...@@ -201,3 +201,77 @@ class CIReAct(ReAct):
self._session_history.append( self._session_history.append(
dict(role='assistant', content=agent_return.response)) dict(role='assistant', content=agent_return.response))
return agent_return return agent_return
class CIReActMergeRole(CIReAct):
"""如有第一轮 SYSTEM, 则使用 SYSTEM。后续 SYSTEM 使用 USER 合并复数轮 USER USER 与 BOT
交替出现."""
def chat(self, message: str) -> AgentReturn:
for hist in self._session_history:
if hist['role'] == 'system':
hist['role'] = self.system_role
self._inner_history = []
# append the user message for session history
self._session_history.append(dict(role='user', content=message))
agent_return = AgentReturn()
force_stop = False
default_response = '对不起,我无法回答你的问题'
for turn in range(self.max_turn):
prompt = self._protocol.format(
chat_history=self.session_history,
inner_step=self._inner_history,
action_executor=self._action_executor,
force_stop=force_stop)
prompt = self.merge_role(prompt)
response = self._llm.generate_from_template(prompt, 512)
self._inner_history.append(dict(role='assistant',
content=response))
thought, action, action_input = self._protocol.parse(
response, self._action_executor)
action_return: ActionReturn = self._action_executor(
action, action_input)
action_return.thought = thought
agent_return.actions.append(action_return)
if action_return.state == ActionStatusCode.SUCCESS:
# if success, stash model response and system response
self._session_history.append(
dict(role='assistant', content=response))
self._session_history.append(
dict(
role=self.system_role,
content=self._protocol.format_response(action_return)))
agent_return.response = action_return.result['text']
return agent_return
elif action_return.type == self._action_executor.invalid_action.name: # noqa
action_return.errmsg = 'The action is invalid, please check the action name.' # noqa
self._inner_history.append(
dict(role=self.system_role,
content=self._protocol.format_response(action_return)))
if turn == self.max_turn - 1:
force_stop = True
agent_return.response = default_response
self._session_history.append(
dict(role='assistant', content=agent_return.response))
return agent_return
def merge_role(self, inputs):
messages = []
msg_buffer, last_role = [], None
for index, item in enumerate(inputs):
if index == 0 and item['role'] == 'system':
role = 'system'
elif item['role'] == 'assistant':
role = 'assistant'
else:
role = 'user'
if role != last_role and last_role is not None:
messages.append({
'content': '\n'.join(msg_buffer),
'role': last_role
})
msg_buffer = []
msg_buffer.append(item['content'])
last_role = role
messages.append({'content': '\n'.join(msg_buffer), 'role': last_role})
return messages
from .accessory import LLaMA2AccessoryModel # noqa: F401 from .accessory import LLaMA2AccessoryModel # noqa: F401
from .ai360_api import AI360GPT # noqa: F401 from .ai360_api import AI360GPT # noqa: F401
from .alaya import AlayaLM # noqa: F401 from .alaya import AlayaLM # noqa: F401
from .baichuan_api import BaiChuan # noqa: F401 from .baichuan_api import BaiChuan, BaiChuan3 # noqa: F401
from .baidu_api import ERNIEBot # noqa: F401 from .baidu_api import ERNIEBot # noqa: F401
from .base import BaseModel, LMTemplateParser # noqa from .base import BaseModel, LMTemplateParser # noqa
from .base_api import APITemplateParser, BaseAPIModel # noqa from .base_api import APITemplateParser, BaseAPIModel # noqa
...@@ -12,12 +12,14 @@ from .glm import GLM130B # noqa: F401, F403 ...@@ -12,12 +12,14 @@ from .glm import GLM130B # noqa: F401, F403
from .huggingface import HuggingFace # noqa: F401, F403 from .huggingface import HuggingFace # noqa: F401, F403
from .huggingface import HuggingFaceCausalLM # noqa: F401, F403 from .huggingface import HuggingFaceCausalLM # noqa: F401, F403
from .huggingface import HuggingFaceChatGLM3 # noqa: F401, F403 from .huggingface import HuggingFaceChatGLM3 # noqa: F401, F403
from .hunyuan_api import Hunyuan # noqa: F401
from .intern_model import InternLM # noqa: F401, F403 from .intern_model import InternLM # noqa: F401, F403
from .krgpt_api import KrGPT # noqa: F401 from .krgpt_api import KrGPT # noqa: F401
from .lightllm_api import LightllmAPI # noqa: F401 from .lightllm_api import LightllmAPI # noqa: F401
from .llama2 import Llama2, Llama2Chat # noqa: F401, F403 from .llama2 import Llama2, Llama2Chat # noqa: F401, F403
from .lmdeploy_pytorch import LmdeployPytorchModel # noqa: F401 from .lmdeploy_pytorch import LmdeployPytorchModel # noqa: F401
from .minimax_api import MiniMax # noqa: F401 from .minimax_api import MiniMax # noqa: F401
from .mistral_api import Mistral # noqa: F401
from .mixtral import Mixtral # noqa: F401 from .mixtral import Mixtral # noqa: F401
from .modelscope import ModelScope, ModelScopeCausalLM # noqa: F401, F403 from .modelscope import ModelScope, ModelScopeCausalLM # noqa: F401, F403
from .moonshot_api import MoonShot # noqa: F401 from .moonshot_api import MoonShot # noqa: F401
...@@ -28,7 +30,9 @@ from .qwen_api import Qwen # noqa: F401 ...@@ -28,7 +30,9 @@ from .qwen_api import Qwen # noqa: F401
from .sensetime_api import SenseTime # noqa: F401 from .sensetime_api import SenseTime # noqa: F401
from .turbomind import TurboMindModel # noqa: F401 from .turbomind import TurboMindModel # noqa: F401
from .turbomind_tis import TurboMindTisModel # noqa: F401 from .turbomind_tis import TurboMindTisModel # noqa: F401
from .unigpt_api import UniGPT # noqa: F401
from .vllm import VLLM # noqa: F401 from .vllm import VLLM # noqa: F401
from .xunfei_api import XunFei # noqa: F401 from .xunfei_api import XunFei # noqa: F401
from .yayi_api import Yayi # noqa: F401
from .zhipuai_api import ZhiPuAI # noqa: F401 from .zhipuai_api import ZhiPuAI # noqa: F401
from .zhipuai_v2_api import ZhiPuV2AI # noqa: F401 from .zhipuai_v2_api import ZhiPuV2AI # noqa: F401
...@@ -60,13 +60,13 @@ class AI360GPT(BaseAPIModel): ...@@ -60,13 +60,13 @@ class AI360GPT(BaseAPIModel):
def generate( def generate(
self, self,
inputs: List[str or PromptList], inputs: List[PromptType],
max_out_len: int = 512, max_out_len: int = 512,
) -> List[str]: ) -> List[str]:
"""Generate results given a list of inputs. """Generate results given a list of inputs.
Args: Args:
inputs (List[str or PromptList]): A list of strings or PromptDicts. inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
...@@ -83,13 +83,13 @@ class AI360GPT(BaseAPIModel): ...@@ -83,13 +83,13 @@ class AI360GPT(BaseAPIModel):
def _generate( def _generate(
self, self,
input: str or PromptList, input: PromptType,
max_out_len: int = 512, max_out_len: int = 512,
) -> str: ) -> str:
"""Generate results given an input. """Generate results given an input.
Args: Args:
inputs (str or PromptList): A string or PromptDict. inputs (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
......
...@@ -59,13 +59,13 @@ class BaiChuan(BaseAPIModel): ...@@ -59,13 +59,13 @@ class BaiChuan(BaseAPIModel):
def generate( def generate(
self, self,
inputs: List[str or PromptList], inputs: List[PromptType],
max_out_len: int = 512, max_out_len: int = 512,
) -> List[str]: ) -> List[str]:
"""Generate results given a list of inputs. """Generate results given a list of inputs.
Args: Args:
inputs (List[str or PromptList]): A list of strings or PromptDicts. inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
...@@ -82,13 +82,13 @@ class BaiChuan(BaseAPIModel): ...@@ -82,13 +82,13 @@ class BaiChuan(BaseAPIModel):
def _generate( def _generate(
self, self,
input: str or PromptList, input: PromptType,
max_out_len: int = 512, max_out_len: int = 512,
) -> str: ) -> str:
"""Generate results given an input. """Generate results given an input.
Args: Args:
inputs (str or PromptList): A string or PromptDict. inputs (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
...@@ -157,3 +157,127 @@ class BaiChuan(BaseAPIModel): ...@@ -157,3 +157,127 @@ class BaiChuan(BaseAPIModel):
max_num_retries += 1 max_num_retries += 1
raise RuntimeError(response) raise RuntimeError(response)
class BaiChuan3(BaseAPIModel):
def __init__(
self,
path: str,
api_key: str,
url: str,
query_per_second: int = 2,
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None,
retry: int = 2,
): # noqa E125
super().__init__(path=path,
max_seq_len=max_seq_len,
query_per_second=query_per_second,
meta_template=meta_template,
retry=retry)
self.api_key = api_key
self.url = url
self.model = path
def generate(
self,
inputs: List[PromptType],
max_out_len: int = 512,
) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
with ThreadPoolExecutor() as executor:
results = list(
executor.map(self._generate, inputs,
[max_out_len] * len(inputs)))
self.flush()
return results
def _generate(
self,
input: PromptType,
max_out_len: int = 512,
) -> str:
"""Generate results given an input.
Args:
inputs (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
str: The generated string.
"""
assert isinstance(input, (str, PromptList))
if isinstance(input, str):
history = []
prompt = input
else:
messages = []
msg_buffer, last_role = [], None
for item in input:
role = 'BOT' if item['role'] == 'BOT' else 'USER'
if role != last_role and last_role is not None:
messages.append({
'data': '\n'.join(msg_buffer),
'from': 0 if last_role == 'USER' else 1
})
msg_buffer = []
msg_buffer.append(item['prompt'])
last_role = role
messages.append({
'data': '\n'.join(msg_buffer),
'from': 0 if last_role == 'USER' else 1
})
history = messages[:-1]
prompt = messages[-1]['data']
data = {
'access_token_key': self.api_key,
'app_info': {
'id': 123
},
'prompt': {
'data': prompt
},
'history': history,
}
for _ in range(self.retry):
try:
response = requests.post(self.url, json=data)
except Exception as e:
print(e)
continue
if response is None or response.status_code != 200:
code = response.status_code if response else -1
print(f'[chat_api]-[failed] request err, status_code: {code}')
continue
try:
response = response.json()
except Exception as e:
print(e)
continue
print(response)
status = response.get('answer', {}).get('status', 0)
session_status = response.get('session_info', {}).get('status', 0)
if status < 0 or session_status < 0:
print('[chat_api]-[warn] prompt or answer is unsafe')
return 'Rejection: unsafe prompt or answer'
return response.get('answer', {}).get('data', '')
raise RuntimeError(response['msg'])
...@@ -88,13 +88,13 @@ class ERNIEBot(BaseAPIModel): ...@@ -88,13 +88,13 @@ class ERNIEBot(BaseAPIModel):
def generate( def generate(
self, self,
inputs: List[str or PromptList], inputs: List[PromptType],
max_out_len: int = 512, max_out_len: int = 512,
) -> List[str]: ) -> List[str]:
"""Generate results given a list of inputs. """Generate results given a list of inputs.
Args: Args:
inputs (List[str or PromptList]): A list of strings or PromptDicts. inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
...@@ -111,13 +111,13 @@ class ERNIEBot(BaseAPIModel): ...@@ -111,13 +111,13 @@ class ERNIEBot(BaseAPIModel):
def _generate( def _generate(
self, self,
input: str or PromptList, input: PromptType,
max_out_len: int = 512, max_out_len: int = 512,
) -> str: ) -> str:
"""Generate results given an input. """Generate results given an input.
Args: Args:
inputs (str or PromptList): A string or PromptDict. inputs (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
......
...@@ -129,7 +129,7 @@ class BaseModel: ...@@ -129,7 +129,7 @@ class BaseModel:
applicable. applicable.
Args: Args:
prompt_template (List[str or PromptList]): A prompt prompt_template (List[PromptType]): A prompt
template (potentially before being wrapped by meta template). template (potentially before being wrapped by meta template).
mode (str): Parsing mode. Choices are 'ppl' and 'gen'. mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
...@@ -266,7 +266,7 @@ class LMTemplateParser: ...@@ -266,7 +266,7 @@ class LMTemplateParser:
applicable. applicable.
Args: Args:
prompt_template (List[str or PromptList]): A prompt prompt_template (List[PromptType]): A prompt
template (potentially before being wrapped by meta template). template (potentially before being wrapped by meta template).
mode (str): Parsing mode. Choices are 'ppl' and 'gen'. mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
......
...@@ -60,7 +60,7 @@ class BaseAPIModel(BaseModel): ...@@ -60,7 +60,7 @@ class BaseAPIModel(BaseModel):
"""Generate results given a list of inputs. """Generate results given a list of inputs.
Args: Args:
inputs (List[str or PromptList]): A list of strings or PromptDicts. inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
...@@ -111,7 +111,7 @@ class BaseAPIModel(BaseModel): ...@@ -111,7 +111,7 @@ class BaseAPIModel(BaseModel):
"""Get perplexity scores given a list of inputs. """Get perplexity scores given a list of inputs.
Args: Args:
inputs (List[str or PromptList]): A list of strings. inputs (List[PromptType]): A list of strings.
mask_length (Optional[List[int]]): A list of mask lengths. If mask_length (Optional[List[int]]): A list of mask lengths. If
provided, the perplexity scores will be calculated with the provided, the perplexity scores will be calculated with the
first mask_length[i] tokens masked out. It's okay to skip first mask_length[i] tokens masked out. It's okay to skip
...@@ -200,12 +200,12 @@ class APITemplateParser: ...@@ -200,12 +200,12 @@ class APITemplateParser:
{'role': 'user', 'prompt': '...'}). {'role': 'user', 'prompt': '...'}).
Args: Args:
prompt_template (List[str or PromptList]): An intermidate prompt prompt_template (List[PromptType]): An intermidate prompt
template (potentially before being wrapped by meta template). template (potentially before being wrapped by meta template).
mode (str): Parsing mode. Choices are 'ppl' and 'gen'. mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
Returns: Returns:
List[str or PromptList]: The finalized prompt or a conversation. List[PromptType]: The finalized prompt or a conversation.
""" """
assert isinstance(prompt_template, (str, list, PromptList, tuple)) assert isinstance(prompt_template, (str, list, PromptList, tuple))
......
...@@ -64,13 +64,13 @@ class ByteDance(BaseAPIModel): ...@@ -64,13 +64,13 @@ class ByteDance(BaseAPIModel):
def generate( def generate(
self, self,
inputs: List[str or PromptList], inputs: List[PromptType],
max_out_len: int = 512, max_out_len: int = 512,
) -> List[str]: ) -> List[str]:
"""Generate results given a list of inputs. """Generate results given a list of inputs.
Args: Args:
inputs (List[str or PromptList]): A list of strings or PromptDicts. inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
...@@ -87,13 +87,13 @@ class ByteDance(BaseAPIModel): ...@@ -87,13 +87,13 @@ class ByteDance(BaseAPIModel):
def _generate( def _generate(
self, self,
input: str or PromptList, input: PromptType,
max_out_len: int = 512, max_out_len: int = 512,
) -> str: ) -> str:
"""Generate results given an input. """Generate results given an input.
Args: Args:
inputs (str or PromptList): A string or PromptDict. inputs (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
......
...@@ -52,13 +52,13 @@ class Claude(BaseAPIModel): ...@@ -52,13 +52,13 @@ class Claude(BaseAPIModel):
def generate( def generate(
self, self,
inputs: List[str or PromptList], inputs: List[PromptType],
max_out_len: int = 512, max_out_len: int = 512,
) -> List[str]: ) -> List[str]:
"""Generate results given a list of inputs. """Generate results given a list of inputs.
Args: Args:
inputs (List[str or PromptList]): A list of strings or PromptDicts. inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
...@@ -74,13 +74,13 @@ class Claude(BaseAPIModel): ...@@ -74,13 +74,13 @@ class Claude(BaseAPIModel):
def _generate( def _generate(
self, self,
input: str or PromptList, input: PromptType,
max_out_len: int = 512, max_out_len: int = 512,
) -> str: ) -> str:
"""Generate results given an input. """Generate results given an input.
Args: Args:
inputs (str or PromptList): A string or PromptDict. inputs (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
......
...@@ -58,13 +58,13 @@ class Gemini(BaseAPIModel): ...@@ -58,13 +58,13 @@ class Gemini(BaseAPIModel):
def generate( def generate(
self, self,
inputs: List[str or PromptList], inputs: List[PromptType],
max_out_len: int = 512, max_out_len: int = 512,
) -> List[str]: ) -> List[str]:
"""Generate results given a list of inputs. """Generate results given a list of inputs.
Args: Args:
inputs (List[str or PromptList]): A list of strings or PromptDicts. inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
...@@ -81,13 +81,13 @@ class Gemini(BaseAPIModel): ...@@ -81,13 +81,13 @@ class Gemini(BaseAPIModel):
def _generate( def _generate(
self, self,
input: str or PromptList, input: PromptType,
max_out_len: int = 512, max_out_len: int = 512,
) -> str: ) -> str:
"""Generate results given an input. """Generate results given an input.
Args: Args:
inputs (str or PromptList): A string or PromptDict. inputs (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
...@@ -234,13 +234,13 @@ class GeminiAllesAPIN(Gemini): ...@@ -234,13 +234,13 @@ class GeminiAllesAPIN(Gemini):
def generate( def generate(
self, self,
inputs: List[str or PromptList], inputs: List[PromptType],
max_out_len: int = 512, max_out_len: int = 512,
) -> List[str]: ) -> List[str]:
"""Generate results given a list of inputs. """Generate results given a list of inputs.
Args: Args:
inputs (List[str or PromptList]): A list of strings or PromptDicts. inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
......
...@@ -723,7 +723,7 @@ class HuggingFaceChatGLM3(HuggingFace): ...@@ -723,7 +723,7 @@ class HuggingFaceChatGLM3(HuggingFace):
self.num_extra_tokens = num_extra_tokens self.num_extra_tokens = num_extra_tokens
def generate(self, def generate(self,
inputs: List[str or PromptList], inputs: List[PromptType],
max_out_len: int = 512, max_out_len: int = 512,
skip_overlength=False, skip_overlength=False,
**kwargs) -> str: **kwargs) -> str:
......
import json
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Union
from opencompass.utils.prompt import PromptList
from .base_api import BaseAPIModel
PromptType = Union[PromptList, str]
class Hunyuan(BaseAPIModel):
def __init__(
self,
path: str,
secret_id: str,
secret_key: str,
endpoint: str,
query_per_second: int = 2,
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None,
retry: int = 2,
): # noqa E125
super().__init__(
path=path,
max_seq_len=max_seq_len,
query_per_second=query_per_second,
meta_template=meta_template,
retry=retry,
)
self.secret_id = secret_id
self.secret_key = secret_key
self.endpoint = endpoint
from tencentcloud.common import credential
from tencentcloud.common.common_client import CommonClient
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
cred = credential.Credential(self.secret_id, self.secret_key)
httpProfile = HttpProfile()
httpProfile.endpoint = self.endpoint
clientProfile = ClientProfile()
clientProfile.httpProfile = httpProfile
self.client = CommonClient('hunyuan',
'2023-09-01',
cred,
'ap-beijing',
profile=clientProfile)
def generate(self,
inputs: List[PromptType],
max_out_len: int = 512) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
with ThreadPoolExecutor() as executor:
results = list(
executor.map(self._generate, inputs,
[max_out_len] * len(inputs)))
self.flush()
return results
def _generate(self, input: PromptType, max_out_len: int = 512) -> str:
"""Generate results given an input.
Args:
inputs (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
str: The generated string.
"""
assert isinstance(input, (str, PromptList))
if isinstance(input, str):
messages = [{'role': 'user', 'content': input}]
else:
messages = []
for item in input:
msg = {'Content': item['prompt']}
if item['role'] == 'HUMAN':
msg['Role'] = 'user'
elif item['role'] == 'BOT':
msg['Role'] = 'assistant'
messages.append(msg)
from tencentcloud.common.exception.tencent_cloud_sdk_exception import \
TencentCloudSDKException
data = {'Messages': messages}
for _ in range(self.retry):
try:
resp = self.client.call_sse('ChatPro', data)
contents = []
for event in resp:
part = json.loads(event['data'])
contents.append(part['Choices'][0]['Delta']['Content'])
answer = ''.join(contents)
except TencentCloudSDKException as err:
print(err)
print(answer)
return answer
raise RuntimeError(f'Failed to respond in {self.retry} retrys')
...@@ -199,7 +199,7 @@ class Llama2Chat(BaseModel): ...@@ -199,7 +199,7 @@ class Llama2Chat(BaseModel):
self.tokenizer = Tokenizer(tokenizer_path) self.tokenizer = Tokenizer(tokenizer_path)
def generate(self, def generate(self,
inputs: List[str or PromptList], inputs: List[PromptType],
max_out_len: int = 512, max_out_len: int = 512,
temperature: float = 0.6) -> str: temperature: float = 0.6) -> str:
"""Generate response from input prompt. """Generate response from input prompt.
......
...@@ -124,13 +124,13 @@ class LmdeployPytorchModel(BaseModel): ...@@ -124,13 +124,13 @@ class LmdeployPytorchModel(BaseModel):
def _generate(self, def _generate(self,
generator, generator,
session_id, session_id,
prompt: str or PromptList, prompt: PromptType,
gen_config=None, gen_config=None,
end_str: Optional[str] = None) -> str: end_str: Optional[str] = None) -> str:
"""Generate results given a list of inputs. """Generate results given a list of inputs.
Args: Args:
prompt (str or PromptList): A string or PromptDict. prompt (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
gen_config (EngineGenerationConfig, optional): Generation gen_config (EngineGenerationConfig, optional): Generation
......
...@@ -60,13 +60,13 @@ class MiniMax(BaseAPIModel): ...@@ -60,13 +60,13 @@ class MiniMax(BaseAPIModel):
def generate( def generate(
self, self,
inputs: List[str or PromptList], inputs: List[PromptType],
max_out_len: int = 512, max_out_len: int = 512,
) -> List[str]: ) -> List[str]:
"""Generate results given a list of inputs. """Generate results given a list of inputs.
Args: Args:
inputs (List[str or PromptList]): A list of strings or PromptDicts. inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
...@@ -83,13 +83,13 @@ class MiniMax(BaseAPIModel): ...@@ -83,13 +83,13 @@ class MiniMax(BaseAPIModel):
def _generate( def _generate(
self, self,
input: str or PromptList, input: PromptType,
max_out_len: int = 512, max_out_len: int = 512,
) -> str: ) -> str:
"""Generate results given an input. """Generate results given an input.
Args: Args:
inputs (str or PromptList): A string or PromptDict. inputs (PromptType): A string or PromptDict.
The PromptDict should be organized in Test' The PromptDict should be organized in Test'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
......
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Union
import requests
from opencompass.utils.prompt import PromptList
from .base_api import BaseAPIModel
PromptType = Union[PromptList, str]
class Mistral(BaseAPIModel):
def __init__(
self,
path: str,
api_key: str,
url: str,
query_per_second: int = 2,
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None,
retry: int = 2,
): # noqa E125
super().__init__(
path=path,
max_seq_len=max_seq_len,
query_per_second=query_per_second,
meta_template=meta_template,
retry=retry,
)
self.api_key = api_key
self.url = url
self.model = path
def generate(self,
inputs: List[PromptType],
max_out_len: int = 512) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
with ThreadPoolExecutor() as executor:
results = list(
executor.map(self._generate, inputs,
[max_out_len] * len(inputs)))
self.flush()
return results
def _generate(self, input: PromptType, max_out_len: int = 512) -> str:
"""Generate results given an input.
Args:
inputs (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
str: The generated string.
"""
assert isinstance(input, (str, PromptList))
if isinstance(input, str):
messages = [{'role': 'user', 'content': input}]
else:
messages = []
for item in input:
msg = {'content': item['prompt']}
if item['role'] == 'HUMAN':
msg['role'] = 'user'
elif item['role'] == 'BOT':
msg['role'] = 'assistant'
elif item['role'] == 'SYSTEM':
msg['role'] = 'system'
messages.append(msg)
messages[-1]['role'] = 'user'
data = {
'model': self.path,
'messages': messages,
}
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
'Authorization': f'Bearer {self.api_key}',
}
from pprint import pprint
print('-' * 128)
pprint(data)
for _ in range(self.retry):
try:
response = requests.post(self.url, json=data, headers=headers)
except Exception as e:
print(e)
continue
try:
response = response.json()
except Exception as e:
print(e)
continue
print('=' * 128)
pprint(response)
try:
msg = response['choices'][0]['message']['content']
except Exception as e:
print(e)
continue
return msg
raise RuntimeError(f'Failed to respond in {self.retry} retrys')
...@@ -55,13 +55,13 @@ class MoonShot(BaseAPIModel): ...@@ -55,13 +55,13 @@ class MoonShot(BaseAPIModel):
def generate( def generate(
self, self,
inputs: List[str or PromptList], inputs: List[PromptType],
max_out_len: int = 512, max_out_len: int = 512,
) -> List[str]: ) -> List[str]:
"""Generate results given a list of inputs. """Generate results given a list of inputs.
Args: Args:
inputs (List[str or PromptList]): A list of strings or PromptDicts. inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
...@@ -78,13 +78,13 @@ class MoonShot(BaseAPIModel): ...@@ -78,13 +78,13 @@ class MoonShot(BaseAPIModel):
def _generate( def _generate(
self, self,
input: str or PromptList, input: PromptType,
max_out_len: int = 512, max_out_len: int = 512,
) -> str: ) -> str:
"""Generate results given an input. """Generate results given an input.
Args: Args:
inputs (str or PromptList): A string or PromptDict. inputs (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
...@@ -98,29 +98,27 @@ class MoonShot(BaseAPIModel): ...@@ -98,29 +98,27 @@ class MoonShot(BaseAPIModel):
messages = [{'role': 'user', 'content': input}] messages = [{'role': 'user', 'content': input}]
else: else:
messages = [] messages = []
msg_buffer, last_role = [], None
for item in input: for item in input:
msg = {'content': item['prompt']} item['role'] = 'assistant' if item['role'] == 'BOT' else 'user'
if item['role'] == 'HUMAN': if item['role'] != last_role and last_role is not None:
msg['role'] = 'user' messages.append({
elif item['role'] == 'BOT': 'content': '\n'.join(msg_buffer),
msg['role'] = 'assistant' 'role': last_role
})
messages.append(msg) msg_buffer = []
msg_buffer.append(item['prompt'])
system = { last_role = item['role']
'role': 'system', messages.append({
'content': self.system_prompt 'content': '\n'.join(msg_buffer),
# '你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。' 'role': last_role
# '你会为用户提供安全,有帮助,准确的回答。同时,你会拒绝一些涉及恐怖主义,种族歧视,' })
# '黄色暴力等问题的回答。Moonshot AI 为专有名词,不可翻译成其他语言。'
} if self.system_prompt:
system = {'role': 'system', 'content': self.system_prompt}
messages.insert(0, system) messages.insert(0, system)
data = { data = {'model': self.model, 'messages': messages}
'model': self.model,
'messages': messages,
}
max_num_retries = 0 max_num_retries = 0
while max_num_retries < self.retry: while max_num_retries < self.retry:
......
...@@ -52,13 +52,13 @@ class Nanbeige(BaseAPIModel): ...@@ -52,13 +52,13 @@ class Nanbeige(BaseAPIModel):
def generate( def generate(
self, self,
inputs: List[str or PromptList], inputs: List[PromptType],
max_out_len: int = 512, max_out_len: int = 512,
) -> List[str]: ) -> List[str]:
"""Generate results given a list of inputs. """Generate results given a list of inputs.
Args: Args:
inputs (List[str or PromptList]): A list of strings or PromptDicts. inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
...@@ -75,13 +75,13 @@ class Nanbeige(BaseAPIModel): ...@@ -75,13 +75,13 @@ class Nanbeige(BaseAPIModel):
def _generate( def _generate(
self, self,
input: str or PromptList, input: PromptType,
max_out_len: int = 512, max_out_len: int = 512,
) -> str: ) -> str:
"""Generate results given an input. """Generate results given an input.
Args: Args:
inputs (str or PromptList): A string or PromptDict. inputs (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
......
...@@ -103,14 +103,14 @@ class OpenAI(BaseAPIModel): ...@@ -103,14 +103,14 @@ class OpenAI(BaseAPIModel):
def generate( def generate(
self, self,
inputs: List[str or PromptList], inputs: List[PromptType],
max_out_len: int = 512, max_out_len: int = 512,
temperature: float = 0.7, temperature: float = 0.7,
) -> List[str]: ) -> List[str]:
"""Generate results given a list of inputs. """Generate results given a list of inputs.
Args: Args:
inputs (List[str or PromptList]): A list of strings or PromptDicts. inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
...@@ -132,12 +132,12 @@ class OpenAI(BaseAPIModel): ...@@ -132,12 +132,12 @@ class OpenAI(BaseAPIModel):
[temperature] * len(inputs))) [temperature] * len(inputs)))
return results return results
def _generate(self, input: str or PromptList, max_out_len: int, def _generate(self, input: PromptType, max_out_len: int,
temperature: float) -> str: temperature: float) -> str:
"""Generate results given a list of inputs. """Generate results given a list of inputs.
Args: Args:
inputs (str or PromptList): A string or PromptDict. inputs (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
...@@ -207,6 +207,7 @@ class OpenAI(BaseAPIModel): ...@@ -207,6 +207,7 @@ class OpenAI(BaseAPIModel):
header = { header = {
'Authorization': f'Bearer {key}', 'Authorization': f'Bearer {key}',
'content-type': 'application/json', 'content-type': 'application/json',
'api-key': key,
} }
if self.orgs: if self.orgs:
...@@ -239,6 +240,7 @@ class OpenAI(BaseAPIModel): ...@@ -239,6 +240,7 @@ class OpenAI(BaseAPIModel):
self.logger.error('JsonDecode error, got', self.logger.error('JsonDecode error, got',
str(raw_response.content)) str(raw_response.content))
continue continue
self.logger.error(str(response))
try: try:
if self.logprobs: if self.logprobs:
return response['choices'] return response['choices']
...@@ -247,13 +249,16 @@ class OpenAI(BaseAPIModel): ...@@ -247,13 +249,16 @@ class OpenAI(BaseAPIModel):
except KeyError: except KeyError:
if 'error' in response: if 'error' in response:
if response['error']['code'] == 'rate_limit_exceeded': if response['error']['code'] == 'rate_limit_exceeded':
time.sleep(1) time.sleep(10)
self.logger.warn('Rate limit exceeded, retrying...') self.logger.warn('Rate limit exceeded, retrying...')
continue continue
elif response['error']['code'] == 'insufficient_quota': elif response['error']['code'] == 'insufficient_quota':
self.invalid_keys.add(key) self.invalid_keys.add(key)
self.logger.warn(f'insufficient_quota key: {key}') self.logger.warn(f'insufficient_quota key: {key}')
continue continue
elif response['error']['code'] == 'invalid_prompt':
self.logger.warn('Invalid prompt:', str(input))
return ''
self.logger.error('Find error message in response: ', self.logger.error('Find error message in response: ',
str(response['error'])) str(response['error']))
...@@ -363,12 +368,12 @@ class OpenAIAllesAPIN(OpenAI): ...@@ -363,12 +368,12 @@ class OpenAIAllesAPIN(OpenAI):
'content-type': 'application/json', 'content-type': 'application/json',
} }
def _generate(self, input: str or PromptList, max_out_len: int, def _generate(self, input: PromptType, max_out_len: int,
temperature: float) -> str: temperature: float) -> str:
"""Generate results given an input. """Generate results given an input.
Args: Args:
inputs (str or PromptList): A string or PromptDict. inputs (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
......
...@@ -67,13 +67,13 @@ class PanGu(BaseAPIModel): ...@@ -67,13 +67,13 @@ class PanGu(BaseAPIModel):
def generate( def generate(
self, self,
inputs: List[str or PromptList], inputs: List[PromptType],
max_out_len: int = 512, max_out_len: int = 512,
) -> List[str]: ) -> List[str]:
"""Generate results given a list of inputs. """Generate results given a list of inputs.
Args: Args:
inputs (List[str or PromptList]): A list of strings or PromptDicts. inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
...@@ -117,13 +117,13 @@ class PanGu(BaseAPIModel): ...@@ -117,13 +117,13 @@ class PanGu(BaseAPIModel):
def _generate( def _generate(
self, self,
input: str or PromptList, input: PromptType,
max_out_len: int = 512, max_out_len: int = 512,
) -> str: ) -> str:
"""Generate results given an input. """Generate results given an input.
Args: Args:
inputs (str or PromptList): A string or PromptDict. inputs (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in OpenCompass'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment