Unverified Commit 8ed022b4 authored by Songyang Zhang's avatar Songyang Zhang Committed by GitHub
Browse files

Update Sensetime API (#844)

parent 4aa74565
...@@ -22,7 +22,22 @@ models = [ ...@@ -22,7 +22,22 @@ models = [
query_per_second=1, query_per_second=1,
max_out_len=2048, max_out_len=2048,
max_seq_len=2048, max_seq_len=2048,
batch_size=8), batch_size=8,
parameters={
"temperature": 0.8,
"top_p": 0.7,
"max_new_tokens": 1024,
"repetition_penalty": 1.05,
"know_ids": [],
"stream": True,
"user": "#*#***TestUser***#*#",
"knowledge_config": {
"control_level": "normal",
"knowledge_base_result": False,
"online_search_result": False
}
}
)
] ]
infer = dict( infer = dict(
......
import json
import os
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
...@@ -30,24 +32,32 @@ class SenseTime(BaseAPIModel): ...@@ -30,24 +32,32 @@ class SenseTime(BaseAPIModel):
def __init__( def __init__(
self, self,
path: str, path: str,
key: str,
url: str, url: str,
key: str = 'ENV',
query_per_second: int = 2, query_per_second: int = 2,
max_seq_len: int = 2048, max_seq_len: int = 2048,
meta_template: Optional[Dict] = None, meta_template: Optional[Dict] = None,
retry: int = 2, retry: int = 2,
parameters: Optional[Dict] = None,
): ):
super().__init__(path=path, super().__init__(path=path,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
query_per_second=query_per_second, query_per_second=query_per_second,
meta_template=meta_template, meta_template=meta_template,
retry=retry) retry=retry)
if isinstance(key, str):
self.keys = os.getenv('SENSENOVA_API_KEY') if key == 'ENV' else key
else:
self.keys = key
self.headers = { self.headers = {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
'Authorization': f'Bearer {key}' 'Authorization': f'Bearer {self.keys}'
} }
self.url = url self.url = url
self.model = path self.model = path
self.params = parameters
def generate( def generate(
self, self,
...@@ -104,38 +114,85 @@ class SenseTime(BaseAPIModel): ...@@ -104,38 +114,85 @@ class SenseTime(BaseAPIModel):
messages.append(msg) messages.append(msg)
data = {'messages': messages, 'model': self.model} data = {'messages': messages, 'model': self.model}
data.update(self.params)
stream = data['stream']
max_num_retries = 0 max_num_retries = 0
while max_num_retries < self.retry: while max_num_retries < self.retry:
self.acquire() self.acquire()
max_num_retries += 1
raw_response = requests.request('POST', raw_response = requests.request('POST',
url=self.url, url=self.url,
headers=self.headers, headers=self.headers,
json=data) json=data)
response = raw_response.json() requests_id = raw_response.headers['X-Request-Id'] # noqa
self.release() self.release()
if response is None: if not stream:
print('Connection error, reconnect.') response = raw_response.json()
# if connect error, frequent requests will casuse
# continuous unstable network, therefore wait here if response is None:
# to slow down the request print('Connection error, reconnect.')
self.wait() # if connect error, frequent requests will casuse
continue # continuous unstable network, therefore wait here
if raw_response.status_code == 200: # to slow down the request
msg = response['data']['choices'][0]['message'] self.wait()
return msg continue
if raw_response.status_code == 200:
if (raw_response.status_code != 200): msg = response['data']['choices'][0]['message']
if response['error']['code'] == 18: return msg
# security issue
return 'error:unsafe' if (raw_response.status_code != 200):
if response['error']['code'] == 18:
# security issue
return 'error:unsafe'
if response['error']['code'] == 17:
return 'error:too long'
else:
print(raw_response.text)
time.sleep(1)
continue
else:
# stream data to msg
raw_response.encoding = 'utf-8'
if raw_response.status_code == 200:
response_text = raw_response.text
data_blocks = response_text.split('data:')
data_blocks = data_blocks[1:]
first_block = json.loads(data_blocks[0])
if first_block['status']['code'] != 0:
msg = f"error:{first_block['status']['code']},"
f" {first_block['status']['message']}"
self.logger.error(msg)
return msg
msg = ''
for i, part in enumerate(data_blocks):
# print(f'process {i}: {part}')
try:
if part.startswith('[DONE]'):
break
json_data = json.loads(part)
choices = json_data['data']['choices']
for c in choices:
delta = c.get('delta')
msg += delta
except json.decoder.JSONDecodeError as err:
print(err)
self.logger.error(f'Error decoding JSON: {part}')
return msg
else: else:
print(raw_response.text) print(raw_response.text,
raw_response.headers.get('X-Request-Id'))
time.sleep(1) time.sleep(1)
continue continue
print(response) raise RuntimeError(
max_num_retries += 1 f'request id: '
f'{raw_response.headers.get("X-Request-Id")}, {raw_response.text}')
raise RuntimeError(raw_response.text)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment