Unverified Commit b912de11 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Make stop reason a dict instead of str (#1407)

parent eb02c161
...@@ -56,7 +56,7 @@ class BaseFinishReason: ...@@ -56,7 +56,7 @@ class BaseFinishReason:
def __init__(self, is_error: bool = False): def __init__(self, is_error: bool = False):
self.is_error = is_error self.is_error = is_error
def __str__(self): def to_json(self):
raise NotImplementedError("Subclasses must implement this method") raise NotImplementedError("Subclasses must implement this method")
...@@ -65,34 +65,45 @@ class FINISH_MATCHED_TOKEN(BaseFinishReason): ...@@ -65,34 +65,45 @@ class FINISH_MATCHED_TOKEN(BaseFinishReason):
super().__init__() super().__init__()
self.matched = matched self.matched = matched
def __str__(self) -> str: def to_json(self):
return f"FINISH_MATCHED_TOKEN: {self.matched}" return {
"type": "stop", # to match OpenAI API's return value
"matched": self.matched,
}
class FINISH_LENGTH(BaseFinishReason): class FINISH_MATCHED_STR(BaseFinishReason):
def __init__(self, length: int): def __init__(self, matched: str):
super().__init__() super().__init__()
self.length = length self.matched = matched
def __str__(self) -> str: def to_json(self):
return f"FINISH_LENGTH: {self.length}" return {
"type": "stop", # to match OpenAI API's return value
"matched": self.matched,
}
class FINISH_MATCHED_STR(BaseFinishReason): class FINISH_LENGTH(BaseFinishReason):
def __init__(self, matched: str): def __init__(self, length: int):
super().__init__() super().__init__()
self.matched = matched self.length = length
def __str__(self) -> str: def to_json(self):
return f"FINISH_MATCHED_STR: {self.matched}" return {
"type": "length", # to match OpenAI API's return value
"length": self.length,
}
class FINISH_ABORT(BaseFinishReason): class FINISH_ABORT(BaseFinishReason):
def __init__(self): def __init__(self):
super().__init__(is_error=True) super().__init__(is_error=True)
def __str__(self) -> str: def to_json(self):
return "FINISH_ABORT" return {
"type": "abort",
}
class Req: class Req:
......
...@@ -813,7 +813,11 @@ class ModelTpServer: ...@@ -813,7 +813,11 @@ class ModelTpServer:
"prompt_tokens": len(req.origin_input_ids), "prompt_tokens": len(req.origin_input_ids),
"completion_tokens": len(req.output_ids), "completion_tokens": len(req.output_ids),
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"finish_reason": str(req.finished_reason), "finish_reason": (
req.finished_reason.to_json()
if req.finished_reason is not None
else None
),
} }
if req.return_logprob: if req.return_logprob:
( (
......
...@@ -95,19 +95,6 @@ file_id_storage: Dict[str, str] = {} ...@@ -95,19 +95,6 @@ file_id_storage: Dict[str, str] = {}
storage_dir = None storage_dir = None
def format_finish_reason(finish_reason) -> Optional[str]:
if finish_reason.startswith("None"):
return None
elif finish_reason.startswith("FINISH_MATCHED"):
return "stop"
elif finish_reason.startswith("FINISH_LENGTH"):
return "length"
elif finish_reason.startswith("FINISH_ABORT"):
return "abort"
else:
return "unknown"
def create_error_response( def create_error_response(
message: str, message: str,
err_type: str = "BadRequestError", err_type: str = "BadRequestError",
...@@ -618,8 +605,10 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False): ...@@ -618,8 +605,10 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
"index": 0, "index": 0,
"text": text, "text": text,
"logprobs": logprobs, "logprobs": logprobs,
"finish_reason": format_finish_reason( "finish_reason": (
ret_item["meta_info"]["finish_reason"] ret_item["meta_info"]["finish_reason"]["type"]
if ret_item["meta_info"]["finish_reason"]
else ""
), ),
} }
else: else:
...@@ -627,8 +616,10 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False): ...@@ -627,8 +616,10 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
index=idx, index=idx,
text=text, text=text,
logprobs=logprobs, logprobs=logprobs,
finish_reason=format_finish_reason( finish_reason=(
ret_item["meta_info"]["finish_reason"] ret_item["meta_info"]["finish_reason"]["type"]
if ret_item["meta_info"]["finish_reason"]
else ""
), ),
) )
...@@ -762,8 +753,10 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -762,8 +753,10 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
index=index, index=index,
text=delta, text=delta,
logprobs=logprobs, logprobs=logprobs,
finish_reason=format_finish_reason( finish_reason=(
content["meta_info"]["finish_reason"] content["meta_info"]["finish_reason"]["type"]
if content["meta_info"]["finish_reason"]
else ""
), ),
) )
chunk = CompletionStreamResponse( chunk = CompletionStreamResponse(
...@@ -999,8 +992,10 @@ def v1_chat_generate_response(request, ret, to_file=False): ...@@ -999,8 +992,10 @@ def v1_chat_generate_response(request, ret, to_file=False):
"index": 0, "index": 0,
"message": {"role": "assistant", "content": ret_item["text"]}, "message": {"role": "assistant", "content": ret_item["text"]},
"logprobs": choice_logprobs, "logprobs": choice_logprobs,
"finish_reason": format_finish_reason( "finish_reason": (
ret_item["meta_info"]["finish_reason"] ret_item["meta_info"]["finish_reason"]["type"]
if ret_item["meta_info"]["finish_reason"]
else ""
), ),
} }
else: else:
...@@ -1008,8 +1003,10 @@ def v1_chat_generate_response(request, ret, to_file=False): ...@@ -1008,8 +1003,10 @@ def v1_chat_generate_response(request, ret, to_file=False):
index=idx, index=idx,
message=ChatMessage(role="assistant", content=ret_item["text"]), message=ChatMessage(role="assistant", content=ret_item["text"]),
logprobs=choice_logprobs, logprobs=choice_logprobs,
finish_reason=format_finish_reason( finish_reason=(
ret_item["meta_info"]["finish_reason"] ret_item["meta_info"]["finish_reason"]["type"]
if ret_item["meta_info"]["finish_reason"]
else ""
), ),
) )
...@@ -1134,8 +1131,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -1134,8 +1131,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=index, index=index,
delta=DeltaMessage(role="assistant"), delta=DeltaMessage(role="assistant"),
finish_reason=format_finish_reason( finish_reason=(
content["meta_info"]["finish_reason"] content["meta_info"]["finish_reason"]["type"]
if content["meta_info"]["finish_reason"]
else ""
), ),
logprobs=choice_logprobs, logprobs=choice_logprobs,
) )
...@@ -1152,8 +1151,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -1152,8 +1151,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=index, index=index,
delta=DeltaMessage(content=delta), delta=DeltaMessage(content=delta),
finish_reason=format_finish_reason( finish_reason=(
content["meta_info"]["finish_reason"] content["meta_info"]["finish_reason"]["type"]
if content["meta_info"]["finish_reason"]
else ""
), ),
logprobs=choice_logprobs, logprobs=choice_logprobs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment