Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
opencompass
Commits
b39f5015
Unverified
Commit
b39f5015
authored
Apr 09, 2024
by
Fengzhe Zhou
Committed by
GitHub
Apr 09, 2024
Browse files
[Sync] update taco (#1030)
parent
16f29b25
Changes
87
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
531 additions
and
82 deletions
+531
-82
opencompass/lagent/agents/react.py
opencompass/lagent/agents/react.py
+74
-0
opencompass/models/__init__.py
opencompass/models/__init__.py
+5
-1
opencompass/models/ai360_api.py
opencompass/models/ai360_api.py
+4
-4
opencompass/models/baichuan_api.py
opencompass/models/baichuan_api.py
+128
-4
opencompass/models/baidu_api.py
opencompass/models/baidu_api.py
+4
-4
opencompass/models/base.py
opencompass/models/base.py
+2
-2
opencompass/models/base_api.py
opencompass/models/base_api.py
+4
-4
opencompass/models/bytedance_api.py
opencompass/models/bytedance_api.py
+4
-4
opencompass/models/claude_api/claude_api.py
opencompass/models/claude_api/claude_api.py
+4
-4
opencompass/models/gemini_api.py
opencompass/models/gemini_api.py
+6
-6
opencompass/models/huggingface.py
opencompass/models/huggingface.py
+1
-1
opencompass/models/hunyuan_api.py
opencompass/models/hunyuan_api.py
+121
-0
opencompass/models/llama2.py
opencompass/models/llama2.py
+1
-1
opencompass/models/lmdeploy_pytorch.py
opencompass/models/lmdeploy_pytorch.py
+2
-2
opencompass/models/minimax_api.py
opencompass/models/minimax_api.py
+4
-4
opencompass/models/mistral_api.py
opencompass/models/mistral_api.py
+123
-0
opencompass/models/moonshot_api.py
opencompass/models/moonshot_api.py
+24
-26
opencompass/models/nanbeige_api.py
opencompass/models/nanbeige_api.py
+4
-4
opencompass/models/openai_api.py
opencompass/models/openai_api.py
+12
-7
opencompass/models/pangu_api.py
opencompass/models/pangu_api.py
+4
-4
No files found.
opencompass/lagent/agents/react.py
View file @
b39f5015
...
@@ -201,3 +201,77 @@ class CIReAct(ReAct):
...
@@ -201,3 +201,77 @@ class CIReAct(ReAct):
self
.
_session_history
.
append
(
self
.
_session_history
.
append
(
dict
(
role
=
'assistant'
,
content
=
agent_return
.
response
))
dict
(
role
=
'assistant'
,
content
=
agent_return
.
response
))
return
agent_return
return
agent_return
class
CIReActMergeRole
(
CIReAct
):
"""如有第一轮 SYSTEM, 则使用 SYSTEM。后续 SYSTEM 使用 USER 合并复数轮 USER USER 与 BOT
交替出现."""
def
chat
(
self
,
message
:
str
)
->
AgentReturn
:
for
hist
in
self
.
_session_history
:
if
hist
[
'role'
]
==
'system'
:
hist
[
'role'
]
=
self
.
system_role
self
.
_inner_history
=
[]
# append the user message for session history
self
.
_session_history
.
append
(
dict
(
role
=
'user'
,
content
=
message
))
agent_return
=
AgentReturn
()
force_stop
=
False
default_response
=
'对不起,我无法回答你的问题'
for
turn
in
range
(
self
.
max_turn
):
prompt
=
self
.
_protocol
.
format
(
chat_history
=
self
.
session_history
,
inner_step
=
self
.
_inner_history
,
action_executor
=
self
.
_action_executor
,
force_stop
=
force_stop
)
prompt
=
self
.
merge_role
(
prompt
)
response
=
self
.
_llm
.
generate_from_template
(
prompt
,
512
)
self
.
_inner_history
.
append
(
dict
(
role
=
'assistant'
,
content
=
response
))
thought
,
action
,
action_input
=
self
.
_protocol
.
parse
(
response
,
self
.
_action_executor
)
action_return
:
ActionReturn
=
self
.
_action_executor
(
action
,
action_input
)
action_return
.
thought
=
thought
agent_return
.
actions
.
append
(
action_return
)
if
action_return
.
state
==
ActionStatusCode
.
SUCCESS
:
# if success, stash model response and system response
self
.
_session_history
.
append
(
dict
(
role
=
'assistant'
,
content
=
response
))
self
.
_session_history
.
append
(
dict
(
role
=
self
.
system_role
,
content
=
self
.
_protocol
.
format_response
(
action_return
)))
agent_return
.
response
=
action_return
.
result
[
'text'
]
return
agent_return
elif
action_return
.
type
==
self
.
_action_executor
.
invalid_action
.
name
:
# noqa
action_return
.
errmsg
=
'The action is invalid, please check the action name.'
# noqa
self
.
_inner_history
.
append
(
dict
(
role
=
self
.
system_role
,
content
=
self
.
_protocol
.
format_response
(
action_return
)))
if
turn
==
self
.
max_turn
-
1
:
force_stop
=
True
agent_return
.
response
=
default_response
self
.
_session_history
.
append
(
dict
(
role
=
'assistant'
,
content
=
agent_return
.
response
))
return
agent_return
def
merge_role
(
self
,
inputs
):
messages
=
[]
msg_buffer
,
last_role
=
[],
None
for
index
,
item
in
enumerate
(
inputs
):
if
index
==
0
and
item
[
'role'
]
==
'system'
:
role
=
'system'
elif
item
[
'role'
]
==
'assistant'
:
role
=
'assistant'
else
:
role
=
'user'
if
role
!=
last_role
and
last_role
is
not
None
:
messages
.
append
({
'content'
:
'
\n
'
.
join
(
msg_buffer
),
'role'
:
last_role
})
msg_buffer
=
[]
msg_buffer
.
append
(
item
[
'content'
])
last_role
=
role
messages
.
append
({
'content'
:
'
\n
'
.
join
(
msg_buffer
),
'role'
:
last_role
})
return
messages
opencompass/models/__init__.py
View file @
b39f5015
from
.accessory
import
LLaMA2AccessoryModel
# noqa: F401
from
.accessory
import
LLaMA2AccessoryModel
# noqa: F401
from
.ai360_api
import
AI360GPT
# noqa: F401
from
.ai360_api
import
AI360GPT
# noqa: F401
from
.alaya
import
AlayaLM
# noqa: F401
from
.alaya
import
AlayaLM
# noqa: F401
from
.baichuan_api
import
BaiChuan
# noqa: F401
from
.baichuan_api
import
BaiChuan
,
BaiChuan3
# noqa: F401
from
.baidu_api
import
ERNIEBot
# noqa: F401
from
.baidu_api
import
ERNIEBot
# noqa: F401
from
.base
import
BaseModel
,
LMTemplateParser
# noqa
from
.base
import
BaseModel
,
LMTemplateParser
# noqa
from
.base_api
import
APITemplateParser
,
BaseAPIModel
# noqa
from
.base_api
import
APITemplateParser
,
BaseAPIModel
# noqa
...
@@ -12,12 +12,14 @@ from .glm import GLM130B # noqa: F401, F403
...
@@ -12,12 +12,14 @@ from .glm import GLM130B # noqa: F401, F403
from
.huggingface
import
HuggingFace
# noqa: F401, F403
from
.huggingface
import
HuggingFace
# noqa: F401, F403
from
.huggingface
import
HuggingFaceCausalLM
# noqa: F401, F403
from
.huggingface
import
HuggingFaceCausalLM
# noqa: F401, F403
from
.huggingface
import
HuggingFaceChatGLM3
# noqa: F401, F403
from
.huggingface
import
HuggingFaceChatGLM3
# noqa: F401, F403
from
.hunyuan_api
import
Hunyuan
# noqa: F401
from
.intern_model
import
InternLM
# noqa: F401, F403
from
.intern_model
import
InternLM
# noqa: F401, F403
from
.krgpt_api
import
KrGPT
# noqa: F401
from
.krgpt_api
import
KrGPT
# noqa: F401
from
.lightllm_api
import
LightllmAPI
# noqa: F401
from
.lightllm_api
import
LightllmAPI
# noqa: F401
from
.llama2
import
Llama2
,
Llama2Chat
# noqa: F401, F403
from
.llama2
import
Llama2
,
Llama2Chat
# noqa: F401, F403
from
.lmdeploy_pytorch
import
LmdeployPytorchModel
# noqa: F401
from
.lmdeploy_pytorch
import
LmdeployPytorchModel
# noqa: F401
from
.minimax_api
import
MiniMax
# noqa: F401
from
.minimax_api
import
MiniMax
# noqa: F401
from
.mistral_api
import
Mistral
# noqa: F401
from
.mixtral
import
Mixtral
# noqa: F401
from
.mixtral
import
Mixtral
# noqa: F401
from
.modelscope
import
ModelScope
,
ModelScopeCausalLM
# noqa: F401, F403
from
.modelscope
import
ModelScope
,
ModelScopeCausalLM
# noqa: F401, F403
from
.moonshot_api
import
MoonShot
# noqa: F401
from
.moonshot_api
import
MoonShot
# noqa: F401
...
@@ -28,7 +30,9 @@ from .qwen_api import Qwen # noqa: F401
...
@@ -28,7 +30,9 @@ from .qwen_api import Qwen # noqa: F401
from
.sensetime_api
import
SenseTime
# noqa: F401
from
.sensetime_api
import
SenseTime
# noqa: F401
from
.turbomind
import
TurboMindModel
# noqa: F401
from
.turbomind
import
TurboMindModel
# noqa: F401
from
.turbomind_tis
import
TurboMindTisModel
# noqa: F401
from
.turbomind_tis
import
TurboMindTisModel
# noqa: F401
from
.unigpt_api
import
UniGPT
# noqa: F401
from
.vllm
import
VLLM
# noqa: F401
from
.vllm
import
VLLM
# noqa: F401
from
.xunfei_api
import
XunFei
# noqa: F401
from
.xunfei_api
import
XunFei
# noqa: F401
from
.yayi_api
import
Yayi
# noqa: F401
from
.zhipuai_api
import
ZhiPuAI
# noqa: F401
from
.zhipuai_api
import
ZhiPuAI
# noqa: F401
from
.zhipuai_v2_api
import
ZhiPuV2AI
# noqa: F401
from
.zhipuai_v2_api
import
ZhiPuV2AI
# noqa: F401
opencompass/models/ai360_api.py
View file @
b39f5015
...
@@ -60,13 +60,13 @@ class AI360GPT(BaseAPIModel):
...
@@ -60,13 +60,13 @@ class AI360GPT(BaseAPIModel):
def
generate
(
def
generate
(
self
,
self
,
inputs
:
List
[
str
or
Prompt
List
],
inputs
:
List
[
Prompt
Type
],
max_out_len
:
int
=
512
,
max_out_len
:
int
=
512
,
)
->
List
[
str
]:
)
->
List
[
str
]:
"""Generate results given a list of inputs.
"""Generate results given a list of inputs.
Args:
Args:
inputs (List[
str or
Prompt
List
]): A list of strings or PromptDicts.
inputs (List[Prompt
Type
]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
The PromptDict should be organized in OpenCompass'
API format.
API format.
max_out_len (int): The maximum length of the output.
max_out_len (int): The maximum length of the output.
...
@@ -83,13 +83,13 @@ class AI360GPT(BaseAPIModel):
...
@@ -83,13 +83,13 @@ class AI360GPT(BaseAPIModel):
def
_generate
(
def
_generate
(
self
,
self
,
input
:
str
or
Prompt
List
,
input
:
Prompt
Type
,
max_out_len
:
int
=
512
,
max_out_len
:
int
=
512
,
)
->
str
:
)
->
str
:
"""Generate results given an input.
"""Generate results given an input.
Args:
Args:
inputs (
str or
Prompt
List
): A string or PromptDict.
inputs (Prompt
Type
): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
The PromptDict should be organized in OpenCompass'
API format.
API format.
max_out_len (int): The maximum length of the output.
max_out_len (int): The maximum length of the output.
...
...
opencompass/models/baichuan_api.py
View file @
b39f5015
...
@@ -59,13 +59,13 @@ class BaiChuan(BaseAPIModel):
...
@@ -59,13 +59,13 @@ class BaiChuan(BaseAPIModel):
def
generate
(
def
generate
(
self
,
self
,
inputs
:
List
[
str
or
Prompt
List
],
inputs
:
List
[
Prompt
Type
],
max_out_len
:
int
=
512
,
max_out_len
:
int
=
512
,
)
->
List
[
str
]:
)
->
List
[
str
]:
"""Generate results given a list of inputs.
"""Generate results given a list of inputs.
Args:
Args:
inputs (List[
str or
Prompt
List
]): A list of strings or PromptDicts.
inputs (List[Prompt
Type
]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
The PromptDict should be organized in OpenCompass'
API format.
API format.
max_out_len (int): The maximum length of the output.
max_out_len (int): The maximum length of the output.
...
@@ -82,13 +82,13 @@ class BaiChuan(BaseAPIModel):
...
@@ -82,13 +82,13 @@ class BaiChuan(BaseAPIModel):
def
_generate
(
def
_generate
(
self
,
self
,
input
:
str
or
Prompt
List
,
input
:
Prompt
Type
,
max_out_len
:
int
=
512
,
max_out_len
:
int
=
512
,
)
->
str
:
)
->
str
:
"""Generate results given an input.
"""Generate results given an input.
Args:
Args:
inputs (
str or
Prompt
List
): A string or PromptDict.
inputs (Prompt
Type
): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
The PromptDict should be organized in OpenCompass'
API format.
API format.
max_out_len (int): The maximum length of the output.
max_out_len (int): The maximum length of the output.
...
@@ -157,3 +157,127 @@ class BaiChuan(BaseAPIModel):
...
@@ -157,3 +157,127 @@ class BaiChuan(BaseAPIModel):
max_num_retries
+=
1
max_num_retries
+=
1
raise
RuntimeError
(
response
)
raise
RuntimeError
(
response
)
class
BaiChuan3
(
BaseAPIModel
):
def
__init__
(
self
,
path
:
str
,
api_key
:
str
,
url
:
str
,
query_per_second
:
int
=
2
,
max_seq_len
:
int
=
2048
,
meta_template
:
Optional
[
Dict
]
=
None
,
retry
:
int
=
2
,
):
# noqa E125
super
().
__init__
(
path
=
path
,
max_seq_len
=
max_seq_len
,
query_per_second
=
query_per_second
,
meta_template
=
meta_template
,
retry
=
retry
)
self
.
api_key
=
api_key
self
.
url
=
url
self
.
model
=
path
def
generate
(
self
,
inputs
:
List
[
PromptType
],
max_out_len
:
int
=
512
,
)
->
List
[
str
]:
"""Generate results given a list of inputs.
Args:
inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
with
ThreadPoolExecutor
()
as
executor
:
results
=
list
(
executor
.
map
(
self
.
_generate
,
inputs
,
[
max_out_len
]
*
len
(
inputs
)))
self
.
flush
()
return
results
def
_generate
(
self
,
input
:
PromptType
,
max_out_len
:
int
=
512
,
)
->
str
:
"""Generate results given an input.
Args:
inputs (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
str: The generated string.
"""
assert
isinstance
(
input
,
(
str
,
PromptList
))
if
isinstance
(
input
,
str
):
history
=
[]
prompt
=
input
else
:
messages
=
[]
msg_buffer
,
last_role
=
[],
None
for
item
in
input
:
role
=
'BOT'
if
item
[
'role'
]
==
'BOT'
else
'USER'
if
role
!=
last_role
and
last_role
is
not
None
:
messages
.
append
({
'data'
:
'
\n
'
.
join
(
msg_buffer
),
'from'
:
0
if
last_role
==
'USER'
else
1
})
msg_buffer
=
[]
msg_buffer
.
append
(
item
[
'prompt'
])
last_role
=
role
messages
.
append
({
'data'
:
'
\n
'
.
join
(
msg_buffer
),
'from'
:
0
if
last_role
==
'USER'
else
1
})
history
=
messages
[:
-
1
]
prompt
=
messages
[
-
1
][
'data'
]
data
=
{
'access_token_key'
:
self
.
api_key
,
'app_info'
:
{
'id'
:
123
},
'prompt'
:
{
'data'
:
prompt
},
'history'
:
history
,
}
for
_
in
range
(
self
.
retry
):
try
:
response
=
requests
.
post
(
self
.
url
,
json
=
data
)
except
Exception
as
e
:
print
(
e
)
continue
if
response
is
None
or
response
.
status_code
!=
200
:
code
=
response
.
status_code
if
response
else
-
1
print
(
f
'[chat_api]-[failed] request err, status_code:
{
code
}
'
)
continue
try
:
response
=
response
.
json
()
except
Exception
as
e
:
print
(
e
)
continue
print
(
response
)
status
=
response
.
get
(
'answer'
,
{}).
get
(
'status'
,
0
)
session_status
=
response
.
get
(
'session_info'
,
{}).
get
(
'status'
,
0
)
if
status
<
0
or
session_status
<
0
:
print
(
'[chat_api]-[warn] prompt or answer is unsafe'
)
return
'Rejection: unsafe prompt or answer'
return
response
.
get
(
'answer'
,
{}).
get
(
'data'
,
''
)
raise
RuntimeError
(
response
[
'msg'
])
opencompass/models/baidu_api.py
View file @
b39f5015
...
@@ -88,13 +88,13 @@ class ERNIEBot(BaseAPIModel):
...
@@ -88,13 +88,13 @@ class ERNIEBot(BaseAPIModel):
def
generate
(
def
generate
(
self
,
self
,
inputs
:
List
[
str
or
Prompt
List
],
inputs
:
List
[
Prompt
Type
],
max_out_len
:
int
=
512
,
max_out_len
:
int
=
512
,
)
->
List
[
str
]:
)
->
List
[
str
]:
"""Generate results given a list of inputs.
"""Generate results given a list of inputs.
Args:
Args:
inputs (List[
str or
Prompt
List
]): A list of strings or PromptDicts.
inputs (List[Prompt
Type
]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
The PromptDict should be organized in OpenCompass'
API format.
API format.
max_out_len (int): The maximum length of the output.
max_out_len (int): The maximum length of the output.
...
@@ -111,13 +111,13 @@ class ERNIEBot(BaseAPIModel):
...
@@ -111,13 +111,13 @@ class ERNIEBot(BaseAPIModel):
def
_generate
(
def
_generate
(
self
,
self
,
input
:
str
or
Prompt
List
,
input
:
Prompt
Type
,
max_out_len
:
int
=
512
,
max_out_len
:
int
=
512
,
)
->
str
:
)
->
str
:
"""Generate results given an input.
"""Generate results given an input.
Args:
Args:
inputs (
str or
Prompt
List
): A string or PromptDict.
inputs (Prompt
Type
): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
The PromptDict should be organized in OpenCompass'
API format.
API format.
max_out_len (int): The maximum length of the output.
max_out_len (int): The maximum length of the output.
...
...
opencompass/models/base.py
View file @
b39f5015
...
@@ -129,7 +129,7 @@ class BaseModel:
...
@@ -129,7 +129,7 @@ class BaseModel:
applicable.
applicable.
Args:
Args:
prompt_template (List[
str or
Prompt
List
]): A prompt
prompt_template (List[Prompt
Type
]): A prompt
template (potentially before being wrapped by meta template).
template (potentially before being wrapped by meta template).
mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
...
@@ -266,7 +266,7 @@ class LMTemplateParser:
...
@@ -266,7 +266,7 @@ class LMTemplateParser:
applicable.
applicable.
Args:
Args:
prompt_template (List[
str or
Prompt
List
]): A prompt
prompt_template (List[Prompt
Type
]): A prompt
template (potentially before being wrapped by meta template).
template (potentially before being wrapped by meta template).
mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
...
...
opencompass/models/base_api.py
View file @
b39f5015
...
@@ -60,7 +60,7 @@ class BaseAPIModel(BaseModel):
...
@@ -60,7 +60,7 @@ class BaseAPIModel(BaseModel):
"""Generate results given a list of inputs.
"""Generate results given a list of inputs.
Args:
Args:
inputs (List[
str or
Prompt
List
]): A list of strings or PromptDicts.
inputs (List[Prompt
Type
]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
The PromptDict should be organized in OpenCompass'
API format.
API format.
max_out_len (int): The maximum length of the output.
max_out_len (int): The maximum length of the output.
...
@@ -111,7 +111,7 @@ class BaseAPIModel(BaseModel):
...
@@ -111,7 +111,7 @@ class BaseAPIModel(BaseModel):
"""Get perplexity scores given a list of inputs.
"""Get perplexity scores given a list of inputs.
Args:
Args:
inputs (List[
str or
Prompt
List
]): A list of strings.
inputs (List[Prompt
Type
]): A list of strings.
mask_length (Optional[List[int]]): A list of mask lengths. If
mask_length (Optional[List[int]]): A list of mask lengths. If
provided, the perplexity scores will be calculated with the
provided, the perplexity scores will be calculated with the
first mask_length[i] tokens masked out. It's okay to skip
first mask_length[i] tokens masked out. It's okay to skip
...
@@ -200,12 +200,12 @@ class APITemplateParser:
...
@@ -200,12 +200,12 @@ class APITemplateParser:
{'role': 'user', 'prompt': '...'}).
{'role': 'user', 'prompt': '...'}).
Args:
Args:
prompt_template (List[
str or
Prompt
List
]): An intermidate prompt
prompt_template (List[Prompt
Type
]): An intermidate prompt
template (potentially before being wrapped by meta template).
template (potentially before being wrapped by meta template).
mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
Returns:
Returns:
List[
str or
Prompt
List
]: The finalized prompt or a conversation.
List[Prompt
Type
]: The finalized prompt or a conversation.
"""
"""
assert
isinstance
(
prompt_template
,
(
str
,
list
,
PromptList
,
tuple
))
assert
isinstance
(
prompt_template
,
(
str
,
list
,
PromptList
,
tuple
))
...
...
opencompass/models/bytedance_api.py
View file @
b39f5015
...
@@ -64,13 +64,13 @@ class ByteDance(BaseAPIModel):
...
@@ -64,13 +64,13 @@ class ByteDance(BaseAPIModel):
def
generate
(
def
generate
(
self
,
self
,
inputs
:
List
[
str
or
Prompt
List
],
inputs
:
List
[
Prompt
Type
],
max_out_len
:
int
=
512
,
max_out_len
:
int
=
512
,
)
->
List
[
str
]:
)
->
List
[
str
]:
"""Generate results given a list of inputs.
"""Generate results given a list of inputs.
Args:
Args:
inputs (List[
str or
Prompt
List
]): A list of strings or PromptDicts.
inputs (List[Prompt
Type
]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
The PromptDict should be organized in OpenCompass'
API format.
API format.
max_out_len (int): The maximum length of the output.
max_out_len (int): The maximum length of the output.
...
@@ -87,13 +87,13 @@ class ByteDance(BaseAPIModel):
...
@@ -87,13 +87,13 @@ class ByteDance(BaseAPIModel):
def
_generate
(
def
_generate
(
self
,
self
,
input
:
str
or
Prompt
List
,
input
:
Prompt
Type
,
max_out_len
:
int
=
512
,
max_out_len
:
int
=
512
,
)
->
str
:
)
->
str
:
"""Generate results given an input.
"""Generate results given an input.
Args:
Args:
inputs (
str or
Prompt
List
): A string or PromptDict.
inputs (Prompt
Type
): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
The PromptDict should be organized in OpenCompass'
API format.
API format.
max_out_len (int): The maximum length of the output.
max_out_len (int): The maximum length of the output.
...
...
opencompass/models/claude_api/claude_api.py
View file @
b39f5015
...
@@ -52,13 +52,13 @@ class Claude(BaseAPIModel):
...
@@ -52,13 +52,13 @@ class Claude(BaseAPIModel):
def
generate
(
def
generate
(
self
,
self
,
inputs
:
List
[
str
or
Prompt
List
],
inputs
:
List
[
Prompt
Type
],
max_out_len
:
int
=
512
,
max_out_len
:
int
=
512
,
)
->
List
[
str
]:
)
->
List
[
str
]:
"""Generate results given a list of inputs.
"""Generate results given a list of inputs.
Args:
Args:
inputs (List[
str or
Prompt
List
]): A list of strings or PromptDicts.
inputs (List[Prompt
Type
]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
The PromptDict should be organized in OpenCompass'
API format.
API format.
max_out_len (int): The maximum length of the output.
max_out_len (int): The maximum length of the output.
...
@@ -74,13 +74,13 @@ class Claude(BaseAPIModel):
...
@@ -74,13 +74,13 @@ class Claude(BaseAPIModel):
def
_generate
(
def
_generate
(
self
,
self
,
input
:
str
or
Prompt
List
,
input
:
Prompt
Type
,
max_out_len
:
int
=
512
,
max_out_len
:
int
=
512
,
)
->
str
:
)
->
str
:
"""Generate results given an input.
"""Generate results given an input.
Args:
Args:
inputs (
str or
Prompt
List
): A string or PromptDict.
inputs (Prompt
Type
): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
The PromptDict should be organized in OpenCompass'
API format.
API format.
max_out_len (int): The maximum length of the output.
max_out_len (int): The maximum length of the output.
...
...
opencompass/models/gemini_api.py
View file @
b39f5015
...
@@ -58,13 +58,13 @@ class Gemini(BaseAPIModel):
...
@@ -58,13 +58,13 @@ class Gemini(BaseAPIModel):
def
generate
(
def
generate
(
self
,
self
,
inputs
:
List
[
str
or
Prompt
List
],
inputs
:
List
[
Prompt
Type
],
max_out_len
:
int
=
512
,
max_out_len
:
int
=
512
,
)
->
List
[
str
]:
)
->
List
[
str
]:
"""Generate results given a list of inputs.
"""Generate results given a list of inputs.
Args:
Args:
inputs (List[
str or
Prompt
List
]): A list of strings or PromptDicts.
inputs (List[Prompt
Type
]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
The PromptDict should be organized in OpenCompass'
API format.
API format.
max_out_len (int): The maximum length of the output.
max_out_len (int): The maximum length of the output.
...
@@ -81,13 +81,13 @@ class Gemini(BaseAPIModel):
...
@@ -81,13 +81,13 @@ class Gemini(BaseAPIModel):
def
_generate
(
def
_generate
(
self
,
self
,
input
:
str
or
Prompt
List
,
input
:
Prompt
Type
,
max_out_len
:
int
=
512
,
max_out_len
:
int
=
512
,
)
->
str
:
)
->
str
:
"""Generate results given an input.
"""Generate results given an input.
Args:
Args:
inputs (
str or
Prompt
List
): A string or PromptDict.
inputs (Prompt
Type
): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
The PromptDict should be organized in OpenCompass'
API format.
API format.
max_out_len (int): The maximum length of the output.
max_out_len (int): The maximum length of the output.
...
@@ -234,13 +234,13 @@ class GeminiAllesAPIN(Gemini):
...
@@ -234,13 +234,13 @@ class GeminiAllesAPIN(Gemini):
def
generate
(
def
generate
(
self
,
self
,
inputs
:
List
[
str
or
Prompt
List
],
inputs
:
List
[
Prompt
Type
],
max_out_len
:
int
=
512
,
max_out_len
:
int
=
512
,
)
->
List
[
str
]:
)
->
List
[
str
]:
"""Generate results given a list of inputs.
"""Generate results given a list of inputs.
Args:
Args:
inputs (List[
str or
Prompt
List
]): A list of strings or PromptDicts.
inputs (List[Prompt
Type
]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
The PromptDict should be organized in OpenCompass'
API format.
API format.
max_out_len (int): The maximum length of the output.
max_out_len (int): The maximum length of the output.
...
...
opencompass/models/huggingface.py
View file @
b39f5015
...
@@ -723,7 +723,7 @@ class HuggingFaceChatGLM3(HuggingFace):
...
@@ -723,7 +723,7 @@ class HuggingFaceChatGLM3(HuggingFace):
self
.
num_extra_tokens
=
num_extra_tokens
self
.
num_extra_tokens
=
num_extra_tokens
def
generate
(
self
,
def
generate
(
self
,
inputs
:
List
[
str
or
Prompt
List
],
inputs
:
List
[
Prompt
Type
],
max_out_len
:
int
=
512
,
max_out_len
:
int
=
512
,
skip_overlength
=
False
,
skip_overlength
=
False
,
**
kwargs
)
->
str
:
**
kwargs
)
->
str
:
...
...
opencompass/models/hunyuan_api.py
0 → 100644
View file @
b39f5015
import
json
from
concurrent.futures
import
ThreadPoolExecutor
from
typing
import
Dict
,
List
,
Optional
,
Union
from
opencompass.utils.prompt
import
PromptList
from
.base_api
import
BaseAPIModel
PromptType
=
Union
[
PromptList
,
str
]
class
Hunyuan
(
BaseAPIModel
):
def
__init__
(
self
,
path
:
str
,
secret_id
:
str
,
secret_key
:
str
,
endpoint
:
str
,
query_per_second
:
int
=
2
,
max_seq_len
:
int
=
2048
,
meta_template
:
Optional
[
Dict
]
=
None
,
retry
:
int
=
2
,
):
# noqa E125
super
().
__init__
(
path
=
path
,
max_seq_len
=
max_seq_len
,
query_per_second
=
query_per_second
,
meta_template
=
meta_template
,
retry
=
retry
,
)
self
.
secret_id
=
secret_id
self
.
secret_key
=
secret_key
self
.
endpoint
=
endpoint
from
tencentcloud.common
import
credential
from
tencentcloud.common.common_client
import
CommonClient
from
tencentcloud.common.profile.client_profile
import
ClientProfile
from
tencentcloud.common.profile.http_profile
import
HttpProfile
cred
=
credential
.
Credential
(
self
.
secret_id
,
self
.
secret_key
)
httpProfile
=
HttpProfile
()
httpProfile
.
endpoint
=
self
.
endpoint
clientProfile
=
ClientProfile
()
clientProfile
.
httpProfile
=
httpProfile
self
.
client
=
CommonClient
(
'hunyuan'
,
'2023-09-01'
,
cred
,
'ap-beijing'
,
profile
=
clientProfile
)
def
generate
(
self
,
inputs
:
List
[
PromptType
],
max_out_len
:
int
=
512
)
->
List
[
str
]:
"""Generate results given a list of inputs.
Args:
inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
with
ThreadPoolExecutor
()
as
executor
:
results
=
list
(
executor
.
map
(
self
.
_generate
,
inputs
,
[
max_out_len
]
*
len
(
inputs
)))
self
.
flush
()
return
results
def
_generate
(
self
,
input
:
PromptType
,
max_out_len
:
int
=
512
)
->
str
:
"""Generate results given an input.
Args:
inputs (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
str: The generated string.
"""
assert
isinstance
(
input
,
(
str
,
PromptList
))
if
isinstance
(
input
,
str
):
messages
=
[{
'role'
:
'user'
,
'content'
:
input
}]
else
:
messages
=
[]
for
item
in
input
:
msg
=
{
'Content'
:
item
[
'prompt'
]}
if
item
[
'role'
]
==
'HUMAN'
:
msg
[
'Role'
]
=
'user'
elif
item
[
'role'
]
==
'BOT'
:
msg
[
'Role'
]
=
'assistant'
messages
.
append
(
msg
)
from
tencentcloud.common.exception.tencent_cloud_sdk_exception
import
\
TencentCloudSDKException
data
=
{
'Messages'
:
messages
}
for
_
in
range
(
self
.
retry
):
try
:
resp
=
self
.
client
.
call_sse
(
'ChatPro'
,
data
)
contents
=
[]
for
event
in
resp
:
part
=
json
.
loads
(
event
[
'data'
])
contents
.
append
(
part
[
'Choices'
][
0
][
'Delta'
][
'Content'
])
answer
=
''
.
join
(
contents
)
except
TencentCloudSDKException
as
err
:
print
(
err
)
print
(
answer
)
return
answer
raise
RuntimeError
(
f
'Failed to respond in
{
self
.
retry
}
retrys'
)
opencompass/models/llama2.py
View file @
b39f5015
...
@@ -199,7 +199,7 @@ class Llama2Chat(BaseModel):
...
@@ -199,7 +199,7 @@ class Llama2Chat(BaseModel):
self
.
tokenizer
=
Tokenizer
(
tokenizer_path
)
self
.
tokenizer
=
Tokenizer
(
tokenizer_path
)
def
generate
(
self
,
def
generate
(
self
,
inputs
:
List
[
str
or
Prompt
List
],
inputs
:
List
[
Prompt
Type
],
max_out_len
:
int
=
512
,
max_out_len
:
int
=
512
,
temperature
:
float
=
0.6
)
->
str
:
temperature
:
float
=
0.6
)
->
str
:
"""Generate response from input prompt.
"""Generate response from input prompt.
...
...
opencompass/models/lmdeploy_pytorch.py
View file @
b39f5015
...
@@ -124,13 +124,13 @@ class LmdeployPytorchModel(BaseModel):
...
@@ -124,13 +124,13 @@ class LmdeployPytorchModel(BaseModel):
def
_generate
(
self
,
def
_generate
(
self
,
generator
,
generator
,
session_id
,
session_id
,
prompt
:
str
or
Prompt
List
,
prompt
:
Prompt
Type
,
gen_config
=
None
,
gen_config
=
None
,
end_str
:
Optional
[
str
]
=
None
)
->
str
:
end_str
:
Optional
[
str
]
=
None
)
->
str
:
"""Generate results given a list of inputs.
"""Generate results given a list of inputs.
Args:
Args:
prompt (
str or
Prompt
List
): A string or PromptDict.
prompt (Prompt
Type
): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
The PromptDict should be organized in OpenCompass'
API format.
API format.
gen_config (EngineGenerationConfig, optional): Generation
gen_config (EngineGenerationConfig, optional): Generation
...
...
opencompass/models/minimax_api.py
View file @
b39f5015
...
@@ -60,13 +60,13 @@ class MiniMax(BaseAPIModel):
...
@@ -60,13 +60,13 @@ class MiniMax(BaseAPIModel):
def
generate
(
def
generate
(
self
,
self
,
inputs
:
List
[
str
or
Prompt
List
],
inputs
:
List
[
Prompt
Type
],
max_out_len
:
int
=
512
,
max_out_len
:
int
=
512
,
)
->
List
[
str
]:
)
->
List
[
str
]:
"""Generate results given a list of inputs.
"""Generate results given a list of inputs.
Args:
Args:
inputs (List[
str or
Prompt
List
]): A list of strings or PromptDicts.
inputs (List[Prompt
Type
]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
The PromptDict should be organized in OpenCompass'
API format.
API format.
max_out_len (int): The maximum length of the output.
max_out_len (int): The maximum length of the output.
...
@@ -83,13 +83,13 @@ class MiniMax(BaseAPIModel):
...
@@ -83,13 +83,13 @@ class MiniMax(BaseAPIModel):
def
_generate
(
def
_generate
(
self
,
self
,
input
:
str
or
Prompt
List
,
input
:
Prompt
Type
,
max_out_len
:
int
=
512
,
max_out_len
:
int
=
512
,
)
->
str
:
)
->
str
:
"""Generate results given an input.
"""Generate results given an input.
Args:
Args:
inputs (
str or
Prompt
List
): A string or PromptDict.
inputs (Prompt
Type
): A string or PromptDict.
The PromptDict should be organized in Test'
The PromptDict should be organized in Test'
API format.
API format.
max_out_len (int): The maximum length of the output.
max_out_len (int): The maximum length of the output.
...
...
opencompass/models/mistral_api.py
0 → 100644
View file @
b39f5015
from
concurrent.futures
import
ThreadPoolExecutor
from
typing
import
Dict
,
List
,
Optional
,
Union
import
requests
from
opencompass.utils.prompt
import
PromptList
from
.base_api
import
BaseAPIModel
PromptType
=
Union
[
PromptList
,
str
]
class
Mistral
(
BaseAPIModel
):
def
__init__
(
self
,
path
:
str
,
api_key
:
str
,
url
:
str
,
query_per_second
:
int
=
2
,
max_seq_len
:
int
=
2048
,
meta_template
:
Optional
[
Dict
]
=
None
,
retry
:
int
=
2
,
):
# noqa E125
super
().
__init__
(
path
=
path
,
max_seq_len
=
max_seq_len
,
query_per_second
=
query_per_second
,
meta_template
=
meta_template
,
retry
=
retry
,
)
self
.
api_key
=
api_key
self
.
url
=
url
self
.
model
=
path
def
generate
(
self
,
inputs
:
List
[
PromptType
],
max_out_len
:
int
=
512
)
->
List
[
str
]:
"""Generate results given a list of inputs.
Args:
inputs (List[PromptType]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
with
ThreadPoolExecutor
()
as
executor
:
results
=
list
(
executor
.
map
(
self
.
_generate
,
inputs
,
[
max_out_len
]
*
len
(
inputs
)))
self
.
flush
()
return
results
def
_generate
(
self
,
input
:
PromptType
,
max_out_len
:
int
=
512
)
->
str
:
"""Generate results given an input.
Args:
inputs (PromptType): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
str: The generated string.
"""
assert
isinstance
(
input
,
(
str
,
PromptList
))
if
isinstance
(
input
,
str
):
messages
=
[{
'role'
:
'user'
,
'content'
:
input
}]
else
:
messages
=
[]
for
item
in
input
:
msg
=
{
'content'
:
item
[
'prompt'
]}
if
item
[
'role'
]
==
'HUMAN'
:
msg
[
'role'
]
=
'user'
elif
item
[
'role'
]
==
'BOT'
:
msg
[
'role'
]
=
'assistant'
elif
item
[
'role'
]
==
'SYSTEM'
:
msg
[
'role'
]
=
'system'
messages
.
append
(
msg
)
messages
[
-
1
][
'role'
]
=
'user'
data
=
{
'model'
:
self
.
path
,
'messages'
:
messages
,
}
headers
=
{
'Content-Type'
:
'application/json'
,
'Accept'
:
'application/json'
,
'Authorization'
:
f
'Bearer
{
self
.
api_key
}
'
,
}
from
pprint
import
pprint
print
(
'-'
*
128
)
pprint
(
data
)
for
_
in
range
(
self
.
retry
):
try
:
response
=
requests
.
post
(
self
.
url
,
json
=
data
,
headers
=
headers
)
except
Exception
as
e
:
print
(
e
)
continue
try
:
response
=
response
.
json
()
except
Exception
as
e
:
print
(
e
)
continue
print
(
'='
*
128
)
pprint
(
response
)
try
:
msg
=
response
[
'choices'
][
0
][
'message'
][
'content'
]
except
Exception
as
e
:
print
(
e
)
continue
return
msg
raise
RuntimeError
(
f
'Failed to respond in
{
self
.
retry
}
retrys'
)
opencompass/models/moonshot_api.py
View file @
b39f5015
...
@@ -55,13 +55,13 @@ class MoonShot(BaseAPIModel):
...
@@ -55,13 +55,13 @@ class MoonShot(BaseAPIModel):
def
generate
(
def
generate
(
self
,
self
,
inputs
:
List
[
str
or
Prompt
List
],
inputs
:
List
[
Prompt
Type
],
max_out_len
:
int
=
512
,
max_out_len
:
int
=
512
,
)
->
List
[
str
]:
)
->
List
[
str
]:
"""Generate results given a list of inputs.
"""Generate results given a list of inputs.
Args:
Args:
inputs (List[
str or
Prompt
List
]): A list of strings or PromptDicts.
inputs (List[Prompt
Type
]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
The PromptDict should be organized in OpenCompass'
API format.
API format.
max_out_len (int): The maximum length of the output.
max_out_len (int): The maximum length of the output.
...
@@ -78,13 +78,13 @@ class MoonShot(BaseAPIModel):
...
@@ -78,13 +78,13 @@ class MoonShot(BaseAPIModel):
def
_generate
(
def
_generate
(
self
,
self
,
input
:
str
or
Prompt
List
,
input
:
Prompt
Type
,
max_out_len
:
int
=
512
,
max_out_len
:
int
=
512
,
)
->
str
:
)
->
str
:
"""Generate results given an input.
"""Generate results given an input.
Args:
Args:
inputs (
str or
Prompt
List
): A string or PromptDict.
inputs (Prompt
Type
): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
The PromptDict should be organized in OpenCompass'
API format.
API format.
max_out_len (int): The maximum length of the output.
max_out_len (int): The maximum length of the output.
...
@@ -98,29 +98,27 @@ class MoonShot(BaseAPIModel):
...
@@ -98,29 +98,27 @@ class MoonShot(BaseAPIModel):
messages
=
[{
'role'
:
'user'
,
'content'
:
input
}]
messages
=
[{
'role'
:
'user'
,
'content'
:
input
}]
else
:
else
:
messages
=
[]
messages
=
[]
msg_buffer
,
last_role
=
[],
None
for
item
in
input
:
for
item
in
input
:
msg
=
{
'content'
:
item
[
'prompt'
]}
item
[
'role'
]
=
'assistant'
if
item
[
'role'
]
==
'BOT'
else
'user'
if
item
[
'role'
]
==
'HUMAN'
:
if
item
[
'role'
]
!=
last_role
and
last_role
is
not
None
:
msg
[
'role'
]
=
'user'
messages
.
append
({
elif
item
[
'role'
]
==
'BOT'
:
'content'
:
'
\n
'
.
join
(
msg_buffer
),
msg
[
'role'
]
=
'assistant'
'role'
:
last_role
})
messages
.
append
(
msg
)
msg_buffer
=
[]
msg_buffer
.
append
(
item
[
'prompt'
])
system
=
{
last_role
=
item
[
'role'
]
'role'
:
'system'
,
messages
.
append
({
'content'
:
self
.
system_prompt
'content'
:
'
\n
'
.
join
(
msg_buffer
),
# '你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。'
'role'
:
last_role
# '你会为用户提供安全,有帮助,准确的回答。同时,你会拒绝一些涉及恐怖主义,种族歧视,'
})
# '黄色暴力等问题的回答。Moonshot AI 为专有名词,不可翻译成其他语言。'
}
if
self
.
system_prompt
:
system
=
{
'role'
:
'system'
,
'content'
:
self
.
system_prompt
}
messages
.
insert
(
0
,
system
)
messages
.
insert
(
0
,
system
)
data
=
{
data
=
{
'model'
:
self
.
model
,
'messages'
:
messages
}
'model'
:
self
.
model
,
'messages'
:
messages
,
}
max_num_retries
=
0
max_num_retries
=
0
while
max_num_retries
<
self
.
retry
:
while
max_num_retries
<
self
.
retry
:
...
...
opencompass/models/nanbeige_api.py
View file @
b39f5015
...
@@ -52,13 +52,13 @@ class Nanbeige(BaseAPIModel):
...
@@ -52,13 +52,13 @@ class Nanbeige(BaseAPIModel):
def
generate
(
def
generate
(
self
,
self
,
inputs
:
List
[
str
or
Prompt
List
],
inputs
:
List
[
Prompt
Type
],
max_out_len
:
int
=
512
,
max_out_len
:
int
=
512
,
)
->
List
[
str
]:
)
->
List
[
str
]:
"""Generate results given a list of inputs.
"""Generate results given a list of inputs.
Args:
Args:
inputs (List[
str or
Prompt
List
]): A list of strings or PromptDicts.
inputs (List[Prompt
Type
]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
The PromptDict should be organized in OpenCompass'
API format.
API format.
max_out_len (int): The maximum length of the output.
max_out_len (int): The maximum length of the output.
...
@@ -75,13 +75,13 @@ class Nanbeige(BaseAPIModel):
...
@@ -75,13 +75,13 @@ class Nanbeige(BaseAPIModel):
def
_generate
(
def
_generate
(
self
,
self
,
input
:
str
or
Prompt
List
,
input
:
Prompt
Type
,
max_out_len
:
int
=
512
,
max_out_len
:
int
=
512
,
)
->
str
:
)
->
str
:
"""Generate results given an input.
"""Generate results given an input.
Args:
Args:
inputs (
str or
Prompt
List
): A string or PromptDict.
inputs (Prompt
Type
): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
The PromptDict should be organized in OpenCompass'
API format.
API format.
max_out_len (int): The maximum length of the output.
max_out_len (int): The maximum length of the output.
...
...
opencompass/models/openai_api.py
View file @
b39f5015
...
@@ -103,14 +103,14 @@ class OpenAI(BaseAPIModel):
...
@@ -103,14 +103,14 @@ class OpenAI(BaseAPIModel):
def
generate
(
def
generate
(
self
,
self
,
inputs
:
List
[
str
or
Prompt
List
],
inputs
:
List
[
Prompt
Type
],
max_out_len
:
int
=
512
,
max_out_len
:
int
=
512
,
temperature
:
float
=
0.7
,
temperature
:
float
=
0.7
,
)
->
List
[
str
]:
)
->
List
[
str
]:
"""Generate results given a list of inputs.
"""Generate results given a list of inputs.
Args:
Args:
inputs (List[
str or
Prompt
List
]): A list of strings or PromptDicts.
inputs (List[Prompt
Type
]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
The PromptDict should be organized in OpenCompass'
API format.
API format.
max_out_len (int): The maximum length of the output.
max_out_len (int): The maximum length of the output.
...
@@ -132,12 +132,12 @@ class OpenAI(BaseAPIModel):
...
@@ -132,12 +132,12 @@ class OpenAI(BaseAPIModel):
[
temperature
]
*
len
(
inputs
)))
[
temperature
]
*
len
(
inputs
)))
return
results
return
results
def
_generate
(
self
,
input
:
str
or
Prompt
List
,
max_out_len
:
int
,
def
_generate
(
self
,
input
:
Prompt
Type
,
max_out_len
:
int
,
temperature
:
float
)
->
str
:
temperature
:
float
)
->
str
:
"""Generate results given a list of inputs.
"""Generate results given a list of inputs.
Args:
Args:
inputs (
str or
Prompt
List
): A string or PromptDict.
inputs (Prompt
Type
): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
The PromptDict should be organized in OpenCompass'
API format.
API format.
max_out_len (int): The maximum length of the output.
max_out_len (int): The maximum length of the output.
...
@@ -207,6 +207,7 @@ class OpenAI(BaseAPIModel):
...
@@ -207,6 +207,7 @@ class OpenAI(BaseAPIModel):
header
=
{
header
=
{
'Authorization'
:
f
'Bearer
{
key
}
'
,
'Authorization'
:
f
'Bearer
{
key
}
'
,
'content-type'
:
'application/json'
,
'content-type'
:
'application/json'
,
'api-key'
:
key
,
}
}
if
self
.
orgs
:
if
self
.
orgs
:
...
@@ -239,6 +240,7 @@ class OpenAI(BaseAPIModel):
...
@@ -239,6 +240,7 @@ class OpenAI(BaseAPIModel):
self
.
logger
.
error
(
'JsonDecode error, got'
,
self
.
logger
.
error
(
'JsonDecode error, got'
,
str
(
raw_response
.
content
))
str
(
raw_response
.
content
))
continue
continue
self
.
logger
.
error
(
str
(
response
))
try
:
try
:
if
self
.
logprobs
:
if
self
.
logprobs
:
return
response
[
'choices'
]
return
response
[
'choices'
]
...
@@ -247,13 +249,16 @@ class OpenAI(BaseAPIModel):
...
@@ -247,13 +249,16 @@ class OpenAI(BaseAPIModel):
except
KeyError
:
except
KeyError
:
if
'error'
in
response
:
if
'error'
in
response
:
if
response
[
'error'
][
'code'
]
==
'rate_limit_exceeded'
:
if
response
[
'error'
][
'code'
]
==
'rate_limit_exceeded'
:
time
.
sleep
(
1
)
time
.
sleep
(
1
0
)
self
.
logger
.
warn
(
'Rate limit exceeded, retrying...'
)
self
.
logger
.
warn
(
'Rate limit exceeded, retrying...'
)
continue
continue
elif
response
[
'error'
][
'code'
]
==
'insufficient_quota'
:
elif
response
[
'error'
][
'code'
]
==
'insufficient_quota'
:
self
.
invalid_keys
.
add
(
key
)
self
.
invalid_keys
.
add
(
key
)
self
.
logger
.
warn
(
f
'insufficient_quota key:
{
key
}
'
)
self
.
logger
.
warn
(
f
'insufficient_quota key:
{
key
}
'
)
continue
continue
elif
response
[
'error'
][
'code'
]
==
'invalid_prompt'
:
self
.
logger
.
warn
(
'Invalid prompt:'
,
str
(
input
))
return
''
self
.
logger
.
error
(
'Find error message in response: '
,
self
.
logger
.
error
(
'Find error message in response: '
,
str
(
response
[
'error'
]))
str
(
response
[
'error'
]))
...
@@ -363,12 +368,12 @@ class OpenAIAllesAPIN(OpenAI):
...
@@ -363,12 +368,12 @@ class OpenAIAllesAPIN(OpenAI):
'content-type'
:
'application/json'
,
'content-type'
:
'application/json'
,
}
}
def
_generate
(
self
,
input
:
str
or
Prompt
List
,
max_out_len
:
int
,
def
_generate
(
self
,
input
:
Prompt
Type
,
max_out_len
:
int
,
temperature
:
float
)
->
str
:
temperature
:
float
)
->
str
:
"""Generate results given an input.
"""Generate results given an input.
Args:
Args:
inputs (
str or
Prompt
List
): A string or PromptDict.
inputs (Prompt
Type
): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
The PromptDict should be organized in OpenCompass'
API format.
API format.
max_out_len (int): The maximum length of the output.
max_out_len (int): The maximum length of the output.
...
...
opencompass/models/pangu_api.py
View file @
b39f5015
...
@@ -67,13 +67,13 @@ class PanGu(BaseAPIModel):
...
@@ -67,13 +67,13 @@ class PanGu(BaseAPIModel):
def
generate
(
def
generate
(
self
,
self
,
inputs
:
List
[
str
or
Prompt
List
],
inputs
:
List
[
Prompt
Type
],
max_out_len
:
int
=
512
,
max_out_len
:
int
=
512
,
)
->
List
[
str
]:
)
->
List
[
str
]:
"""Generate results given a list of inputs.
"""Generate results given a list of inputs.
Args:
Args:
inputs (List[
str or
Prompt
List
]): A list of strings or PromptDicts.
inputs (List[Prompt
Type
]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
The PromptDict should be organized in OpenCompass'
API format.
API format.
max_out_len (int): The maximum length of the output.
max_out_len (int): The maximum length of the output.
...
@@ -117,13 +117,13 @@ class PanGu(BaseAPIModel):
...
@@ -117,13 +117,13 @@ class PanGu(BaseAPIModel):
def
_generate
(
def
_generate
(
self
,
self
,
input
:
str
or
Prompt
List
,
input
:
Prompt
Type
,
max_out_len
:
int
=
512
,
max_out_len
:
int
=
512
,
)
->
str
:
)
->
str
:
"""Generate results given an input.
"""Generate results given an input.
Args:
Args:
inputs (
str or
Prompt
List
): A string or PromptDict.
inputs (Prompt
Type
): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
The PromptDict should be organized in OpenCompass'
API format.
API format.
max_out_len (int): The maximum length of the output.
max_out_len (int): The maximum length of the output.
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment