Unverified Commit e6ea0315 authored by mans's avatar mans Committed by GitHub
Browse files

fix request hanging when request api (#3090)



* fix request hanging when request api

* pre commit

---------
Co-authored-by: default avatarqinyidao <qinyidao@moonshot.cn>
parent 489fbc21
......@@ -447,6 +447,7 @@ class TemplateAPI(TemplateLM):
async def amodel_call(
self,
session: ClientSession,
sem: asyncio.Semaphore,
messages: Union[List[List[int]], List[str], List[JsonChatStr]],
*,
generate: bool = True,
......@@ -465,6 +466,7 @@ class TemplateAPI(TemplateLM):
**kwargs,
)
cache_method = "generate_until" if generate else "loglikelihood"
acquired = await sem.acquire()
try:
async with session.post(
self.base_url,
......@@ -474,7 +476,8 @@ class TemplateAPI(TemplateLM):
if not response.ok:
error_text = await response.text()
eval_logger.warning(
f"API request failed with error message: {error_text}. Retrying..."
f"API request failed! Status code: {response.status}, "
f"Response text: {error_text}. Retrying..."
)
# raising exception will retry the request
response.raise_for_status()
......@@ -495,11 +498,12 @@ class TemplateAPI(TemplateLM):
self.cache_hook.add_partial(cache_method, cache, res)
return answers
# If the retries also fail
except RetryError:
eval_logger.error(
"API request failed after multiple retries. Please check the API status."
)
return None
except BaseException as e:
eval_logger.error(f"Exception:{repr(e)}, {outputs}, retrying.")
raise e
finally:
if acquired:
sem.release()
def batch_loglikelihood_requests(
self, chunks: Iterable[List[LogLikelihoodInputs]]
......@@ -535,6 +539,7 @@ class TemplateAPI(TemplateLM):
) -> Union[List[List[str]], List[List[Tuple[float, bool]]]]:
ctxlens = ctxlens if ctxlens else [None] * len(requests)
conn = TCPConnector(limit=self._concurrent, ssl=self.verify_certificate)
sem = asyncio.Semaphore(self._concurrent)
async with ClientSession(
connector=conn, timeout=ClientTimeout(total=self.timeout)
) as session:
......@@ -542,12 +547,16 @@ class TemplateAPI(TemplateLM):
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=0.5, min=1, max=10),
reraise=True,
before_sleep=lambda retry_state: eval_logger.info(
f"Retry attempt {retry_state.attempt_number}"
),
)(self.amodel_call)
# Create tasks for each batch of request
tasks = [
asyncio.create_task(
retry_(
session=session,
sem=sem,
messages=message,
cache_keys=cache_key,
generate=generate,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment