Unverified Commit e6ea0315 authored by mans's avatar mans Committed by GitHub
Browse files

fix request hanging when request api (#3090)



* fix request hanging when request api

* pre commit

---------
Co-authored-by: default avatarqinyidao <qinyidao@moonshot.cn>
parent 489fbc21
...@@ -447,6 +447,7 @@ class TemplateAPI(TemplateLM): ...@@ -447,6 +447,7 @@ class TemplateAPI(TemplateLM):
async def amodel_call( async def amodel_call(
self, self,
session: ClientSession, session: ClientSession,
sem: asyncio.Semaphore,
messages: Union[List[List[int]], List[str], List[JsonChatStr]], messages: Union[List[List[int]], List[str], List[JsonChatStr]],
*, *,
generate: bool = True, generate: bool = True,
...@@ -465,6 +466,7 @@ class TemplateAPI(TemplateLM): ...@@ -465,6 +466,7 @@ class TemplateAPI(TemplateLM):
**kwargs, **kwargs,
) )
cache_method = "generate_until" if generate else "loglikelihood" cache_method = "generate_until" if generate else "loglikelihood"
acquired = await sem.acquire()
try: try:
async with session.post( async with session.post(
self.base_url, self.base_url,
...@@ -474,7 +476,8 @@ class TemplateAPI(TemplateLM): ...@@ -474,7 +476,8 @@ class TemplateAPI(TemplateLM):
if not response.ok: if not response.ok:
error_text = await response.text() error_text = await response.text()
eval_logger.warning( eval_logger.warning(
f"API request failed with error message: {error_text}. Retrying..." f"API request failed! Status code: {response.status}, "
f"Response text: {error_text}. Retrying..."
) )
# raising exception will retry the request # raising exception will retry the request
response.raise_for_status() response.raise_for_status()
...@@ -495,11 +498,12 @@ class TemplateAPI(TemplateLM): ...@@ -495,11 +498,12 @@ class TemplateAPI(TemplateLM):
self.cache_hook.add_partial(cache_method, cache, res) self.cache_hook.add_partial(cache_method, cache, res)
return answers return answers
# If the retries also fail # If the retries also fail
except RetryError: except BaseException as e:
eval_logger.error( eval_logger.error(f"Exception:{repr(e)}, {outputs}, retrying.")
"API request failed after multiple retries. Please check the API status." raise e
) finally:
return None if acquired:
sem.release()
def batch_loglikelihood_requests( def batch_loglikelihood_requests(
self, chunks: Iterable[List[LogLikelihoodInputs]] self, chunks: Iterable[List[LogLikelihoodInputs]]
...@@ -535,6 +539,7 @@ class TemplateAPI(TemplateLM): ...@@ -535,6 +539,7 @@ class TemplateAPI(TemplateLM):
) -> Union[List[List[str]], List[List[Tuple[float, bool]]]]: ) -> Union[List[List[str]], List[List[Tuple[float, bool]]]]:
ctxlens = ctxlens if ctxlens else [None] * len(requests) ctxlens = ctxlens if ctxlens else [None] * len(requests)
conn = TCPConnector(limit=self._concurrent, ssl=self.verify_certificate) conn = TCPConnector(limit=self._concurrent, ssl=self.verify_certificate)
sem = asyncio.Semaphore(self._concurrent)
async with ClientSession( async with ClientSession(
connector=conn, timeout=ClientTimeout(total=self.timeout) connector=conn, timeout=ClientTimeout(total=self.timeout)
) as session: ) as session:
...@@ -542,12 +547,16 @@ class TemplateAPI(TemplateLM): ...@@ -542,12 +547,16 @@ class TemplateAPI(TemplateLM):
stop=stop_after_attempt(self.max_retries), stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=0.5, min=1, max=10), wait=wait_exponential(multiplier=0.5, min=1, max=10),
reraise=True, reraise=True,
before_sleep=lambda retry_state: eval_logger.info(
f"Retry attempt {retry_state.attempt_number}"
),
)(self.amodel_call) )(self.amodel_call)
# Create tasks for each batch of request # Create tasks for each batch of request
tasks = [ tasks = [
asyncio.create_task( asyncio.create_task(
retry_( retry_(
session=session, session=session,
sem=sem,
messages=message, messages=message,
cache_keys=cache_key, cache_keys=cache_key,
generate=generate, generate=generate,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment