"tools/vscode:/vscode.git/clone" did not exist on "1d67149511404350c0c0bced73773afb6fcbfa21"
Unverified Commit 793e32c9 authored by Songyang Zhang's avatar Songyang Zhang Committed by GitHub
Browse files

[Feature] Update API implementation (#834)

parent 2ee8e8a1
import hashlib
import json
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
...@@ -22,7 +20,6 @@ class BaiChuan(BaseAPIModel): ...@@ -22,7 +20,6 @@ class BaiChuan(BaseAPIModel):
path (str): The name of Baichuan model. path (str): The name of Baichuan model.
e.g. `Baichuan2-53B` e.g. `Baichuan2-53B`
api_key (str): Provided api key api_key (str): Provided api key
secretkey (str): secretkey in order to obtain access_token
url (str): Provide url url (str): Provide url
query_per_second (int): The maximum queries allowed per second query_per_second (int): The maximum queries allowed per second
between two consecutive calls of the API. Defaults to 1. between two consecutive calls of the API. Defaults to 1.
...@@ -37,7 +34,6 @@ class BaiChuan(BaseAPIModel): ...@@ -37,7 +34,6 @@ class BaiChuan(BaseAPIModel):
self, self,
path: str, path: str,
api_key: str, api_key: str,
secret_key: str,
url: str, url: str,
query_per_second: int = 2, query_per_second: int = 2,
max_seq_len: int = 2048, max_seq_len: int = 2048,
...@@ -48,6 +44,7 @@ class BaiChuan(BaseAPIModel): ...@@ -48,6 +44,7 @@ class BaiChuan(BaseAPIModel):
'top_p': 0.85, 'top_p': 0.85,
'top_k': 5, 'top_k': 5,
'with_search_enhance': False, 'with_search_enhance': False,
'stream': False,
}): # noqa E125 }): # noqa E125
super().__init__(path=path, super().__init__(path=path,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
...@@ -57,7 +54,6 @@ class BaiChuan(BaseAPIModel): ...@@ -57,7 +54,6 @@ class BaiChuan(BaseAPIModel):
generation_kwargs=generation_kwargs) generation_kwargs=generation_kwargs)
self.api_key = api_key self.api_key = api_key
self.secret_key = secret_key
self.url = url self.url = url
self.model = path self.model = path
...@@ -119,36 +115,28 @@ class BaiChuan(BaseAPIModel): ...@@ -119,36 +115,28 @@ class BaiChuan(BaseAPIModel):
data = {'model': self.model, 'messages': messages} data = {'model': self.model, 'messages': messages}
data.update(self.generation_kwargs) data.update(self.generation_kwargs)
def calculate_md5(input_string):
md5 = hashlib.md5()
md5.update(input_string.encode('utf-8'))
encrypted = md5.hexdigest()
return encrypted
json_data = json.dumps(data)
time_stamp = int(time.time())
signature = calculate_md5(self.secret_key + json_data +
str(time_stamp))
headers = { headers = {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
'Authorization': 'Bearer ' + self.api_key, 'Authorization': 'Bearer ' + self.api_key,
'X-BC-Request-Id': 'your requestId',
'X-BC-Timestamp': str(time_stamp),
'X-BC-Signature': signature,
'X-BC-Sign-Algo': 'MD5',
} }
max_num_retries = 0 max_num_retries = 0
while max_num_retries < self.retry: while max_num_retries < self.retry:
self.acquire() self.acquire()
raw_response = requests.request('POST', try:
url=self.url, raw_response = requests.request('POST',
headers=headers, url=self.url,
json=data) headers=headers,
response = raw_response.json() json=data)
self.release() response = raw_response.json()
except Exception as err:
print('Request Error:{}'.format(err))
time.sleep(3)
continue
self.release()
# print(response.keys())
# print(response['choices'][0]['message']['content'])
if response is None: if response is None:
print('Connection error, reconnect.') print('Connection error, reconnect.')
# if connect error, frequent requests will casuse # if connect error, frequent requests will casuse
...@@ -156,13 +144,13 @@ class BaiChuan(BaseAPIModel): ...@@ -156,13 +144,13 @@ class BaiChuan(BaseAPIModel):
# to slow down the request # to slow down the request
self.wait() self.wait()
continue continue
if raw_response.status_code == 200 and response['code'] == 0: if raw_response.status_code == 200:
msg = response['data']['messages'][0]['content'] msg = response['choices'][0]['message']['content']
return msg return msg
if response['code'] != 0: if raw_response.status_code != 200:
print(response) print(raw_response)
time.sleep(1) time.sleep(1)
continue continue
print(response) print(response)
......
...@@ -54,6 +54,9 @@ class ERNIEBot(BaseAPIModel): ...@@ -54,6 +54,9 @@ class ERNIEBot(BaseAPIModel):
self.secretkey = secretkey self.secretkey = secretkey
self.key = key self.key = key
self.url = url self.url = url
access_token, _ = self._generate_access_token()
self.access_token = access_token
print(access_token)
def _generate_access_token(self): def _generate_access_token(self):
try: try:
...@@ -154,12 +157,18 @@ class ERNIEBot(BaseAPIModel): ...@@ -154,12 +157,18 @@ class ERNIEBot(BaseAPIModel):
max_num_retries = 0 max_num_retries = 0
while max_num_retries < self.retry: while max_num_retries < self.retry:
self.acquire() self.acquire()
access_token, _ = self._generate_access_token() try:
raw_response = requests.request('POST', raw_response = requests.request('POST',
url=self.url + access_token, url=self.url +
headers=self.headers, self.access_token,
json=data) headers=self.headers,
response = raw_response.json() json=data)
response = raw_response.json()
except Exception as err:
print('Request Error:{}'.format(err))
time.sleep(3)
continue
self.release() self.release()
if response is None: if response is None:
...@@ -176,6 +185,10 @@ class ERNIEBot(BaseAPIModel): ...@@ -176,6 +185,10 @@ class ERNIEBot(BaseAPIModel):
except KeyError: except KeyError:
print(response) print(response)
self.logger.error(str(response['error_code'])) self.logger.error(str(response['error_code']))
if response['error_code'] == 336007:
# exceed max length
return ''
time.sleep(1) time.sleep(1)
continue continue
...@@ -189,7 +202,8 @@ class ERNIEBot(BaseAPIModel): ...@@ -189,7 +202,8 @@ class ERNIEBot(BaseAPIModel):
or response['error_code'] == 216100 or response['error_code'] == 216100
or response['error_code'] == 336001 or response['error_code'] == 336001
or response['error_code'] == 336003 or response['error_code'] == 336003
or response['error_code'] == 336000): or response['error_code'] == 336000
or response['error_code'] == 336007):
print(response['error_msg']) print(response['error_msg'])
return '' return ''
print(response) print(response)
......
...@@ -90,7 +90,7 @@ class MiniMax(BaseAPIModel): ...@@ -90,7 +90,7 @@ class MiniMax(BaseAPIModel):
Args: Args:
inputs (str or PromptList): A string or PromptDict. inputs (str or PromptList): A string or PromptDict.
The PromptDict should be organized in OpenCompass' The PromptDict should be organized in Test'
API format. API format.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
...@@ -102,7 +102,7 @@ class MiniMax(BaseAPIModel): ...@@ -102,7 +102,7 @@ class MiniMax(BaseAPIModel):
if isinstance(input, str): if isinstance(input, str):
messages = [{ messages = [{
'sender_type': 'USER', 'sender_type': 'USER',
'sender_name': 'OpenCompass', 'sender_name': 'Test',
'text': input 'text': input
}] }]
else: else:
...@@ -111,7 +111,7 @@ class MiniMax(BaseAPIModel): ...@@ -111,7 +111,7 @@ class MiniMax(BaseAPIModel):
msg = {'text': item['prompt']} msg = {'text': item['prompt']}
if item['role'] == 'HUMAN': if item['role'] == 'HUMAN':
msg['sender_type'] = 'USER' msg['sender_type'] = 'USER'
msg['sender_name'] = 'OpenCompass' msg['sender_name'] = 'Test'
elif item['role'] == 'BOT': elif item['role'] == 'BOT':
msg['sender_type'] = 'BOT' msg['sender_type'] = 'BOT'
msg['sender_name'] = 'MM智能助理' msg['sender_name'] = 'MM智能助理'
...@@ -135,15 +135,19 @@ class MiniMax(BaseAPIModel): ...@@ -135,15 +135,19 @@ class MiniMax(BaseAPIModel):
'messages': 'messages':
messages messages
} }
max_num_retries = 0 max_num_retries = 0
while max_num_retries < self.retry: while max_num_retries < self.retry:
self.acquire() self.acquire()
raw_response = requests.request('POST', try:
url=self.url, raw_response = requests.request('POST',
headers=self.headers, url=self.url,
json=data) headers=self.headers,
response = raw_response.json() json=data)
response = raw_response.json()
except Exception as err:
print('Request Error:{}'.format(err))
time.sleep(3)
continue
self.release() self.release()
if response is None: if response is None:
...@@ -157,6 +161,7 @@ class MiniMax(BaseAPIModel): ...@@ -157,6 +161,7 @@ class MiniMax(BaseAPIModel):
# msg = json.load(response.text) # msg = json.load(response.text)
# response # response
msg = response['reply'] msg = response['reply']
# msg = response['choices']['messages']['text']
return msg return msg
# sensitive content, prompt overlength, network error # sensitive content, prompt overlength, network error
# or illegal prompt # or illegal prompt
......
...@@ -125,10 +125,15 @@ class MoonShot(BaseAPIModel): ...@@ -125,10 +125,15 @@ class MoonShot(BaseAPIModel):
max_num_retries = 0 max_num_retries = 0
while max_num_retries < self.retry: while max_num_retries < self.retry:
self.acquire() self.acquire()
raw_response = requests.request('POST', try:
url=self.url, raw_response = requests.request('POST',
headers=self.headers, url=self.url,
json=data) headers=self.headers,
json=data)
except Exception as err:
print('Request Error:{}'.format(err))
time.sleep(2)
continue
response = raw_response.json() response = raw_response.json()
self.release() self.release()
...@@ -153,12 +158,14 @@ class MoonShot(BaseAPIModel): ...@@ -153,12 +158,14 @@ class MoonShot(BaseAPIModel):
elif raw_response.status_code == 400: elif raw_response.status_code == 400:
print(messages, response) print(messages, response)
print('请求失败,状态码:', raw_response) print('请求失败,状态码:', raw_response)
msg = 'The request was rejected because high risk'
return msg
time.sleep(1) time.sleep(1)
continue continue
elif raw_response.status_code == 429: elif raw_response.status_code == 429:
print(messages, response) print(messages, response)
print('请求失败,状态码:', raw_response) print('请求失败,状态码:', raw_response)
time.sleep(3) time.sleep(5)
continue continue
max_num_retries += 1 max_num_retries += 1
......
...@@ -109,6 +109,8 @@ class Qwen(BaseAPIModel): ...@@ -109,6 +109,8 @@ class Qwen(BaseAPIModel):
msg['role'] = 'user' msg['role'] = 'user'
elif item['role'] == 'BOT': elif item['role'] == 'BOT':
msg['role'] = 'assistant' msg['role'] = 'assistant'
elif item['role'] == 'SYSTEM':
msg['role'] = 'system'
messages.append(msg) messages.append(msg)
data = {'messages': messages} data = {'messages': messages}
...@@ -117,10 +119,16 @@ class Qwen(BaseAPIModel): ...@@ -117,10 +119,16 @@ class Qwen(BaseAPIModel):
max_num_retries = 0 max_num_retries = 0
while max_num_retries < self.retry: while max_num_retries < self.retry:
self.acquire() self.acquire()
response = self.dashscope.Generation.call( try:
model=self.path, response = self.dashscope.Generation.call(
**data, model=self.path,
) **data,
)
except Exception as err:
print('Request Error:{}'.format(err))
time.sleep(1)
continue
self.release() self.release()
if response is None: if response is None:
...@@ -140,6 +148,13 @@ class Qwen(BaseAPIModel): ...@@ -140,6 +148,13 @@ class Qwen(BaseAPIModel):
self.logger.error(str(response.status_code)) self.logger.error(str(response.status_code))
time.sleep(1) time.sleep(1)
continue continue
if response.status_code == 429:
print('Rate limited')
time.sleep(2)
continue
if response.status_code == 400:
msg = 'Output data may contain inappropriate content.'
return msg
if ('Range of input length should be ' in response.message if ('Range of input length should be ' in response.message
or # input too long or # input too long
......
dashscope # Qwen
sseclient-py==1.7.2 sseclient-py==1.7.2
volcengine # bytedance volcengine # bytedance
websocket-client websocket-client
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment