Unverified Commit f61c9da7 authored by Rishabh Saini's avatar Rishabh Saini Committed by GitHub
Browse files

[BugFix] deepseek_v32_encoding: Replace asserts with proper exceptions (#32884)


Signed-off-by: default avatarRishabhSaini <rishabhsaini01@gmail.com>
parent 7fe25588
...@@ -154,10 +154,12 @@ def find_last_user_index(messages: list[dict[str, Any]]) -> int: ...@@ -154,10 +154,12 @@ def find_last_user_index(messages: list[dict[str, Any]]) -> int:
def render_message( def render_message(
index: int, messages: list[dict[str, Any]], thinking_mode: str index: int, messages: list[dict[str, Any]], thinking_mode: str
) -> str: ) -> str:
assert 0 <= index < len(messages) if not (0 <= index < len(messages)):
assert thinking_mode in ["chat", "thinking"], ( raise ValueError(
f"Invalid thinking_mode `{thinking_mode}`" f"Index {index} out of range for messages list of length {len(messages)}"
) )
if thinking_mode not in ["chat", "thinking"]:
raise ValueError(f"Invalid thinking_mode `{thinking_mode}`")
prompt = "" prompt = ""
msg = messages[index] msg = messages[index]
...@@ -187,7 +189,8 @@ def render_message( ...@@ -187,7 +189,8 @@ def render_message(
) )
elif role == "developer": elif role == "developer":
assert content, f"Invalid message for role `{role}`: {msg}" if not content:
raise ValueError(f"Invalid message for role `{role}`: {msg}")
content_developer = "" content_developer = ""
if tools: if tools:
content_developer += "\n\n" + render_tools(tools) content_developer += "\n\n" + render_tools(tools)
...@@ -220,17 +223,17 @@ def render_message( ...@@ -220,17 +223,17 @@ def render_message(
prev_assistant_idx -= 1 prev_assistant_idx -= 1
assistant_msg = messages[prev_assistant_idx] assistant_msg = messages[prev_assistant_idx]
assert ( if not (
index == 0 index == 0
or prev_assistant_idx >= 0 or prev_assistant_idx >= 0
and assistant_msg.get("role") == "assistant" and assistant_msg.get("role") == "assistant"
), f"Invalid messages at {index}:\n{assistant_msg}" ):
raise ValueError(f"Invalid messages at {index}:\n{assistant_msg}")
tool_call_order = index - prev_assistant_idx tool_call_order = index - prev_assistant_idx
assistant_tool_calls = assistant_msg.get("tool_calls") assistant_tool_calls = assistant_msg.get("tool_calls")
assert assistant_tool_calls and len(assistant_tool_calls) >= tool_call_order, ( if not (assistant_tool_calls and len(assistant_tool_calls) >= tool_call_order):
"No tool calls but found tool output" raise ValueError("No tool calls but found tool output")
)
if tool_call_order == 1: if tool_call_order == 1:
prompt += "\n\n<function_results>" prompt += "\n\n<function_results>"
...@@ -266,7 +269,8 @@ def render_message( ...@@ -266,7 +269,8 @@ def render_message(
summary_content = content or "" summary_content = content or ""
if thinking_mode == "thinking" and index > last_user_idx: if thinking_mode == "thinking" and index > last_user_idx:
assert reasoning_content or tool_calls, ( if not (reasoning_content or tool_calls):
raise ValueError(
f"ThinkingMode: {thinking_mode}, invalid message without reasoning_content/tool_calls `{msg}` after last user message" f"ThinkingMode: {thinking_mode}, invalid message without reasoning_content/tool_calls `{msg}` after last user message"
) )
thinking_part = ( thinking_part = (
...@@ -362,12 +366,14 @@ def parse_tool_calls(index: int, text: str): ...@@ -362,12 +366,14 @@ def parse_tool_calls(index: int, text: str):
index, _, stop_token = _read_until_stop( index, _, stop_token = _read_until_stop(
index, text, [f"<{dsml_token}invoke", tool_calls_end_token] index, text, [f"<{dsml_token}invoke", tool_calls_end_token]
) )
assert _ == ">\n", "Tool call format error" if _ != ">\n":
raise RuntimeError("Tool call format error")
if stop_token == tool_calls_end_token: if stop_token == tool_calls_end_token:
break break
assert stop_token is not None, "Missing special token" if stop_token is None:
raise RuntimeError("Missing special token")
index, tool_name_content, stop_token = _read_until_stop( index, tool_name_content, stop_token = _read_until_stop(
index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"] index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"]
...@@ -376,7 +382,8 @@ def parse_tool_calls(index: int, text: str): ...@@ -376,7 +382,8 @@ def parse_tool_calls(index: int, text: str):
p_tool_name = re.findall( p_tool_name = re.findall(
r'^\s*name="(.*?)">\n$', tool_name_content, flags=re.DOTALL r'^\s*name="(.*?)">\n$', tool_name_content, flags=re.DOTALL
) )
assert len(p_tool_name) == 1, "Tool name format error" if len(p_tool_name) != 1:
raise RuntimeError("Tool name format error")
tool_name = p_tool_name[0] tool_name = p_tool_name[0]
tool_args: dict[str, tuple[str, str]] = {} tool_args: dict[str, tuple[str, str]] = {}
...@@ -390,16 +397,19 @@ def parse_tool_calls(index: int, text: str): ...@@ -390,16 +397,19 @@ def parse_tool_calls(index: int, text: str):
param_content, param_content,
flags=re.DOTALL, flags=re.DOTALL,
) )
assert len(param_kv) == 1, "Parameter format error" if len(param_kv) != 1:
raise RuntimeError("Parameter format error")
param_name, string, param_value = param_kv[0] param_name, string, param_value = param_kv[0]
assert param_name not in tool_args, "Duplicate parameter name" if param_name in tool_args:
raise RuntimeError("Duplicate parameter name")
tool_args[param_name] = (param_value, string) tool_args[param_name] = (param_value, string)
index, content, stop_token = _read_until_stop( index, content, stop_token = _read_until_stop(
index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"] index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"]
) )
assert content == ">\n", "Parameter format error" if content != ">\n":
raise RuntimeError("Parameter format error")
tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args) tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args)
tool_calls.append(tool_call) tool_calls.append(tool_call)
...@@ -422,7 +432,8 @@ def parse_message_from_completion_text(text: str, thinking_mode: str): ...@@ -422,7 +432,8 @@ def parse_message_from_completion_text(text: str, thinking_mode: str):
index, text, [thinking_end_token, tool_calls_start_token] index, text, [thinking_end_token, tool_calls_start_token]
) )
reasoning_content = content_delta reasoning_content = content_delta
assert stop_token == thinking_end_token, "Invalid thinking format" if stop_token != thinking_end_token:
raise RuntimeError("Invalid thinking format")
index, content_delta, stop_token = _read_until_stop( index, content_delta, stop_token = _read_until_stop(
index, text, [eos_token, tool_calls_start_token] index, text, [eos_token, tool_calls_start_token]
...@@ -431,17 +442,18 @@ def parse_message_from_completion_text(text: str, thinking_mode: str): ...@@ -431,17 +442,18 @@ def parse_message_from_completion_text(text: str, thinking_mode: str):
if stop_token == tool_calls_start_token: if stop_token == tool_calls_start_token:
is_tool_calling = True is_tool_calling = True
else: else:
assert stop_token == eos_token, "Invalid summary format" if stop_token != eos_token:
raise RuntimeError("Invalid summary format")
if is_tool_calling: if is_tool_calling:
index, stop_token, tool_calls = parse_tool_calls(index, text) index, stop_token, tool_calls = parse_tool_calls(index, text)
index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token]) index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token])
assert not tool_ends_text, "Unexpected content after tool calls" if tool_ends_text:
raise RuntimeError("Unexpected content after tool calls")
assert len(text) == index and stop_token in [eos_token, None], ( if not (len(text) == index and stop_token in [eos_token, None]):
"Unexpected content at end" raise RuntimeError("Unexpected content at end")
)
for sp_token in [ for sp_token in [
bos_token, bos_token,
...@@ -450,9 +462,8 @@ def parse_message_from_completion_text(text: str, thinking_mode: str): ...@@ -450,9 +462,8 @@ def parse_message_from_completion_text(text: str, thinking_mode: str):
thinking_end_token, thinking_end_token,
dsml_token, dsml_token,
]: ]:
assert sp_token not in summary_content and sp_token not in reasoning_content, ( if sp_token in summary_content or sp_token in reasoning_content:
"Unexpected special token in content" raise RuntimeError("Unexpected special token in content")
)
return { return {
"role": "assistant", "role": "assistant",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment