Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
opencompass
Commits
ba027eee
Unverified
Commit
ba027eee
authored
Jan 02, 2024
by
HUANG Fei
Committed by
GitHub
Jan 02, 2024
Browse files
[Feature] Add support of qwen api (#735)
parent
33f8df1c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
194 additions
and
0 deletions
+194
-0
configs/api_examples/eval_api_qwen.py
configs/api_examples/eval_api_qwen.py
+40
-0
opencompass/models/__init__.py
opencompass/models/__init__.py
+1
-0
opencompass/models/qwen_api.py
opencompass/models/qwen_api.py
+153
-0
No files found.
configs/api_examples/eval_api_qwen.py
0 → 100644
View file @
ba027eee
from
mmengine.config
import
read_base
from
opencompass.models
import
Qwen
from
opencompass.partitioners
import
NaivePartitioner
from
opencompass.runners.local_api
import
LocalAPIRunner
from
opencompass.tasks
import
OpenICLInferTask
with
read_base
():
from
..summarizers.medium
import
summarizer
from
..datasets.ceval.ceval_gen
import
ceval_datasets
datasets
=
[
*
ceval_datasets
,
]
models
=
[
dict
(
abbr
=
'qwen-max'
,
type
=
Qwen
,
path
=
'qwen-max'
,
key
=
'xxxxxxxxxxxxxxxx'
,
# please give you key
generation_kwargs
=
{
'enable_search'
:
False
,
},
query_per_second
=
1
,
max_out_len
=
2048
,
max_seq_len
=
2048
,
batch_size
=
8
),
]
infer
=
dict
(
partitioner
=
dict
(
type
=
NaivePartitioner
),
runner
=
dict
(
type
=
LocalAPIRunner
,
max_num_workers
=
1
,
concurrent_users
=
1
,
task
=
dict
(
type
=
OpenICLInferTask
)),
)
work_dir
=
"outputs/api_qwen/"
opencompass/models/__init__.py
View file @
ba027eee
...
@@ -19,6 +19,7 @@ from .modelscope import ModelScope, ModelScopeCausalLM # noqa: F401, F403
...
@@ -19,6 +19,7 @@ from .modelscope import ModelScope, ModelScopeCausalLM # noqa: F401, F403
from
.moonshot_api
import
MoonShot
# noqa: F401
from
.moonshot_api
import
MoonShot
# noqa: F401
from
.openai_api
import
OpenAI
# noqa: F401
from
.openai_api
import
OpenAI
# noqa: F401
from
.pangu_api
import
PanGu
# noqa: F401
from
.pangu_api
import
PanGu
# noqa: F401
from
.qwen_api
import
Qwen
# noqa: F401
from
.sensetime_api
import
SenseTime
# noqa: F401
from
.sensetime_api
import
SenseTime
# noqa: F401
from
.turbomind
import
TurboMindModel
# noqa: F401
from
.turbomind
import
TurboMindModel
# noqa: F401
from
.turbomind_tis
import
TurboMindTisModel
# noqa: F401
from
.turbomind_tis
import
TurboMindTisModel
# noqa: F401
...
...
opencompass/models/qwen_api.py
0 → 100644
View file @
ba027eee
import
time
from
concurrent.futures
import
ThreadPoolExecutor
from
typing
import
Dict
,
List
,
Optional
,
Union
from
opencompass.utils.prompt
import
PromptList
from
.base_api
import
BaseAPIModel
PromptType
=
Union
[
PromptList
,
str
]
class
Qwen
(
BaseAPIModel
):
"""Model wrapper around Qwen.
Documentation:
https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions/
Args:
path (str): The name of qwen model.
e.g. `qwen-max`
key (str): Authorization key.
query_per_second (int): The maximum queries allowed per second
between two consecutive calls of the API. Defaults to 1.
max_seq_len (int): Unused here.
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
retry (int): Number of retires if the API call fails. Defaults to 2.
"""
def
__init__
(
self
,
path
:
str
,
key
:
str
,
query_per_second
:
int
=
1
,
max_seq_len
:
int
=
2048
,
meta_template
:
Optional
[
Dict
]
=
None
,
retry
:
int
=
5
,
generation_kwargs
:
Dict
=
{}):
super
().
__init__
(
path
=
path
,
max_seq_len
=
max_seq_len
,
query_per_second
=
query_per_second
,
meta_template
=
meta_template
,
retry
=
retry
,
generation_kwargs
=
generation_kwargs
)
import
dashscope
dashscope
.
api_key
=
key
self
.
dashscope
=
dashscope
def
generate
(
self
,
inputs
:
List
[
str
or
PromptList
],
max_out_len
:
int
=
512
,
)
->
List
[
str
]:
"""Generate results given a list of inputs.
Args:
inputs (List[str or PromptList]): A list of strings or PromptDicts.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
with
ThreadPoolExecutor
()
as
executor
:
results
=
list
(
executor
.
map
(
self
.
_generate
,
inputs
,
[
max_out_len
]
*
len
(
inputs
)))
self
.
flush
()
return
results
def
_generate
(
self
,
input
:
str
or
PromptList
,
max_out_len
:
int
=
512
,
)
->
str
:
"""Generate results given an input.
Args:
inputs (str or PromptList): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.
Returns:
str: The generated string.
"""
assert
isinstance
(
input
,
(
str
,
PromptList
))
"""
{
"messages": [
{"role":"user","content":"请介绍一下你自己"},
{"role":"assistant","content":"我是通义千问"},
{"role":"user","content": "我在上海,周末可以去哪里玩?"},
{"role":"assistant","content": "上海是一个充满活力和文化氛围的城市"},
{"role":"user","content": "周末这里的天气怎么样?"}
]
}
"""
if
isinstance
(
input
,
str
):
messages
=
[{
'role'
:
'user'
,
'content'
:
input
}]
else
:
messages
=
[]
for
item
in
input
:
msg
=
{
'content'
:
item
[
'prompt'
]}
if
item
[
'role'
]
==
'HUMAN'
:
msg
[
'role'
]
=
'user'
elif
item
[
'role'
]
==
'BOT'
:
msg
[
'role'
]
=
'assistant'
messages
.
append
(
msg
)
data
=
{
'messages'
:
messages
}
data
.
update
(
self
.
generation_kwargs
)
max_num_retries
=
0
while
max_num_retries
<
self
.
retry
:
self
.
acquire
()
response
=
self
.
dashscope
.
Generation
.
call
(
model
=
self
.
path
,
**
data
,
)
self
.
release
()
if
response
is
None
:
print
(
'Connection error, reconnect.'
)
# if connect error, frequent requests will casuse
# continuous unstable network, therefore wait here
# to slow down the request
self
.
wait
()
continue
if
response
.
status_code
==
200
:
try
:
msg
=
response
.
output
.
text
return
msg
except
KeyError
:
print
(
response
)
self
.
logger
.
error
(
str
(
response
.
status_code
))
time
.
sleep
(
1
)
continue
if
(
'Range of input length should be '
in
response
.
message
or
# input too long
'Input data may contain inappropriate content.'
in
response
.
message
):
# bad input
print
(
response
.
message
)
return
''
print
(
response
)
max_num_retries
+=
1
raise
RuntimeError
(
response
.
message
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment