Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
opencompass
Commits
793e32c9
"docs/vscode:/vscode.git/clone" did not exist on "d39899e85c5c29b3aeb2ea36d19f59214de60336"
Unverified
Commit
793e32c9
authored
Jan 24, 2024
by
Songyang Zhang
Committed by
GitHub
Jan 24, 2024
Browse files
[Feature] Update API implementation (#834)
parent
2ee8e8a1
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
85 additions
and
55 deletions
+85
-55
opencompass/models/baichuan_api.py
opencompass/models/baichuan_api.py
+18
-30
opencompass/models/baidu_api.py
opencompass/models/baidu_api.py
+21
-7
opencompass/models/minimax_api.py
opencompass/models/minimax_api.py
+14
-9
opencompass/models/moonshot_api.py
opencompass/models/moonshot_api.py
+12
-5
opencompass/models/qwen_api.py
opencompass/models/qwen_api.py
+19
-4
requirements/api.txt
requirements/api.txt
+1
-0
No files found.
opencompass/models/baichuan_api.py
View file @
793e32c9
import
hashlib
import
json
import
time
from
concurrent.futures
import
ThreadPoolExecutor
from
typing
import
Dict
,
List
,
Optional
,
Union
...
...
@@ -22,7 +20,6 @@ class BaiChuan(BaseAPIModel):
path (str): The name of Baichuan model.
e.g. `Baichuan2-53B`
api_key (str): Provided api key
secretkey (str): secretkey in order to obtain access_token
url (str): Provide url
query_per_second (int): The maximum queries allowed per second
between two consecutive calls of the API. Defaults to 1.
...
...
@@ -37,7 +34,6 @@ class BaiChuan(BaseAPIModel):
self
,
path
:
str
,
api_key
:
str
,
secret_key
:
str
,
url
:
str
,
query_per_second
:
int
=
2
,
max_seq_len
:
int
=
2048
,
...
...
@@ -48,6 +44,7 @@ class BaiChuan(BaseAPIModel):
'top_p'
:
0.85
,
'top_k'
:
5
,
'with_search_enhance'
:
False
,
'stream'
:
False
,
}):
# noqa E125
super
().
__init__
(
path
=
path
,
max_seq_len
=
max_seq_len
,
...
...
@@ -57,7 +54,6 @@ class BaiChuan(BaseAPIModel):
generation_kwargs
=
generation_kwargs
)
self
.
api_key
=
api_key
self
.
secret_key
=
secret_key
self
.
url
=
url
self
.
model
=
path
...
...
@@ -119,36 +115,28 @@ class BaiChuan(BaseAPIModel):
data
=
{
'model'
:
self
.
model
,
'messages'
:
messages
}
data
.
update
(
self
.
generation_kwargs
)
def
calculate_md5
(
input_string
):
md5
=
hashlib
.
md5
()
md5
.
update
(
input_string
.
encode
(
'utf-8'
))
encrypted
=
md5
.
hexdigest
()
return
encrypted
json_data
=
json
.
dumps
(
data
)
time_stamp
=
int
(
time
.
time
())
signature
=
calculate_md5
(
self
.
secret_key
+
json_data
+
str
(
time_stamp
))
headers
=
{
'Content-Type'
:
'application/json'
,
'Authorization'
:
'Bearer '
+
self
.
api_key
,
'X-BC-Request-Id'
:
'your requestId'
,
'X-BC-Timestamp'
:
str
(
time_stamp
),
'X-BC-Signature'
:
signature
,
'X-BC-Sign-Algo'
:
'MD5'
,
}
max_num_retries
=
0
while
max_num_retries
<
self
.
retry
:
self
.
acquire
()
try
:
raw_response
=
requests
.
request
(
'POST'
,
url
=
self
.
url
,
headers
=
headers
,
json
=
data
)
response
=
raw_response
.
json
()
self
.
release
()
except
Exception
as
err
:
print
(
'Request Error:{}'
.
format
(
err
))
time
.
sleep
(
3
)
continue
self
.
release
()
# print(response.keys())
# print(response['choices'][0]['message']['content'])
if
response
is
None
:
print
(
'Connection error, reconnect.'
)
# if connect error, frequent requests will casuse
...
...
@@ -156,13 +144,13 @@ class BaiChuan(BaseAPIModel):
# to slow down the request
self
.
wait
()
continue
if
raw_response
.
status_code
==
200
and
response
[
'code'
]
==
0
:
if
raw_response
.
status_code
==
200
:
msg
=
response
[
'
data'
][
'message
s'
][
0
][
'content'
]
msg
=
response
[
'
choices'
][
0
][
'message
'
][
'content'
]
return
msg
if
response
[
'
code
'
]
!=
0
:
print
(
response
)
if
raw_
response
.
status_
code
!=
20
0
:
print
(
raw_
response
)
time
.
sleep
(
1
)
continue
print
(
response
)
...
...
opencompass/models/baidu_api.py
View file @
793e32c9
...
...
@@ -54,6 +54,9 @@ class ERNIEBot(BaseAPIModel):
self
.
secretkey
=
secretkey
self
.
key
=
key
self
.
url
=
url
access_token
,
_
=
self
.
_generate_access_token
()
self
.
access_token
=
access_token
print
(
access_token
)
def
_generate_access_token
(
self
):
try
:
...
...
@@ -154,12 +157,18 @@ class ERNIEBot(BaseAPIModel):
max_num_retries
=
0
while
max_num_retries
<
self
.
retry
:
self
.
acquire
()
access_token
,
_
=
self
.
_generate_access_token
()
try
:
raw_response
=
requests
.
request
(
'POST'
,
url
=
self
.
url
+
access_token
,
url
=
self
.
url
+
self
.
access_token
,
headers
=
self
.
headers
,
json
=
data
)
response
=
raw_response
.
json
()
except
Exception
as
err
:
print
(
'Request Error:{}'
.
format
(
err
))
time
.
sleep
(
3
)
continue
self
.
release
()
if
response
is
None
:
...
...
@@ -176,6 +185,10 @@ class ERNIEBot(BaseAPIModel):
except
KeyError
:
print
(
response
)
self
.
logger
.
error
(
str
(
response
[
'error_code'
]))
if
response
[
'error_code'
]
==
336007
:
# exceed max length
return
''
time
.
sleep
(
1
)
continue
...
...
@@ -189,7 +202,8 @@ class ERNIEBot(BaseAPIModel):
or
response
[
'error_code'
]
==
216100
or
response
[
'error_code'
]
==
336001
or
response
[
'error_code'
]
==
336003
or
response
[
'error_code'
]
==
336000
):
or
response
[
'error_code'
]
==
336000
or
response
[
'error_code'
]
==
336007
):
print
(
response
[
'error_msg'
])
return
''
print
(
response
)
...
...
opencompass/models/minimax_api.py
View file @
793e32c9
...
...
@@ -90,7 +90,7 @@ class MiniMax(BaseAPIModel):
Args:
inputs (str or PromptList): A string or PromptDict.
The PromptDict should be organized in
OpenCompass
'
The PromptDict should be organized in
Test
'
API format.
max_out_len (int): The maximum length of the output.
...
...
@@ -102,7 +102,7 @@ class MiniMax(BaseAPIModel):
if
isinstance
(
input
,
str
):
messages
=
[{
'sender_type'
:
'USER'
,
'sender_name'
:
'
OpenCompass
'
,
'sender_name'
:
'
Test
'
,
'text'
:
input
}]
else
:
...
...
@@ -111,7 +111,7 @@ class MiniMax(BaseAPIModel):
msg
=
{
'text'
:
item
[
'prompt'
]}
if
item
[
'role'
]
==
'HUMAN'
:
msg
[
'sender_type'
]
=
'USER'
msg
[
'sender_name'
]
=
'
OpenCompass
'
msg
[
'sender_name'
]
=
'
Test
'
elif
item
[
'role'
]
==
'BOT'
:
msg
[
'sender_type'
]
=
'BOT'
msg
[
'sender_name'
]
=
'MM智能助理'
...
...
@@ -135,15 +135,19 @@ class MiniMax(BaseAPIModel):
'messages'
:
messages
}
max_num_retries
=
0
while
max_num_retries
<
self
.
retry
:
self
.
acquire
()
try
:
raw_response
=
requests
.
request
(
'POST'
,
url
=
self
.
url
,
headers
=
self
.
headers
,
json
=
data
)
response
=
raw_response
.
json
()
except
Exception
as
err
:
print
(
'Request Error:{}'
.
format
(
err
))
time
.
sleep
(
3
)
continue
self
.
release
()
if
response
is
None
:
...
...
@@ -157,6 +161,7 @@ class MiniMax(BaseAPIModel):
# msg = json.load(response.text)
# response
msg
=
response
[
'reply'
]
# msg = response['choices']['messages']['text']
return
msg
# sensitive content, prompt overlength, network error
# or illegal prompt
...
...
opencompass/models/moonshot_api.py
View file @
793e32c9
...
...
@@ -125,10 +125,15 @@ class MoonShot(BaseAPIModel):
max_num_retries
=
0
while
max_num_retries
<
self
.
retry
:
self
.
acquire
()
try
:
raw_response
=
requests
.
request
(
'POST'
,
url
=
self
.
url
,
headers
=
self
.
headers
,
json
=
data
)
except
Exception
as
err
:
print
(
'Request Error:{}'
.
format
(
err
))
time
.
sleep
(
2
)
continue
response
=
raw_response
.
json
()
self
.
release
()
...
...
@@ -153,12 +158,14 @@ class MoonShot(BaseAPIModel):
elif
raw_response
.
status_code
==
400
:
print
(
messages
,
response
)
print
(
'请求失败,状态码:'
,
raw_response
)
msg
=
'The request was rejected because high risk'
return
msg
time
.
sleep
(
1
)
continue
elif
raw_response
.
status_code
==
429
:
print
(
messages
,
response
)
print
(
'请求失败,状态码:'
,
raw_response
)
time
.
sleep
(
3
)
time
.
sleep
(
5
)
continue
max_num_retries
+=
1
...
...
opencompass/models/qwen_api.py
View file @
793e32c9
...
...
@@ -109,6 +109,8 @@ class Qwen(BaseAPIModel):
msg
[
'role'
]
=
'user'
elif
item
[
'role'
]
==
'BOT'
:
msg
[
'role'
]
=
'assistant'
elif
item
[
'role'
]
==
'SYSTEM'
:
msg
[
'role'
]
=
'system'
messages
.
append
(
msg
)
data
=
{
'messages'
:
messages
}
...
...
@@ -117,10 +119,16 @@ class Qwen(BaseAPIModel):
max_num_retries
=
0
while
max_num_retries
<
self
.
retry
:
self
.
acquire
()
try
:
response
=
self
.
dashscope
.
Generation
.
call
(
model
=
self
.
path
,
**
data
,
)
except
Exception
as
err
:
print
(
'Request Error:{}'
.
format
(
err
))
time
.
sleep
(
1
)
continue
self
.
release
()
if
response
is
None
:
...
...
@@ -140,6 +148,13 @@ class Qwen(BaseAPIModel):
self
.
logger
.
error
(
str
(
response
.
status_code
))
time
.
sleep
(
1
)
continue
if
response
.
status_code
==
429
:
print
(
'Rate limited'
)
time
.
sleep
(
2
)
continue
if
response
.
status_code
==
400
:
msg
=
'Output data may contain inappropriate content.'
return
msg
if
(
'Range of input length should be '
in
response
.
message
or
# input too long
...
...
requirements/api.txt
View file @
793e32c9
dashscope # Qwen
sseclient-py==1.7.2
volcengine # bytedance
websocket-client
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment