Commit 84ac1d27 authored by zhangning3's avatar zhangning3
Browse files

add step3p5 mtp3

parent b27f1671
......@@ -57,6 +57,11 @@ def parse_args():
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--enforce-eager", action="store_true")
parser.add_argument("--enable-chunked-prefill", action="store_true")
parser.add_argument(
"--enable-multi-layers-mtp",
action="store_true",
help="Enable multi-layer MTP (only effective when --method=mtp).",
)
parser.add_argument("--max-model-len", type=int, default=16384)
parser.add_argument("--temp", type=float, default=0)
parser.add_argument("--top-p", type=float, default=1.0)
......@@ -66,12 +71,14 @@ def parse_args():
parser.add_argument("--model-dir", type=str, default=None)
parser.add_argument("--eagle-dir", type=str, default=None)
parser.add_argument("--draft-model", type=str, default=None)
parser.add_argument("--tokenizer-dir", type=str, default=None)
parser.add_argument("--custom-mm-prompts", action="store_true")
parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
parser.add_argument("--disable-padded-drafter-batch", action="store_true")
parser.add_argument("--max-num-seqs", type=int, default=None)
parser.add_argument("--parallel-drafting", action="store_true")
parser.add_argument("--allowed-local-media-path", type=str, default="")
parser.add_argument("--trust-remote-code", action="store_true")
return parser.parse_args()
......@@ -85,7 +92,11 @@ def main(args):
"please specify model_dir to give a mm based model"
)
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
tokenizer_dir = args.tokenizer_dir
if tokenizer_dir is None:
tokenizer_dir = model_dir
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
if args.custom_mm_prompts:
prompts = llm_prompts = get_custom_mm_prompts(args.num_prompts)
......@@ -141,6 +152,8 @@ def main(args):
"method": "mtp",
"num_speculative_tokens": args.num_spec_tokens,
}
if args.enable_multi_layers_mtp:
speculative_config["enable_multi_layers_mtp"] = True
else:
raise ValueError(f"unknown method: {args.method}")
......
#!/usr/bin/env bash
set -u
DEFAULT_BATCH_SIZES=(1 8 16 32 64 128)
MODEL_PATH="/module/step3.5-fp8/"
SERVED_MODEL_NAME="/module/step3.5-fp8/"
DATASET_NAME="random"
DEFAULT_OUTPUT_LEN_DECODE=4096
DEFAULT_OUTPUT_LEN_PREFILL=1
DEFAULT_ROLE="decode"
READY_CHECK_TIMEOUT=3
RESULT_DIR="benchmark_result"
print_usage() {
cat <<'USAGE'
Usage:
./scripts/step3p5_benchmark_test.sh
./scripts/step3p5_benchmark_test.sh 1,8,16,32
./scripts/step3p5_benchmark_test.sh --role prefill
./scripts/step3p5_benchmark_test.sh --role decode
./scripts/step3p5_benchmark_test.sh --role both
./scripts/step3p5_benchmark_test.sh --role prefill --output-len 1
Description:
- No argument: use default batch sizes, role=decode, output-len=4096
- Optional positional argument: batch size list (comma or space separated)
- Optional flag: --role <prefill|decode|both>
- Optional flag: --output-len <N> (must be positive integer)
- role=both 时串行执行 prefill 再 decode
- Result files are saved under:
<result_dir>/prefill (when role=prefill)
<result_dir>/decode (when role=decode)
USAGE
}
parse_batch_sizes() {
local raw_input="${1:-}"
if [[ -z "$raw_input" ]]; then
BATCH_SIZES=("${DEFAULT_BATCH_SIZES[@]}")
return
fi
raw_input="${raw_input//,/ }"
read -r -a BATCH_SIZES <<< "$raw_input"
if (( ${#BATCH_SIZES[@]} == 0 )); then
echo "[ERROR] batch size 列表为空。"
print_usage
exit 1
fi
local batch_size
for batch_size in "${BATCH_SIZES[@]}"; do
if ! [[ "$batch_size" =~ ^[1-9][0-9]*$ ]]; then
echo "[ERROR] 非法 batch size: $batch_size(必须是正整数)"
exit 1
fi
done
}
parse_role() {
local role_input="${1:-$DEFAULT_ROLE}"
if [[ "$role_input" != "prefill" && "$role_input" != "decode" && "$role_input" != "both" ]]; then
echo "[ERROR] 非法 role: $role_input(必须是 prefill、decode 或 both)"
print_usage
exit 1
fi
ROLE="$role_input"
}
parse_output_len() {
local output_len_input="$1"
if ! [[ "$output_len_input" =~ ^[1-9][0-9]*$ ]]; then
echo "[ERROR] 非法 output-len: $output_len_input(必须是正整数)"
print_usage
exit 1
fi
RANDOM_OUTPUT_LEN="$output_len_input"
}
calc_num_prompts() {
local batch_size="$1"
local num_prompts=$((batch_size + batch_size / 2))
if (( num_prompts < 16 )); then
num_prompts=16
fi
if (( num_prompts > 384 )); then
num_prompts=384
fi
if (( num_prompts < batch_size )); then
num_prompts=$batch_size
fi
echo "$num_prompts"
}
main() {
local batch_arg=""
local role_arg="$DEFAULT_ROLE"
local output_len_arg=""
while (( $# > 0 )); do
case "$1" in
-h|--help)
print_usage
exit 0
;;
--role|-r)
if [[ -z "${2:-}" ]]; then
echo "[ERROR] --role 缺少参数。"
print_usage
exit 1
fi
role_arg="$2"
shift 2
;;
--output-len|-o)
if [[ -z "${2:-}" ]]; then
echo "[ERROR] --output-len 缺少参数。"
print_usage
exit 1
fi
output_len_arg="$2"
shift 2
;;
--*)
echo "[ERROR] 未知参数: $1"
print_usage
exit 1
;;
*)
if [[ -n "$batch_arg" ]]; then
echo "[ERROR] 仅支持一个 batch size 列表参数。"
print_usage
exit 1
fi
batch_arg="$1"
shift
;;
esac
done
parse_batch_sizes "$batch_arg"
parse_role "$role_arg"
if [[ "$ROLE" == "both" ]]; then
local -a prefill_cmd=("$0")
local -a decode_cmd=("$0")
if [[ -n "$batch_arg" ]]; then
prefill_cmd+=("$batch_arg")
decode_cmd+=("$batch_arg")
fi
prefill_cmd+=("--role" "prefill")
decode_cmd+=("--role" "decode")
if [[ -n "$output_len_arg" ]]; then
decode_cmd+=("--output-len" "$output_len_arg")
fi
echo "[INFO] role=both: 将串行执行 prefill 和 decode"
echo "[INFO] step1: ${prefill_cmd[*]}"
"${prefill_cmd[@]}"
echo "[INFO] step2: ${decode_cmd[*]}"
"${decode_cmd[@]}"
echo "[INFO] role=both 执行完成。"
return 0
fi
if [[ "$ROLE" == "prefill" ]]; then
if [[ -n "$output_len_arg" && "$output_len_arg" != "$DEFAULT_OUTPUT_LEN_PREFILL" ]]; then
echo "[WARN] role=prefill 时 output-len 必须为 1,已自动覆盖为 1。"
fi
output_len_arg="$DEFAULT_OUTPUT_LEN_PREFILL"
elif [[ -z "$output_len_arg" ]]; then
output_len_arg="$DEFAULT_OUTPUT_LEN_DECODE"
fi
parse_output_len "$output_len_arg"
RESULT_SUBDIR="$RESULT_DIR/$ROLE"
mkdir -p "$RESULT_SUBDIR"
echo "[INFO] 将执行 ${#BATCH_SIZES[@]} 组 benchmark"
echo "[INFO] role: $ROLE"
echo "[INFO] random-output-len: $RANDOM_OUTPUT_LEN"
echo "[INFO] batch size 列表: ${BATCH_SIZES[*]}"
echo "[INFO] result_dir: $RESULT_SUBDIR"
local batch_size
local num_prompts
local failed_count=0
for batch_size in "${BATCH_SIZES[@]}"; do
num_prompts="$(calc_num_prompts "$batch_size")"
echo ""
echo "[INFO] 开始测试: role=$ROLE, batch_size=$batch_size, max_concurrency=$batch_size, num_prompts=$num_prompts, random_output_len=$RANDOM_OUTPUT_LEN"
if ! vllm bench serve \
--backend vllm \
--model "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--dataset-name "$DATASET_NAME" \
--random-input-len 65536 \
--random-output-len "$RANDOM_OUTPUT_LEN" \
--num-prompts "$num_prompts" \
--temperature 0 \
--max-concurrency "$batch_size" \
--ready-check-timeout "$READY_CHECK_TIMEOUT" \
--result-dir "$RESULT_SUBDIR" \
--port 8018 \
--save-result; then
echo "[WARN] batch_size=$batch_size 执行失败,继续下一个。"
failed_count=$((failed_count + 1))
else
echo "[INFO] batch_size=$batch_size 执行完成。"
fi
done
echo ""
if (( failed_count > 0 )); then
echo "[WARN] 全部执行结束,但有 $failed_count 组失败。"
exit 1
fi
echo "[INFO] 全部 benchmark 执行完成。"
}
main "$@"
......@@ -386,6 +386,57 @@ TX
assert extracted_tool_calls.tool_calls[0].function.name == "get_current_weather"
def test_extract_tool_calls_missing_function_name(step3p5_tool_parser, sample_tools):
"""Tool call without function name should fallback to content."""
model_output = (
"<tool_call><parameter=pattern>*.py</parameter></function></tool_call>"
)
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
extracted_tool_calls = step3p5_tool_parser.extract_tool_calls(
model_output, request=request
)
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
def test_extract_tool_calls_empty_function_name(step3p5_tool_parser, sample_tools):
"""Tool call with empty function name should fallback to content."""
model_output = (
"<tool_call><function=><parameter=pattern>*.py</parameter>"
"</function></tool_call>"
)
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
extracted_tool_calls = step3p5_tool_parser.extract_tool_calls(
model_output, request=request
)
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
def test_extract_tool_calls_empty_function_name_single_quotes(
step3p5_tool_parser, sample_tools
):
"""Tool call with empty function name (single quotes) should fallback."""
model_output = (
"<tool_call><function=''><parameter=pattern>*.py</parameter></tool_call>"
)
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
extracted_tool_calls = step3p5_tool_parser.extract_tool_calls(
model_output, request=request
)
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
def test_extract_tool_calls_type_conversion(step3p5_tool_parser):
"""Test parameter type conversion based on tool schema"""
tools = [
......@@ -623,7 +674,7 @@ def test_extract_tool_calls_streaming(
expected_tool_calls,
expected_content,
):
"""Test incremental streaming behavior including typed parameters"""
"""Test streaming returns complete tool calls when parsed."""
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
other_content = ""
......@@ -647,11 +698,10 @@ def test_extract_tool_calls_streaming(
tool_states[idx] = {
"id": None,
"name": None,
"arguments": "",
"arguments": None,
"type": None,
}
# First chunk should have id, name, and type
if tool_call.id:
tool_states[idx]["id"] = tool_call.id
......@@ -666,8 +716,15 @@ def test_extract_tool_calls_streaming(
tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None:
# Accumulate arguments incrementally
tool_states[idx]["arguments"] += tool_call.function.arguments
# Arguments should be complete JSON when emitted.
json.loads(tool_call.function.arguments)
if tool_states[idx]["arguments"] is None:
tool_states[idx]["arguments"] = tool_call.function.arguments
else:
assert (
tool_states[idx]["arguments"]
== tool_call.function.arguments
)
# Verify final content
assert other_content == (expected_content or "") # Handle None case
......@@ -682,7 +739,7 @@ def test_extract_tool_calls_streaming(
assert state["type"] == "function"
assert state["name"] == expected_tool.function.name
# Parse accumulated arguments
# Parse arguments
arguments_str = state["arguments"]
assert arguments_str is not None
actual_args = json.loads(arguments_str)
......@@ -770,7 +827,7 @@ fahrenheit
tool_states[idx] = {
"id": None,
"name": None,
"arguments": "",
"arguments": None,
"type": None,
}
......@@ -786,7 +843,14 @@ fahrenheit
tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None:
tool_states[idx]["arguments"] += tool_call.function.arguments
json.loads(tool_call.function.arguments)
if tool_states[idx]["arguments"] is None:
tool_states[idx]["arguments"] = tool_call.function.arguments
else:
assert (
tool_states[idx]["arguments"]
== tool_call.function.arguments
)
# Verify content was streamed
assert "Let me check the weather for you:" in other_content
......@@ -806,62 +870,69 @@ fahrenheit
assert args["unit"] == "fahrenheit"
def test_extract_tool_calls_streaming_incremental(
def test_extract_tool_calls_streaming_missing_function_name(
step3p5_tool_parser, step3p5_tokenizer, sample_tools
):
"""Test that streaming is truly incremental"""
model_output = """I'll check the weather.<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
</function>
</tool_call>"""
"""Streaming: missing function name should be treated as content."""
model_output = (
"<tool_call><parameter=pattern>*.py</parameter></function></tool_call>"
)
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
chunks = []
for delta_message in stream_delta_message_generator(
step3p5_tool_parser, step3p5_tokenizer, model_output, request
other_content = ""
tool_calls = []
for delta_message in stream_delta_message_generator_from_chunks(
step3p5_tool_parser,
step3p5_tokenizer,
[
"<tool_call><parameter=pattern>",
"*.py</parameter>",
"</function></tool_call>",
],
request,
):
if delta_message.content:
other_content += delta_message.content
if delta_message.tool_calls:
tool_calls.extend(delta_message.tool_calls)
assert other_content == model_output
assert tool_calls == []
def test_extract_tool_calls_streaming_empty_function_name(
step3p5_tool_parser, step3p5_tokenizer, sample_tools
):
"""Streaming: empty function name should be treated as content."""
model_output = (
"<tool_call><function=''><parameter=pattern>*.py</parameter>"
"</function></tool_call>"
)
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
other_content = ""
tool_calls = []
for delta_message in stream_delta_message_generator_from_chunks(
step3p5_tool_parser,
step3p5_tokenizer,
[
"<tool_call><function=",
"''><parameter=pattern>*.py</parameter>",
"</function></tool_call>",
],
request,
):
chunks.append(delta_message)
# Should have multiple chunks
assert len(chunks) > 3
# First chunk(s) should be content
assert chunks[0].content is not None
assert chunks[0].tool_calls is None or chunks[0].tool_calls == []
# Should have a chunk with tool header (id, name, type)
header_found = False
for chunk in chunks:
if chunk.tool_calls and chunk.tool_calls[0].id:
header_found = True
assert chunk.tool_calls[0].function.name == "get_current_weather"
assert chunk.tool_calls[0].type == "function"
# Empty initially
assert chunk.tool_calls[0].function.arguments == ""
break
assert header_found
# Should have chunks with incremental arguments
arg_chunks = []
for chunk in chunks:
if chunk.tool_calls and chunk.tool_calls[0].function.arguments:
arg_chunks.append(chunk.tool_calls[0].function.arguments)
# Arguments should be streamed incrementally
assert len(arg_chunks) > 1
# Concatenated arguments should form valid JSON
full_args = "".join(arg_chunks)
parsed_args = json.loads(full_args)
assert parsed_args["city"] == "Dallas"
assert parsed_args["state"] == "TX"
if delta_message.content:
other_content += delta_message.content
if delta_message.tool_calls:
tool_calls.extend(delta_message.tool_calls)
assert other_content == model_output
assert tool_calls == []
def test_extract_tool_calls_complex_type_with_single_quote(step3p5_tool_parser):
......@@ -951,7 +1022,7 @@ rectangle
tool_states[idx] = {
"id": None,
"name": None,
"arguments": "",
"arguments": None,
"type": None,
}
......@@ -967,7 +1038,14 @@ rectangle
tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None:
tool_states[idx]["arguments"] += tool_call.function.arguments
json.loads(tool_call.function.arguments)
if tool_states[idx]["arguments"] is None:
tool_states[idx]["arguments"] = tool_call.function.arguments
else:
assert (
tool_states[idx]["arguments"]
== tool_call.function.arguments
)
# Should have exactly two complete tool calls
assert len(tool_states) == 2, "Should have exactly two complete tool calls"
......@@ -1164,7 +1242,7 @@ rectangle
tool_states[idx] = {
"id": None,
"name": None,
"arguments": "",
"arguments": None,
"type": None,
}
if tool_call.id:
......@@ -1175,7 +1253,14 @@ rectangle
if tool_call.function.name:
tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None:
tool_states[idx]["arguments"] += tool_call.function.arguments
json.loads(tool_call.function.arguments)
if tool_states[idx]["arguments"] is None:
tool_states[idx]["arguments"] = tool_call.function.arguments
else:
assert (
tool_states[idx]["arguments"]
== tool_call.function.arguments
)
# Should have exactly two complete tool calls
assert len(tool_states) == 2, "Should have exactly two complete tool calls"
......@@ -1266,7 +1351,7 @@ rectangle
tool_states[idx] = {
"id": None,
"name": None,
"arguments": "",
"arguments": None,
"type": None,
}
......@@ -1282,7 +1367,14 @@ rectangle
tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None:
tool_states[idx]["arguments"] += tool_call.function.arguments
json.loads(tool_call.function.arguments)
if tool_states[idx]["arguments"] is None:
tool_states[idx]["arguments"] = tool_call.function.arguments
else:
assert (
tool_states[idx]["arguments"]
== tool_call.function.arguments
)
# Should have exactly two complete tool calls
assert len(tool_states) == 2, "Should have exactly two complete tool calls"
......@@ -1344,20 +1436,26 @@ rectangle""",
for delta_message in stream_delta_message_generator_from_chunks(
step3p5_tool_parser, step3p5_tokenizer, delta_text_chunks, request
):
print(delta_message)
if delta_message.tool_calls:
for tool_call in delta_message.tool_calls:
idx = tool_call.index
if idx not in tool_states:
tool_states[idx] = {
"name": None,
"arguments": "",
"arguments": None,
}
if tool_call.function:
if tool_call.function.name:
tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None:
tool_states[idx]["arguments"] += tool_call.function.arguments
json.loads(tool_call.function.arguments)
if tool_states[idx]["arguments"] is None:
tool_states[idx]["arguments"] = tool_call.function.arguments
else:
assert (
tool_states[idx]["arguments"]
== tool_call.function.arguments
)
assert len(tool_states) == 2
assert all(state["name"] for state in tool_states.values())
......@@ -1368,7 +1466,7 @@ rectangle""",
def test_extract_tool_calls_non_streaming_multiple_tool_calls_no_content_between(
step3p5_tool_parser, sample_tools
):
"""Test non-streaming extraction with tool calls and no content between them.
"""Test non-streaming extraction with multiple tool calls.
Scenario: Model outputs "hello" + tool call + tool call.
Expected: "hello" as content, first tool call parsed (index=0),
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from types import SimpleNamespace
import pytest
import torch
from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata
from vllm.v1.spec_decode.multi_layer_eagle import MultiLayerEagleProposer
HIDDEN_SIZE = 3
def _make_multi_layer_eagle_metadata(
*,
initial_cache: list[dict],
max_shift: int,
device: torch.device,
) -> MultiLayerEagleMetadata:
for row in initial_cache:
assert "len" in row
row_len = int(row["len"])
assert 0 <= row_len <= max_shift
# Test cases pad cache rows to `layer_num` (== max_shift) and specify the
# number of valid entries via `len`.
assert (
len(row["token_ids"])
== len(row["positions"])
== len(row["slot_mapping"])
== max_shift
)
assert all(v == 0 for v in row["token_ids"][row_len:])
assert all(v == 0 for v in row["positions"][row_len:])
assert all(v == 0 for v in row["slot_mapping"][row_len:])
cached_len = torch.tensor(
[min(int(row["len"]), max_shift) for row in initial_cache],
dtype=torch.int64,
device=device,
)
cached_token_ids = torch.tensor(
[row["token_ids"] for row in initial_cache],
dtype=torch.int32,
device=device,
)
cached_positions = torch.tensor(
[row["positions"] for row in initial_cache],
dtype=torch.int64,
device=device,
)
cached_slot_mappings = torch.tensor(
[row["slot_mapping"] for row in initial_cache],
dtype=torch.int64,
device=device,
)
cached_hidden_states = torch.zeros(
(len(initial_cache), max_shift, HIDDEN_SIZE),
dtype=torch.float32,
device=device,
)
return MultiLayerEagleMetadata(
cached_len=cached_len,
cached_token_ids=cached_token_ids,
cached_hidden_states=cached_hidden_states,
cached_slot_mappings=cached_slot_mappings,
cached_positions=cached_positions,
)
@pytest.fixture
def proposer_stub():
if not torch.cuda.is_available():
pytest.skip("MultiLayerEagleProposer.adjust_input is CUDA/Triton-only.")
proposer = MultiLayerEagleProposer.__new__(MultiLayerEagleProposer)
proposer.layer_num = 3
return proposer
LAYER3_CASES = [
{
"name": "shift_0_at_sequence_end",
"batch_size": 1,
"initial_cache": [
{
"len": 0,
"token_ids": [0, 0, 0],
"positions": [0, 0, 0],
"slot_mapping": [0, 0, 0],
}
],
"target_token_ids": [10, 11, 12, 13],
"target_positions": [0, 1, 2, 3],
"token_indices_to_sample": [3],
"common_attn_metadata": {
"query_start_loc": [0, 4],
"seq_lens": [4],
"seq_lens_cpu": [4],
"num_computed_tokens_cpu": [0],
"slot_mapping": [100, 101, 102, 103],
"max_seq_len": 4,
},
"expected": {
"prev_token_ids": [10, 11, 12, 13],
"prev_positions": [0, 1, 2, 3],
"token_indices_to_sample": [3],
"seq_lens": [4],
"slot_mapping": [100, 101, 102, 103],
"cached": [
{
"len": 3,
"token_ids": [11, 12, 13],
"positions": [1, 2, 3],
"slot_mapping": [101, 102, 103],
}
],
},
},
{
"name": "batch2_short_seq_no_shift",
"batch_size": 2,
"initial_cache": [
{
"len": 0,
"token_ids": [0, 0, 0],
"positions": [0, 0, 0],
"slot_mapping": [0, 0, 0],
},
{
"len": 0,
"token_ids": [0, 0, 0],
"positions": [0, 0, 0],
"slot_mapping": [0, 0, 0],
},
],
"target_token_ids": [10, 11, 20],
"target_positions": [0, 1, 0],
"token_indices_to_sample": [1, 2],
"common_attn_metadata": {
"query_start_loc": [0, 2, 3],
"seq_lens": [2, 1],
"seq_lens_cpu": [2, 1],
"num_computed_tokens_cpu": [0, 0],
"slot_mapping": [100, 101, 200],
"max_seq_len": 2,
},
"expected": {
"prev_token_ids": [10, 11, 20],
"prev_positions": [0, 1, 0],
"token_indices_to_sample": [1, 2],
"seq_lens": [2, 1],
"slot_mapping": [100, 101, 200],
"cached": [
{
"len": 2,
"token_ids": [10, 11, 0],
"positions": [0, 1, 0],
"slot_mapping": [100, 101, 0],
},
{
"len": 1,
"token_ids": [20, 0, 0],
"positions": [0, 0, 0],
"slot_mapping": [200, 0, 0],
},
],
},
},
{
"name": "batch2_short_seq_shift_on_first",
"batch_size": 2,
"initial_cache": [
{
"len": 1,
"token_ids": [99, 0, 0],
"positions": [0, 0, 0],
"slot_mapping": [999, 0, 0],
},
{
"len": 0,
"token_ids": [0, 0, 0],
"positions": [0, 0, 0],
"slot_mapping": [0, 0, 0],
},
],
"target_token_ids": [10, 11, 20],
"target_positions": [1, 2, 0],
"token_indices_to_sample": [0, 2],
"common_attn_metadata": {
"query_start_loc": [0, 2, 3],
"seq_lens": [2, 1],
"seq_lens_cpu": [2, 1],
"num_computed_tokens_cpu": [1, 0],
"slot_mapping": [100, 101, 200],
"max_seq_len": 2,
},
"expected": {
"prev_token_ids": [99, 10, 20],
"prev_positions": [0, 1, 0],
"token_indices_to_sample": [1, 2],
"seq_lens": [1, 1],
"slot_mapping": [999, 100, 200],
"cached": [
{
"len": 2,
"token_ids": [99, 10, 0],
"positions": [0, 1, 0],
"slot_mapping": [999, 100, 0],
},
{
"len": 1,
"token_ids": [20, 0, 0],
"positions": [0, 0, 0],
"slot_mapping": [200, 0, 0],
},
],
},
},
{
"name": "short_seq_len_2_shift_0_cache_len_1",
"batch_size": 1,
"initial_cache": [
{
"len": 0,
"token_ids": [0, 0, 0],
"positions": [0, 0, 0],
"slot_mapping": [0, 0, 0],
}
],
"target_token_ids": [7, 8],
"target_positions": [0, 1],
"token_indices_to_sample": [0],
"common_attn_metadata": {
"query_start_loc": [0, 2],
"seq_lens": [2],
"seq_lens_cpu": [2],
"num_computed_tokens_cpu": [0],
"slot_mapping": [1000, 1001],
"max_seq_len": 2,
},
"expected": {
"prev_token_ids": [7, 8],
"prev_positions": [0, 1],
"token_indices_to_sample": [0],
"seq_lens": [2],
"slot_mapping": [1000, 1001],
"cached": [
{
"len": 1,
"token_ids": [7, 0, 0],
"positions": [0, 0, 0],
"slot_mapping": [1000, 0, 0],
}
],
},
},
{
"name": "short_seq_len_2_shift_1_cache_len_2",
"batch_size": 1,
"initial_cache": [
{
"len": 1,
"token_ids": [6, 0, 0],
"positions": [0, 0, 0],
"slot_mapping": [999, 0, 0],
}
],
"target_token_ids": [7, 8],
"target_positions": [1, 2],
"token_indices_to_sample": [0],
"common_attn_metadata": {
"query_start_loc": [0, 2],
"seq_lens": [2],
"seq_lens_cpu": [2],
"num_computed_tokens_cpu": [1],
"slot_mapping": [1000, 1001],
"max_seq_len": 2,
},
"expected": {
"prev_token_ids": [6, 7],
"prev_positions": [0, 1],
"token_indices_to_sample": [1],
"seq_lens": [1],
"slot_mapping": [999, 1000],
"cached": [
{
"len": 2,
"token_ids": [6, 7, 0],
"positions": [0, 1, 0],
"slot_mapping": [999, 1000, 0],
}
],
},
},
{
"name": "shift_bounded_by_start_pos_zero",
"batch_size": 1,
"initial_cache": [
{
"len": 0,
"token_ids": [0, 0, 0],
"positions": [0, 0, 0],
"slot_mapping": [0, 0, 0],
}
],
"target_token_ids": [10, 11, 12, 13],
"target_positions": [0, 2, 3, 4],
"token_indices_to_sample": [1],
"common_attn_metadata": {
"query_start_loc": [0, 4],
"seq_lens": [4],
"seq_lens_cpu": [4],
"num_computed_tokens_cpu": [0],
"slot_mapping": [100, 101, 102, 103],
"max_seq_len": 4,
},
"expected": {
"prev_token_ids": [10, 11, 12, 13],
"prev_positions": [0, 2, 3, 4],
"token_indices_to_sample": [1],
"seq_lens": [4],
"slot_mapping": [100, 101, 102, 103],
"cached": [
{
"len": 2,
"token_ids": [10, 11, 0],
"positions": [0, 2, 0],
"slot_mapping": [100, 101, 0],
}
],
},
},
{
"name": "shift_bounded_by_start_pos",
"batch_size": 1,
"initial_cache": [
{
"len": 0,
"token_ids": [0, 0, 0],
"positions": [0, 0, 0],
"slot_mapping": [0, 0, 0],
}
],
"target_token_ids": [10, 11, 12, 13, 14],
"target_positions": [0, 1, 2, 3, 4],
"token_indices_to_sample": [1],
"common_attn_metadata": {
"query_start_loc": [0, 5],
"seq_lens": [5],
"seq_lens_cpu": [5],
"num_computed_tokens_cpu": [1],
"slot_mapping": [100, 101, 102, 103, 104],
"max_seq_len": 5,
},
"expected": {
"prev_token_ids": [10, 11, 12, 13, 14],
"prev_positions": [0, 1, 2, 3, 4],
"token_indices_to_sample": [1],
"seq_lens": [5],
"slot_mapping": [100, 101, 102, 103, 104],
"cached": [
{
"len": 2,
"token_ids": [10, 11, 0],
"positions": [0, 1, 0],
"slot_mapping": [100, 101, 0],
}
],
},
},
{
"name": "shift_2_bounded_by_remaining",
"batch_size": 1,
"initial_cache": [
{
"len": 0,
"token_ids": [0, 0, 0],
"positions": [0, 0, 0],
"slot_mapping": [0, 0, 0],
}
],
"target_token_ids": [10, 11, 12, 13, 14],
"target_positions": [0, 1, 2, 3, 4],
"token_indices_to_sample": [2],
"common_attn_metadata": {
"query_start_loc": [0, 5],
"seq_lens": [5],
"seq_lens_cpu": [5],
"num_computed_tokens_cpu": [2],
"slot_mapping": [100, 101, 102, 103, 104],
"max_seq_len": 5,
},
"expected": {
"prev_token_ids": [10, 11, 12, 13, 14],
"prev_positions": [0, 1, 2, 3, 4],
"token_indices_to_sample": [2],
"seq_lens": [5],
"slot_mapping": [100, 101, 102, 103, 104],
"cached": [
{
"len": 3,
"token_ids": [10, 11, 12],
"positions": [0, 1, 2],
"slot_mapping": [100, 101, 102],
}
],
},
},
{
"name": "shift_3_full_cache_window",
"batch_size": 1,
"initial_cache": [
{
"len": 0,
"token_ids": [0, 0, 0],
"positions": [0, 0, 0],
"slot_mapping": [0, 0, 0],
}
],
"target_token_ids": [20, 21, 22, 23, 24],
"target_positions": [0, 3, 4, 5, 6],
"token_indices_to_sample": [1],
"common_attn_metadata": {
"query_start_loc": [0, 5],
"seq_lens": [5],
"seq_lens_cpu": [5],
"num_computed_tokens_cpu": [3],
"slot_mapping": [100, 101, 102, 103, 104],
"max_seq_len": 5,
},
"expected": {
"prev_token_ids": [20, 21, 22, 23, 24],
"prev_positions": [0, 3, 4, 5, 6],
"token_indices_to_sample": [1],
"seq_lens": [5],
"slot_mapping": [100, 101, 102, 103, 104],
"cached": [
{
"len": 2,
"token_ids": [20, 21, 0],
"positions": [0, 3, 0],
"slot_mapping": [100, 101, 0],
}
],
},
},
{
"name": "batch2_shift_1_and_1",
"batch_size": 2,
"initial_cache": [
{
"len": 0,
"token_ids": [0, 0, 0],
"positions": [0, 0, 0],
"slot_mapping": [0, 0, 0],
},
{
"len": 0,
"token_ids": [0, 0, 0],
"positions": [0, 0, 0],
"slot_mapping": [0, 0, 0],
},
],
"target_token_ids": [10, 11, 12, 13, 20, 21, 22],
"target_positions": [0, 1, 2, 3, 0, 1, 2],
"token_indices_to_sample": [1, 5],
"common_attn_metadata": {
"query_start_loc": [0, 4, 7],
"seq_lens": [4, 3],
"seq_lens_cpu": [4, 3],
"num_computed_tokens_cpu": [1, 1],
"slot_mapping": [100, 101, 102, 103, 200, 201, 202],
"max_seq_len": 4,
},
"expected": {
"prev_token_ids": [10, 11, 12, 13, 20, 21, 22],
"prev_positions": [0, 1, 2, 3, 0, 1, 2],
"token_indices_to_sample": [1, 5],
"seq_lens": [4, 3],
"slot_mapping": [100, 101, 102, 103, 200, 201, 202],
"cached": [
{
"len": 2,
"token_ids": [10, 11, 0],
"positions": [0, 1, 0],
"slot_mapping": [100, 101, 0],
},
{
"len": 2,
"token_ids": [20, 21, 0],
"positions": [0, 1, 0],
"slot_mapping": [200, 201, 0],
},
],
},
},
{
"name": "batch4_mixed_shifts",
"batch_size": 4,
"initial_cache": [
{
"len": 0,
"token_ids": [0, 0, 0],
"positions": [0, 0, 0],
"slot_mapping": [0, 0, 0],
},
{
"len": 1,
"token_ids": [19, 0, 0],
"positions": [0, 0, 0],
"slot_mapping": [119, 0, 0],
},
{
"len": 0,
"token_ids": [0, 0, 0],
"positions": [0, 0, 0],
"slot_mapping": [0, 0, 0],
},
{
"len": 0,
"token_ids": [0, 0, 0],
"positions": [0, 0, 0],
"slot_mapping": [0, 0, 0],
},
],
"target_token_ids": [10, 11, 20, 21, 22, 30, 31, 32, 33, 40, 41, 42],
"target_positions": [0, 1, 1, 2, 3, 0, 2, 3, 4, 0, 1, 2],
"token_indices_to_sample": [1, 2, 6, 10],
"common_attn_metadata": {
"query_start_loc": [0, 2, 5, 9, 12],
"seq_lens": [2, 3, 4, 3],
"seq_lens_cpu": [2, 3, 4, 3],
"num_computed_tokens_cpu": [0, 1, 2, 1],
"slot_mapping": [
100,
101,
102,
103,
104,
105,
106,
107,
108,
109,
110,
111,
],
"max_seq_len": 4,
},
"expected": {
"prev_token_ids": [10, 11, 19, 20, 21, 30, 31, 32, 33, 40, 41, 42],
"prev_positions": [0, 1, 0, 1, 2, 0, 2, 3, 4, 0, 1, 2],
"token_indices_to_sample": [1, 3, 6, 10],
"seq_lens": [2, 2, 4, 3],
"slot_mapping": [
100,
101,
119,
102,
103,
105,
106,
107,
108,
109,
110,
111,
],
"cached": [
{
"len": 2,
"token_ids": [10, 11, 0],
"positions": [0, 1, 0],
"slot_mapping": [100, 101, 0],
},
{
"len": 2,
"token_ids": [19, 20, 0],
"positions": [0, 1, 0],
"slot_mapping": [119, 102, 0],
},
{
"len": 2,
"token_ids": [30, 31, 0],
"positions": [0, 2, 0],
"slot_mapping": [105, 106, 0],
},
{
"len": 2,
"token_ids": [40, 41, 0],
"positions": [0, 1, 0],
"slot_mapping": [109, 110, 0],
},
],
},
},
{
"name": "batch2_shift_0_and_2",
"batch_size": 2,
"initial_cache": [
{
"len": 0,
"token_ids": [0, 0, 0],
"positions": [0, 0, 0],
"slot_mapping": [0, 0, 0],
},
{
"len": 0,
"token_ids": [0, 0, 0],
"positions": [0, 0, 0],
"slot_mapping": [0, 0, 0],
},
],
"target_token_ids": [30, 31, 32, 40, 41, 42, 43],
"target_positions": [0, 1, 2, 0, 3, 4, 5],
"token_indices_to_sample": [2, 4],
"common_attn_metadata": {
"query_start_loc": [0, 3, 7],
"seq_lens": [3, 4],
"seq_lens_cpu": [3, 4],
"num_computed_tokens_cpu": [0, 2],
"slot_mapping": [100, 101, 102, 200, 201, 202, 203],
"max_seq_len": 4,
},
"expected": {
"prev_token_ids": [30, 31, 32, 40, 41, 42, 43],
"prev_positions": [0, 1, 2, 0, 3, 4, 5],
"token_indices_to_sample": [2, 4],
"seq_lens": [3, 4],
"slot_mapping": [100, 101, 102, 200, 201, 202, 203],
"cached": [
{
"len": 3,
"token_ids": [30, 31, 32],
"positions": [0, 1, 2],
"slot_mapping": [100, 101, 102],
},
{
"len": 2,
"token_ids": [40, 41, 0],
"positions": [0, 3, 0],
"slot_mapping": [200, 201, 0],
},
],
},
},
{
"name": "continue_req_shift_1_cache_tail_3",
"batch_size": 1,
"initial_cache": [
{
"len": 3,
"token_ids": [70, 71, 72],
"positions": [7, 8, 9],
"slot_mapping": [170, 171, 172],
}
],
"target_token_ids": [100, 101, 102, 103, 104],
"target_positions": [10, 11, 12, 13, 14],
"token_indices_to_sample": [3],
"common_attn_metadata": {
"query_start_loc": [0, 5],
"seq_lens": [5],
"seq_lens_cpu": [5],
"num_computed_tokens_cpu": [0],
"slot_mapping": [200, 201, 202, 203, 204],
"max_seq_len": 5,
},
"expected": {
"prev_token_ids": [72, 100, 101, 102, 103],
"prev_positions": [9, 10, 11, 12, 13],
"token_indices_to_sample": [4],
"seq_lens": [4],
"slot_mapping": [172, 200, 201, 202, 203],
"cached": [
{
"len": 3,
"token_ids": [101, 102, 103],
"positions": [11, 12, 13],
"slot_mapping": [201, 202, 203],
}
],
},
},
{
"name": "continue_req_shift_3_cache_tail_3",
"batch_size": 1,
"initial_cache": [
{
"len": 3,
"token_ids": [270, 271, 272],
"positions": [27, 28, 29],
"slot_mapping": [370, 371, 372],
}
],
"target_token_ids": [300, 301, 302, 303, 304, 305, 306],
"target_positions": [30, 31, 32, 33, 34, 35, 36],
"token_indices_to_sample": [3],
"common_attn_metadata": {
"query_start_loc": [0, 7],
"seq_lens": [7],
"seq_lens_cpu": [7],
"num_computed_tokens_cpu": [0],
"slot_mapping": [400, 401, 402, 403, 404, 405, 406],
"max_seq_len": 7,
},
"expected": {
"prev_token_ids": [270, 271, 272, 300, 301, 302, 303],
"prev_positions": [27, 28, 29, 30, 31, 32, 33],
"token_indices_to_sample": [6],
"seq_lens": [4],
"slot_mapping": [370, 371, 372, 400, 401, 402, 403],
"cached": [
{
"len": 3,
"token_ids": [301, 302, 303],
"positions": [31, 32, 33],
"slot_mapping": [401, 402, 403],
}
],
},
},
{
"name": "batch3_mixed_shifts_0_1_2_all_full_cache",
"batch_size": 3,
"initial_cache": [
{
"len": 0,
"token_ids": [0, 0, 0],
"positions": [0, 0, 0],
"slot_mapping": [0, 0, 0],
},
{
"len": 3,
"token_ids": [70, 71, 72],
"positions": [7, 8, 9],
"slot_mapping": [170, 171, 172],
},
{
"len": 3,
"token_ids": [270, 271, 272],
"positions": [17, 18, 19],
"slot_mapping": [370, 371, 372],
},
],
"target_token_ids": [
10,
11,
12,
13,
100,
101,
102,
103,
104,
200,
201,
202,
203,
204,
205,
],
"target_positions": [
0,
1,
2,
3,
10,
11,
12,
13,
14,
20,
21,
22,
23,
24,
25,
],
"token_indices_to_sample": [3, 7, 12],
"common_attn_metadata": {
"query_start_loc": [0, 4, 9, 15],
"seq_lens": [4, 5, 6],
"seq_lens_cpu": [4, 5, 6],
"num_computed_tokens_cpu": [0, 0, 0],
"slot_mapping": [
100,
101,
102,
103,
200,
201,
202,
203,
204,
300,
301,
302,
303,
304,
305,
],
"max_seq_len": 6,
},
"expected": {
"prev_token_ids": [
10,
11,
12,
13,
72,
100,
101,
102,
103,
271,
272,
200,
201,
202,
203,
],
"prev_positions": [
0,
1,
2,
3,
9,
10,
11,
12,
13,
18,
19,
20,
21,
22,
23,
],
"token_indices_to_sample": [3, 8, 14],
"seq_lens": [4, 4, 4],
"slot_mapping": [
100,
101,
102,
103,
172,
200,
201,
202,
203,
371,
372,
300,
301,
302,
303,
],
"cached": [
{
"len": 3,
"token_ids": [11, 12, 13],
"positions": [1, 2, 3],
"slot_mapping": [101, 102, 103],
},
{
"len": 3,
"token_ids": [101, 102, 103],
"positions": [11, 12, 13],
"slot_mapping": [201, 202, 203],
},
{
"len": 3,
"token_ids": [201, 202, 203],
"positions": [21, 22, 23],
"slot_mapping": [301, 302, 303],
},
],
},
},
]
def _run_adjust_input_case(proposer_stub, case, layer_num):
proposer = proposer_stub
proposer.layer_num = layer_num
max_shift = proposer.layer_num
device = torch.device("cuda")
initial_cache = case["initial_cache"]
batch_size = case["batch_size"]
assert len(initial_cache) == batch_size
meta = case["common_attn_metadata"]
query_start_loc_cpu = torch.tensor(
meta["query_start_loc"], dtype=torch.int32, device="cpu"
)
common_attn_metadata = SimpleNamespace(
query_start_loc=query_start_loc_cpu.to(device=device),
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=torch.tensor(meta["seq_lens"], dtype=torch.int32, device=device),
seq_lens_cpu=torch.tensor(
meta["seq_lens_cpu"], dtype=torch.int32, device="cpu"
),
num_computed_tokens_cpu=torch.tensor(
meta["num_computed_tokens_cpu"], dtype=torch.int32, device="cpu"
),
slot_mapping=torch.tensor(
meta["slot_mapping"], dtype=torch.int64, device=device
),
max_seq_len=meta["max_seq_len"],
)
target_token_ids = torch.tensor(
case["target_token_ids"], dtype=torch.int32, device=device
)
target_positions = torch.tensor(
case["target_positions"], dtype=torch.int64, device=device
)
target_hidden_states = torch.arange(
0, target_token_ids.numel() * HIDDEN_SIZE, dtype=torch.float32, device=device
).reshape(-1, HIDDEN_SIZE)
token_indices_to_sample = torch.tensor(
case["token_indices_to_sample"], dtype=torch.int32, device=device
)
multi_layer_eagle_metadata = _make_multi_layer_eagle_metadata(
initial_cache=initial_cache,
max_shift=max_shift,
device=device,
)
prev_token_ids, prev_positions, _, _ = proposer.adjust_input(
batch_size=batch_size,
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
token_indices_to_sample=token_indices_to_sample,
common_attn_metadata=common_attn_metadata,
multi_layer_eagle_metadata=multi_layer_eagle_metadata,
)
expected = case["expected"]
assert len(expected["cached"]) == batch_size
assert prev_token_ids.cpu().tolist() == expected["prev_token_ids"]
assert prev_positions.cpu().tolist() == expected["prev_positions"]
assert token_indices_to_sample.cpu().tolist() == expected["token_indices_to_sample"]
assert common_attn_metadata.seq_lens.cpu().tolist() == expected["seq_lens"]
assert common_attn_metadata.slot_mapping.cpu().tolist() == expected["slot_mapping"]
for row, cached_expect in enumerate(expected["cached"]):
assert cached_expect["len"] <= max_shift
assert (
len(cached_expect["token_ids"])
== len(cached_expect["positions"])
== len(cached_expect["slot_mapping"])
== max_shift
)
cache_len = int(cached_expect["len"])
assert int(multi_layer_eagle_metadata.cached_len[row].item()) == cache_len
assert all(v == 0 for v in cached_expect["token_ids"][cache_len:])
assert all(v == 0 for v in cached_expect["positions"][cache_len:])
assert all(v == 0 for v in cached_expect["slot_mapping"][cache_len:])
assert (
multi_layer_eagle_metadata.cached_token_ids[row].cpu().tolist()
== cached_expect["token_ids"]
)
assert (
multi_layer_eagle_metadata.cached_positions[row].cpu().tolist()
== cached_expect["positions"]
)
assert (
multi_layer_eagle_metadata.cached_slot_mappings[row].cpu().tolist()
== cached_expect["slot_mapping"]
)
@pytest.mark.parametrize(
"case", LAYER3_CASES, ids=[case["name"] for case in LAYER3_CASES]
)
def test_adjust_input_layer3_cases(proposer_stub, case):
_run_adjust_input_case(proposer_stub, case, layer_num=3)
......@@ -80,6 +80,10 @@ class SpeculativeConfig:
If using `ngram` method, the related configuration `prompt_lookup_max` and
`prompt_lookup_min` should be considered."""
enable_multi_layers_mtp: bool = False
"""If set to True, the MTP method will run multiple layers of MTP
speculator. If set to False, it will run only one layer of MTP speculator.
This is only effective when the method is set to `mtp`."""
draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
"""The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size."""
......@@ -493,7 +497,10 @@ class SpeculativeConfig:
MTPModelTypes
):
self.method = "mtp"
if self.num_speculative_tokens > 1:
if (
self.enable_multi_layers_mtp is False
and self.num_speculative_tokens > 1
):
logger.warning(
"Enabling num_speculative_tokens > 1 will run "
"multiple times of forward on same MTP layer"
......
......@@ -166,7 +166,72 @@ class AnthropicServingMessages(OpenAIServingChat):
if isinstance(msg.content, str):
openai_msg["content"] = msg.content
else:
cls._convert_message_content(msg, openai_msg, openai_messages)
# Handle complex content blocks
content_parts: list[dict[str, Any]] = []
tool_calls: list[dict[str, Any]] = []
reasoning_parts: list[str] = []
for block in msg.content:
if block.type == "text" and block.text:
content_parts.append({"type": "text", "text": block.text})
elif block.type == "image" and block.source:
content_parts.append(
{
"type": "image_url",
"image_url": {"url": block.source.get("data", "")},
}
)
elif block.type == "thinking" and block.thinking is not None:
reasoning_parts.append(block.thinking)
elif block.type == "tool_use":
# Convert tool use to function call format
tool_call = {
"id": block.id or f"call_{int(time.time())}",
"type": "function",
"function": {
"name": block.name or "",
"arguments": json.dumps(block.input or {}),
},
}
tool_calls.append(tool_call)
elif block.type == "tool_result":
if msg.role == "user":
openai_messages.append(
{
"role": "tool",
"tool_call_id": block.id or "",
"content": str(block.content)
if block.content
else "",
}
)
else:
# Assistant tool result becomes regular text
tool_result_text = (
str(block.content) if block.content else ""
)
content_parts.append(
{
"type": "text",
"text": f"Tool result: {tool_result_text}",
}
)
if reasoning_parts:
openai_msg["reasoning"] = "".join(reasoning_parts)
# Add tool calls to the message if any
if tool_calls:
openai_msg["tool_calls"] = tool_calls # type: ignore
# Add content parts if any
if content_parts:
if len(content_parts) == 1 and content_parts[0]["type"] == "text":
openai_msg["content"] = content_parts[0]["text"]
else:
openai_msg["content"] = content_parts # type: ignore
elif not tool_calls and not reasoning_parts:
continue
openai_messages.append(openai_msg)
......@@ -522,49 +587,75 @@ class AnthropicServingMessages(OpenAIServingChat):
first_item = True
finish_reason = None
state = _ActiveBlockState()
content_block_index = 0
active_block_type: str | None = None
active_block_index: int | None = None
active_block_signature: str | None = None
signature_emitted = False
active_tool_use_id: str | None = None
# Map from tool call index to tool_use_id
tool_index_to_id: dict[int, str] = {}
def stop_active_block():
nonlocal active_block_type, active_block_index, content_block_index
nonlocal active_block_signature, signature_emitted, active_tool_use_id
events: list[str] = []
if state.block_type is None:
if active_block_type is None:
return events
if (
state.block_type == "thinking"
and state.block_signature is not None
and not state.signature_emitted
active_block_type == "thinking"
and active_block_signature is not None
and not signature_emitted
):
chunk = AnthropicStreamEvent(
index=state.block_index,
index=active_block_index,
type="content_block_delta",
delta=AnthropicDelta(
type="signature_delta",
signature=state.block_signature,
signature=active_block_signature,
),
)
data = chunk.model_dump_json(exclude_unset=True)
events.append(wrap_data_with_event(data, "content_block_delta"))
state.signature_emitted = True
signature_emitted = True
stop_chunk = AnthropicStreamEvent(
index=state.block_index,
index=active_block_index,
type="content_block_stop",
)
data = stop_chunk.model_dump_json(exclude_unset=True)
events.append(wrap_data_with_event(data, "content_block_stop"))
state.reset()
state.content_block_index += 1
active_block_type = None
active_block_index = None
active_block_signature = None
signature_emitted = False
active_tool_use_id = None
content_block_index += 1
return events
def start_block(block: AnthropicContentBlock):
nonlocal active_block_type, active_block_index, content_block_index
nonlocal active_block_signature, signature_emitted, active_tool_use_id
chunk = AnthropicStreamEvent(
index=state.content_block_index,
index=content_block_index,
type="content_block_start",
content_block=block,
)
data = chunk.model_dump_json(exclude_unset=True)
event = wrap_data_with_event(data, "content_block_start")
state.start(block)
active_block_type = block.type
active_block_index = content_block_index
if block.type == "thinking":
active_block_signature = uuid.uuid4().hex
signature_emitted = False
active_tool_use_id = None
elif block.type == "tool_use":
active_block_signature = None
signature_emitted = True
active_tool_use_id = block.id
else:
active_block_signature = None
signature_emitted = True
active_tool_use_id = None
return event
async for item in generator:
......@@ -638,7 +729,7 @@ class AnthropicServingMessages(OpenAIServingChat):
if reasoning_delta == "":
pass
else:
if state.block_type != "thinking":
if active_block_type != "thinking":
for event in stop_active_block():
yield event
start_event = start_block(
......@@ -649,9 +740,9 @@ class AnthropicServingMessages(OpenAIServingChat):
yield start_event
chunk = AnthropicStreamEvent(
index=(
state.block_index
if state.block_index is not None
else state.content_block_index
active_block_index
if active_block_index is not None
else content_block_index
),
type="content_block_delta",
delta=AnthropicDelta(
......@@ -666,7 +757,7 @@ class AnthropicServingMessages(OpenAIServingChat):
if origin_chunk.choices[0].delta.content == "":
pass
else:
if state.block_type != "text":
if active_block_type != "text":
for event in stop_active_block():
yield event
start_event = start_block(
......@@ -675,9 +766,9 @@ class AnthropicServingMessages(OpenAIServingChat):
yield start_event
chunk = AnthropicStreamEvent(
index=(
state.block_index
if state.block_index is not None
else state.content_block_index
active_block_index
if active_block_index is not None
else content_block_index
),
type="content_block_delta",
delta=AnthropicDelta(
......@@ -702,7 +793,7 @@ class AnthropicServingMessages(OpenAIServingChat):
else None
)
if (
state.tool_use_id != tool_call.id
active_tool_use_id != tool_call.id
and tool_name is not None
):
for event in stop_active_block():
......@@ -720,13 +811,13 @@ class AnthropicServingMessages(OpenAIServingChat):
if (
tool_call.function
and tool_call.function.arguments
and state.tool_use_id == tool_call.id
and active_tool_use_id == tool_call.id
):
chunk = AnthropicStreamEvent(
index=(
state.block_index
if state.block_index is not None
else state.content_block_index
active_block_index
if active_block_index is not None
else content_block_index
),
type="content_block_delta",
delta=AnthropicDelta(
......@@ -745,13 +836,13 @@ class AnthropicServingMessages(OpenAIServingChat):
tool_use_id is not None
and tool_call.function
and tool_call.function.arguments
and state.tool_use_id == tool_use_id
and active_tool_use_id == tool_use_id
):
chunk = AnthropicStreamEvent(
index=(
state.block_index
if state.block_index is not None
else state.content_block_index
active_block_index
if active_block_index is not None
else content_block_index
),
type="content_block_delta",
delta=AnthropicDelta(
......
......@@ -1101,10 +1101,10 @@ class OpenAIServingChat(OpenAIServing):
index = 0
if (
self._should_check_for_unstreamed_tool_arg_tokens(
delta_message, output
tool_parser
and self._should_check_for_unstreamed_tool_arg_tokens(
delta_message, output, tool_parser
)
and tool_parser
):
latest_delta_len = 0
if (
......@@ -1760,6 +1760,7 @@ class OpenAIServingChat(OpenAIServing):
self,
delta_message: DeltaMessage | None,
output: CompletionOutput,
tool_parser: ToolParser | None = None,
) -> bool:
"""
Check to see if we should check for unstreamed tool arguments tokens.
......@@ -1772,6 +1773,8 @@ class OpenAIServingChat(OpenAIServing):
# include a function that has arguments
output.finish_reason is not None
and self.enable_auto_tools
and tool_parser is not None
and tool_parser.parser_should_check_for_unstreamed_tool_arg_tokens()
and self.tool_parser
and delta_message
and delta_message.tool_calls
......
......@@ -262,6 +262,7 @@ def select_fp8_moe_backend(
supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format
)
supported = True
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
......
......@@ -6,6 +6,7 @@ import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
......@@ -40,9 +41,11 @@ class SharedHead(nn.Module):
return self.norm(hidden_states)
@support_torch_compile
class Step3p5AMultiTokenPredictorLayer(nn.Module):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str,
) -> None:
......@@ -52,7 +55,7 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
self.enorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.hnorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.lm_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = Step3p5DecoderLayer(
vllm_config,
prefix=f"{prefix}.mtp_block",
......@@ -64,9 +67,12 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
embed_tokens: VocabParallelEmbedding | None = None,
spec_step_index: int = 0,
) -> torch.Tensor:
assert inputs_embeds is not None
if inputs_embeds is None:
assert embed_tokens is not None
inputs_embeds = embed_tokens(input_ids)
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
......@@ -92,8 +98,8 @@ class Step3p5AMultiTokenPredictor(nn.Module):
self.layers = torch.nn.ModuleDict(
{
str(idx): Step3p5AMultiTokenPredictorLayer(
vllm_config,
f"{prefix}.layers.{idx}",
vllm_config=vllm_config,
prefix=f"{prefix}.layers.{idx}",
)
for idx in range(
self.mtp_start_layer_idx,
......@@ -112,14 +118,13 @@ class Step3p5AMultiTokenPredictor(nn.Module):
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = spec_step_idx % self.num_mtp_layers
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids,
positions,
previous_hidden_states,
inputs_embeds,
self.embed_tokens,
current_step_idx,
)
......@@ -131,7 +136,7 @@ class Step3p5AMultiTokenPredictor(nn.Module):
current_step_idx = spec_step_idx % self.num_mtp_layers
mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
logits = self.logits_processor(
mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states)
mtp_layer.lm_head.head, mtp_layer.lm_head(hidden_states)
)
return logits
......@@ -257,6 +262,7 @@ class Step3p5MTP(nn.Module):
name = name.replace(".transformer.", ".")
if "shared_head" in name:
name = name.replace("shared_head.output", "shared_head.head")
name = name.replace("shared_head", "lm_head")
if "embed_tokens" in name:
assert (
hasattr(self.config, "num_nextn_predict_layers")
......
......@@ -118,6 +118,12 @@ class ToolParser:
"AbstractToolParser.extract_tool_calls_streaming has not been implemented!"
)
def parser_should_check_for_unstreamed_tool_arg_tokens(self) -> bool:
"""
Whether to check for unstreamed tool-argument tokens in serving
"""
return True
class ToolParserManager:
"""
......
......@@ -2,13 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
import uuid
from collections.abc import Sequence
from typing import Any
from xml.parsers.expat import ParserCreate
import regex as re
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
......@@ -23,1500 +22,1144 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
logger = init_logger(__name__)
class StreamingXMLToolCallParser:
"""
Simplified streaming XML tool call parser
Supports streaming input, parsing, and output
"""
class Step3p5ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self):
self.reset_streaming_state()
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
# Override base class type - we use string IDs for tool calls
self.current_tool_id: str | None = None # type: ignore
self.streamed_args_for_tool: list[str] = []
# Tool configuration information
self.tools: list[ChatCompletionToolsParam] | None = None
# Sentinel tokens for streaming mode
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"
self.function_start_token: str = "<function="
self.tool_call_prefix: str = "<function="
self.function_end_token: str = "</function>"
self.parameter_start_token: str = "<parameter="
self.parameter_prefix: str = "<parameter="
self.parameter_end_token: str = "</parameter>"
self.is_tool_call_started: bool = False
self.failed_count: int = 0
def reset_streaming_state(self):
"""Reset streaming parsing state"""
self.deltas = []
# state for streaming
self.tool_call_index = 0
self.current_call_id = None
self.last_completed_call_id = None
self.current_function_name = None
self.current_function_open = False
self.parameters = {}
self.current_param_name = None
self.current_param_value = ""
self.current_param_value_converted = ""
self.current_param_is_first = False
self.should_emit_end_newline = False
self.start_quote_emitted = False
self.streaming_buffer = ""
self.last_processed_pos = 0
self.text_content_buffer = ""
# state for preprocessing and deferred parsing
self._pre_inside_parameter = False
self._pre_param_buffer = ""
self._pre_current_param_name = None
self.defer_current_parameter = False
self.deferred_param_raw_value = ""
# recreate parser
self.parser = ParserCreate()
self.setup_parser()
def parse_single_streaming_chunks(self, xml_chunk: str) -> DeltaMessage:
"""
Parse single streaming XML chunk and return Delta response
This is the actual streaming interface that receives chunks
one by one and maintains internal state
# Enhanced streaming state - reset for each new message
self._reset_streaming_state()
Args:
xml_chunk: Single XML chunk string
Returns:
DeltaMessage: Contains delta information generated by this chunk,
returns empty response if no complete elements
"""
# Record delta count before processing
initial_delta_count = len(self.deltas)
entry_call_id = self.current_call_id
entry_tool_call_index = self.tool_call_index
# Regex patterns
self.tool_call_complete_regex = re.compile(
r"<tool_call>(.*?)</tool_call>", re.DOTALL
)
self.tool_call_function_regex = re.compile(
r"<function(?:=|\s+)?(.*?)</function>", re.DOTALL
)
self.tool_call_parameter_regex = re.compile(
r"<parameter=(.*?)</parameter>", re.DOTALL
)
self.streaming_buffer += xml_chunk
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
found_elements = self._process_complete_xml_elements()
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
fallback_call_id = None
if entry_call_id is not None:
if (
self.current_call_id == entry_call_id
and self.tool_call_index == entry_tool_call_index
if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
raise RuntimeError(
"Step3p5 RL Tool parser could not locate tool call start/end "
"tokens in the tokenizer!"
)
# Get EOS token ID for EOS detection
self.eos_token_id = getattr(self.model_tokenizer, "eos_token_id", None)
logger.info(
"vLLM Successfully import tool parser %s !", self.__class__.__name__
)
def _generate_tool_call_id(self) -> str:
"""Generate a unique tool call ID."""
return f"call_{uuid.uuid4().hex[:24]}"
def parser_should_check_for_unstreamed_tool_arg_tokens(self) -> bool:
"""
Skip the remaining_call calculation in serving
"""
return False
def _reset_streaming_state(self):
"""Reset all streaming state for a new request."""
self._processed_length: int = 0 # Position of last processed character
self._tool_call_index: int = 0 # Number of tool calls processed so far
self.streaming_request = None # Current request being processed
def _get_arguments_config(
self, func_name: str, tools: list[ChatCompletionToolsParam] | None
) -> dict:
"""Extract argument configuration for a function."""
if tools is None:
return {}
for config in tools:
if not hasattr(config, "type") or not (
hasattr(config, "function") and hasattr(config.function, "name")
):
fallback_call_id = entry_call_id
continue
if config.type == "function" and config.function.name == func_name:
if not hasattr(config.function, "parameters"):
return {}
params = config.function.parameters
if isinstance(params, dict) and "properties" in params:
return params["properties"]
elif isinstance(params, dict):
return params
else:
return {}
logger.warning("Tool '%s' is not defined in the tools list.", func_name)
return {}
def _convert_param_value(
self, param_value: str, param_name: str, param_config: dict, func_name: str
) -> Any:
"""Convert parameter value based on its type in the schema."""
# Handle null value for any type
if param_value.lower() == "null":
return None
if param_name not in param_config:
if param_config != {}:
logger.warning(
"Parsed parameter '%s' is not defined in the tool "
"parameters for tool '%s', directly returning the "
"string value.",
param_name,
func_name,
)
return param_value
if (
isinstance(param_config[param_name], dict)
and "type" in param_config[param_name]
):
param_type = str(param_config[param_name]["type"]).strip().lower()
else:
param_type = "string"
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
return param_value
elif (
self.current_call_id is not None
and self.tool_call_index == entry_tool_call_index + 1
param_type.startswith("int")
or param_type.startswith("uint")
or param_type.startswith("long")
or param_type.startswith("short")
or param_type.startswith("unsigned")
):
fallback_call_id = self.current_call_id
if found_elements:
# If complete elements found, check if end events were missed
# some tags may not have been triggered
try:
new_deltas = self.deltas[initial_delta_count:]
# If this chunk contains </function>
# but didn't generate '}', then complete it
if (
fallback_call_id is not None
and self.function_end_token in xml_chunk
):
# - Added '}' (non-empty parameter ending)
# - Added '{}' (empty parameter function)
has_function_close = any(
(
td.tool_calls
and any(
(
tc.function
and tc.id == fallback_call_id
and isinstance(tc.function.arguments, str)
and (tc.function.arguments in ("}", "{}"))
)
for tc in td.tool_calls
)
return int(param_value)
except (ValueError, TypeError):
try:
float_value = float(param_value)
if float_value.is_integer():
return int(float_value)
except (ValueError, TypeError):
pass
try:
literal_value = ast.literal_eval(param_value)
if isinstance(literal_value, bool):
return int(literal_value)
if isinstance(literal_value, (int, float)):
return (
int(literal_value)
if float(literal_value).is_integer()
else literal_value
)
except (ValueError, SyntaxError, TypeError):
pass
logger.warning(
"Parsed value '%s' of parameter '%s' is not an integer "
"in tool '%s', returning raw string.",
param_value,
param_name,
func_name,
)
return param_value
elif param_type.startswith("num") or param_type.startswith("float"):
try:
float_param_value = float(param_value)
return (
float_param_value
if float_param_value - int(float_param_value) != 0
else int(float_param_value)
)
except (ValueError, TypeError):
try:
literal_value = ast.literal_eval(param_value)
if isinstance(literal_value, (int, float)):
return (
float(literal_value)
if float(literal_value) - int(float(literal_value)) != 0
else int(float(literal_value))
)
for td in new_deltas
except (ValueError, SyntaxError, TypeError):
pass
logger.warning(
"Parsed value '%s' of parameter '%s' is not a float "
"in tool '%s', returning raw string.",
param_value,
param_name,
func_name,
)
return param_value
elif param_type in ["boolean", "bool", "binary"]:
normalized_value = param_value.strip().lower()
if normalized_value in ["true", "false"]:
return normalized_value == "true"
if normalized_value in ["1", "0"]:
return normalized_value == "1"
try:
literal_value = ast.literal_eval(param_value)
if isinstance(literal_value, bool):
return literal_value
except (ValueError, SyntaxError, TypeError):
pass
logger.warning(
"Parsed value '%s' of parameter '%s' is not a boolean "
"in tool '%s', returning raw string.",
param_value,
param_name,
func_name,
)
return param_value
else:
if (
param_type in ["object", "array", "arr"]
or param_type.startswith("dict")
or param_type.startswith("list")
):
try:
param_value = json.loads(param_value)
return param_value
except (json.JSONDecodeError, TypeError, ValueError):
try:
literal_value = ast.literal_eval(param_value)
if isinstance(literal_value, (list, dict)):
return literal_value
if isinstance(literal_value, (tuple, set)):
return list(literal_value)
except (ValueError, SyntaxError, TypeError):
pass
logger.warning(
"Parsed value '%s' of parameter '%s' cannot be parsed "
"as JSON in tool '%s', returning raw string.",
param_value,
param_name,
func_name,
)
if not has_function_close:
# Close potentially unclosed element
if self.current_param_name:
self._end_element("parameter")
if self.current_function_name:
self._end_element("function")
# If this chunk contains </tool_call>
# but didn't generate final empty delta, then complete it
return param_value
try:
literal_value = ast.literal_eval(param_value) # safer
if isinstance(literal_value, (tuple, set)):
return list(literal_value)
if (
fallback_call_id is not None
and self.tool_call_end_token in xml_chunk
isinstance(literal_value, (list, dict, str, int, float, bool))
or literal_value is None
):
has_toolcall_close = any(
(
td.tool_calls
and any(
(
tc.type == "function"
and tc.function
and tc.function.arguments == ""
and tc.id == fallback_call_id
)
for tc in td.tool_calls
)
)
for td in new_deltas
)
if not has_toolcall_close:
# Close potentially unclosed element
if self.current_param_name:
self._end_element("parameter")
if self.current_function_name:
self._end_element("function")
self._end_element("tool_call")
except Exception as e:
logger.warning("Error with fallback parsing: %s", e)
# Merge newly generated deltas into single response
result_delta = self._merge_new_deltas_to_single_response(
initial_delta_count
return literal_value
except (ValueError, SyntaxError, TypeError):
pass
logger.warning(
"Parsed value '%s' of parameter '%s' cannot be converted via "
"Python `ast.literal_eval()` in tool '%s', returning raw string.",
param_value,
param_name,
func_name,
)
return result_delta
return param_value
def _parse_parameters_fallback(
self,
parameters: str,
allowed_param_names: set[str] | None = None,
) -> list[tuple[str, str]]:
"""Fallback parser for malformed parameter tags."""
param_pairs: list[tuple[str, str]] = []
pos = 0
while True:
start = parameters.find(self.parameter_prefix, pos)
if start == -1:
break
name_start = start + len(self.parameter_prefix)
name_end = parameters.find(">", name_start)
if name_end == -1:
newline_idx = parameters.find("\n", name_start)
end_tag = parameters.find(self.parameter_end_token, name_start)
next_param = parameters.find(self.parameter_prefix, name_start)
candidates = [
idx for idx in [newline_idx, end_tag, next_param] if idx != -1
]
if not candidates:
break
name_end = min(candidates)
value_start = name_end
else:
value_start = name_end + 1
param_name = parameters[name_start:name_end].strip()
next_param = parameters.find(self.parameter_prefix, value_start)
end_tag = parameters.find(self.parameter_end_token, value_start)
if end_tag == -1 or (next_param != -1 and next_param < end_tag):
end = next_param if next_param != -1 else len(parameters)
pos = end
else:
end = end_tag
pos = end + len(self.parameter_end_token)
param_value = parameters[value_start:end]
if allowed_param_names is None or param_name in allowed_param_names:
param_pairs.append((param_name, param_value))
return param_pairs
def _is_valid_json_arguments(self, arguments: str) -> bool:
"""Check if arguments can be loaded as JSON."""
try:
json.loads(arguments)
except Exception:
return False
return True
def _parse_xml_function_call(
self, function_call_str: str, tools: list[ChatCompletionToolsParam] | None
) -> ToolCall | None:
# Extract function name
end_index = function_call_str.index(">")
# check empty function name
function_name = function_call_str[:end_index].strip()
if function_name.startswith("="):
function_name = function_name.lstrip("=").strip()
if not function_name or function_name.strip("'\"") == "":
logger.warning("Empty function name in tool call.")
return None
if function_name[0] in "\"'" and function_name[-1] == function_name[0]:
function_name = function_name[1:-1].strip()
if not function_name:
logger.warning("Empty function name in tool call.")
return None
param_config = self._get_arguments_config(function_name, tools)
parameters = function_call_str[end_index + 1 :]
param_dict = {}
match_texts = self.tool_call_parameter_regex.findall(parameters)
use_fallback = False
if match_texts:
for match_text in match_texts:
if self.parameter_prefix in match_text or ">" not in match_text:
use_fallback = True
break
else:
# No complete elements, check if there's unoutput text content
if self.text_content_buffer and self.tool_call_index == 0:
# Has text content but no tool_call yet, output text content
text_delta = DeltaMessage(content=self.text_content_buffer)
self._emit_delta(text_delta)
# Clear buffer to avoid duplicate output
self.text_content_buffer = ""
return text_delta
# If this chunk contains end tags but wasn't triggered by parser,
# manually complete end events
# Only execute when still on the same call as when entered,
# to prevent accidentally closing new calls
# in multi <tool_call> scenarios
if fallback_call_id is not None and (
self.function_end_token in xml_chunk
or self.tool_call_end_token in xml_chunk
):
# Close potentially unclosed element
if self.current_param_name:
self._end_element("parameter")
if self.function_end_token in xml_chunk and self.current_function_name:
self._end_element("function")
if self.tool_call_end_token in xml_chunk:
self._end_element("tool_call")
# Return the merged delta result generated by this fallback
result_delta = self._merge_new_deltas_to_single_response(
initial_delta_count
)
return result_delta
use_fallback = self.parameter_prefix in parameters
# No complete elements, return empty response
return DeltaMessage(content=None)
if use_fallback:
allowed_param_names = (
set(param_config.keys())
if isinstance(param_config, dict) and param_config
else None
)
param_pairs = self._parse_parameters_fallback(
parameters, allowed_param_names
)
else:
param_pairs = []
for match_text in match_texts:
idx = match_text.index(">")
param_name = match_text[:idx]
param_value = str(match_text[idx + 1 :])
param_pairs.append((param_name, param_value))
for param_name, param_value in param_pairs:
# Remove prefix and trailing \n
if param_value.startswith("\n"):
param_value = param_value[1:]
if param_value.endswith("\n"):
param_value = param_value[:-1]
param_dict[param_name] = self._convert_param_value(
param_value, param_name, param_config, function_name
)
def _escape_xml_special_chars(self, text: str) -> str:
"""
Escape XML special characters
Args:
text: Original text
Returns:
Escaped text
"""
xml_escapes = {
"&": "&amp;",
"<": "&lt;",
">": "&gt;",
'"': "&quot;",
"'": "&apos;",
}
try:
arguments = json.dumps(param_dict, ensure_ascii=False)
except Exception as e:
logger.warning("Error in converting parameter value: %s", e)
return None
return ToolCall(
type="function",
function=FunctionCall(name=function_name, arguments=arguments),
)
for char, escape in xml_escapes.items():
text = text.replace(char, escape)
def _get_function_calls(self, model_output: str) -> list[str]:
# Find all tool calls
raw_tool_calls = self.tool_call_complete_regex.findall(model_output)
return text
# if no closed tool_call tags found, return empty list
if len(raw_tool_calls) == 0:
return []
def _process_complete_xml_elements(self) -> bool:
"""
Process complete XML elements in buffer
raw_function_calls = []
for tool_call in raw_tool_calls:
function_matches = self.tool_call_function_regex.findall(tool_call)
raw_function_calls.extend(function_matches)
Returns:
bool: Whether complete elements were found and processed
"""
found_any = False
return raw_function_calls
while self.last_processed_pos < len(self.streaming_buffer):
# Find next complete xml element
element, end_pos = self._find_next_complete_element(self.last_processed_pos)
if element is None:
# No complete element found, wait for more data
break
def _check_format(self, model_output: str) -> bool:
"""Check if model output contains properly formatted tool call.
# Check if this element should be skipped
if self._should_skip_element(element):
self.last_processed_pos = end_pos
continue
Requirements:
1. Must have closed tool_call tags (<tool_call>...</tool_call>)
2. Must have closed function tags (<function=...</function>)
3. If parameter tags exist, they must be closed and correct
# Found complete XML element, process it
try:
preprocessed_element = self._preprocess_xml_chunk(element)
# Check if this is the first tool_call start
if (
(
preprocessed_element.strip().startswith("<tool_call>")
or preprocessed_element.strip().startswith("<function name=")
)
and self.tool_call_index == 0
) and self.text_content_buffer:
# First tool_call starts,
# output previously collected text content first
text_delta = DeltaMessage(content=self.text_content_buffer)
self._emit_delta(text_delta)
# Clear buffer for potential subsequent text content
self.text_content_buffer = ""
# If a new tool_call starts and
# there are already completed tool_calls with function name
if (
preprocessed_element.strip().startswith("<tool_call>")
and self.tool_call_index > 0
and self.current_call_id
and self.current_function_name
):
# Reset parser state but preserve generated deltas
if self.current_param_name:
self._end_element("parameter")
if self.current_function_open:
self._end_element("function")
# Output final tool_call tail delta
final_delta = DeltaMessage(
role=None,
content=None,
reasoning=None,
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments=""),
)
],
)
self._emit_delta(final_delta)
# Reset XML parser and current call state
self._reset_xml_parser_after_tool_call()
# Parse preprocessed element
self.parser.Parse(preprocessed_element, False)
found_any = True
Returns True if the format is valid, False otherwise.
"""
# Check 1: Must have closed tool_call tags
tool_call_matches = self.tool_call_complete_regex.findall(model_output)
if len(tool_call_matches) == 0:
return False
except Exception as e:
logger.warning("Error when parsing XML elements: %s", e)
# Check 2: Must have closed function tags within tool_call
has_valid_function = False
for tool_call_content in tool_call_matches:
function_matches = self.tool_call_function_regex.findall(tool_call_content)
if len(function_matches) > 0:
has_valid_function = True
# Check if there's an unclosed function tag
if (
self.tool_call_prefix in tool_call_content
and self.function_end_token not in tool_call_content
):
return False
# Update processed position
self.last_processed_pos = end_pos
if not has_valid_function:
return False
return found_any
# Check 3: If parameter tags exist, they must be closed and correct
for tool_call_content in tool_call_matches:
# Count opening and closing parameter tags
param_open_count = tool_call_content.count(self.parameter_prefix)
param_close_count = tool_call_content.count(self.parameter_end_token)
# If there are parameter tags, they must be balanced
if param_open_count > 0:
if param_open_count != param_close_count:
return False
# Check if all parameter tags are properly closed using regex
param_matches = self.tool_call_parameter_regex.findall(
tool_call_content
)
if len(param_matches) != param_open_count:
return False
def _fix_incomplete_tag_in_chunk(self, chunk: str) -> str:
"""
Fallback: fix incomplete <parameter=xxx or <function=xxx tags
(missing >)
Examples: <parameter=-C: -> <parameter=-C>, <parameter=parameter=-n:
-> <parameter=-n>
Also handles missing = cases: <function xxx> -> <function=xxx>,
<functionxxx> -> <function=xxx>
Only fixes tags that pass validation (parameter exists in tool definition)
"""
# First, handle missing = cases for function tags
chunk = self._fix_missing_equals_in_function_tag(chunk)
return True
for tag_type in ["parameter", "function"]:
pattern = f"<{tag_type}="
if pattern not in chunk:
def _wrap_missing_tool_call_tags(self, model_output: str) -> str:
"""Wrap bare <function=...></function> blocks with <tool_call> tags."""
if (
self.tool_call_prefix not in model_output
or self.function_end_token not in model_output
):
return model_output
def _wrap_bare_functions(text: str) -> str:
pos = 0
wrapped_parts: list[str] = []
while True:
func_idx = text.find(self.tool_call_prefix, pos)
if func_idx == -1:
wrapped_parts.append(text[pos:])
break
end_idx = text.find(self.function_end_token, func_idx)
if end_idx == -1:
wrapped_parts.append(text[pos:])
break
end_idx += len(self.function_end_token)
wrapped_parts.append(text[pos:func_idx])
wrapped_parts.append(self.tool_call_start_token)
wrapped_parts.append(text[func_idx:end_idx])
wrapped_parts.append(self.tool_call_end_token)
ws_idx = end_idx
while ws_idx < len(text) and text[ws_idx].isspace():
ws_idx += 1
if text.startswith(self.tool_call_end_token, ws_idx):
if ws_idx > end_idx:
wrapped_parts.append(text[end_idx:ws_idx])
pos = ws_idx + len(self.tool_call_end_token)
else:
pos = end_idx
return "".join(wrapped_parts)
tool_call_ranges = [
match.span()
for match in self.tool_call_complete_regex.finditer(model_output)
]
if not tool_call_ranges:
return _wrap_bare_functions(model_output)
wrapped_parts: list[str] = []
pos = 0
for start, end in tool_call_ranges:
if start < pos:
continue
wrapped_parts.append(_wrap_bare_functions(model_output[pos:start]))
wrapped_parts.append(model_output[start:end])
pos = end
wrapped_parts.append(_wrap_bare_functions(model_output[pos:]))
return "".join(wrapped_parts)
def _normalize_prev_arguments(self, args_value: Any) -> Any:
if isinstance(args_value, str):
try:
return json.loads(args_value)
except (TypeError, ValueError, json.JSONDecodeError):
return args_value
return args_value
def _update_prev_tool_call_state(self, tool_calls: list[ToolCall]) -> None:
self.prev_tool_call_arr.clear()
self.streamed_args_for_tool.clear()
for tool_call in tool_calls:
if not tool_call or not tool_call.function:
continue
args_value = tool_call.function.arguments
if isinstance(args_value, str):
args_json = args_value
elif args_value is None:
args_json = ""
else:
try:
args_json = json.dumps(args_value, ensure_ascii=False)
except (TypeError, ValueError):
args_json = str(args_value)
prev_args = self._normalize_prev_arguments(args_json)
self.prev_tool_call_arr.append(
{
"name": tool_call.function.name,
"arguments": prev_args,
}
)
try:
expected_args_json = json.dumps(prev_args, ensure_ascii=False)
except (TypeError, ValueError):
expected_args_json = args_json
start_idx = chunk.find(pattern)
after_tag = chunk[start_idx:]
gt_pos = after_tag.find(">")
lt_pos = after_tag.find("<", len(pattern))
# Serving may subtract the latest delta length from
# streamed_args_for_tool to detect unstreamed suffixes. Since this
# parser emits full arguments at once, store expected+actual so
# the subtraction yields expected_args_json and no resend occurs.
self.streamed_args_for_tool.append(expected_args_json + args_json)
# Skip if already well-formed
if (
gt_pos != -1
and (lt_pos == -1 or gt_pos < lt_pos)
and pattern in after_tag[:gt_pos]
):
continue
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
try:
origin_model_output = model_output
try:
# Fallback: handle outputs without <tool_call> wrapper.
origin_model_output = self._wrap_missing_tool_call_tags(
origin_model_output
)
model_output = origin_model_output
except Exception:
pass
# Use streaming-like approach: process position by position
valid_tool_calls = []
content_parts = []
processed_length = 0
while processed_length < len(model_output):
# Find next tool call start
tool_start_idx = self._find_tool_call_start(
model_output, processed_length
)
# Extract tag name (stop at space, newline, or <)
content = chunk[start_idx + len(pattern) :]
end_pos = next(
(i for i, ch in enumerate(content) if ch in (" ", "\n", "<")),
len(content),
)
tag_name = content[:end_pos]
# Case 1: No more tool calls - add remaining as content
if tool_start_idx == -1:
remaining = model_output[processed_length:]
if remaining:
content_parts.append(remaining)
break
# Case 2: Content before tool call
if tool_start_idx > processed_length:
content_before = model_output[processed_length:tool_start_idx]
# Skip whitespace-only content between tool calls
# Check if we just ended a tool call and this is pure whitespace
if processed_length > 0:
text_before = model_output[:processed_length]
if (
text_before.rstrip().endswith(self.tool_call_end_token)
and content_before.strip() == ""
):
# Skip whitespace between tool calls
pass
else:
content_parts.append(content_before)
else:
content_parts.append(content_before)
if not tag_name:
continue
# Case 3: Try to find complete tool call
tool_end_idx = self._find_first_complete_tool_call_end(
model_output, tool_start_idx
)
# Remove duplicate prefix: <parameter=parameter=xxx -> <parameter=xxx
if tag_name.startswith(f"{tag_type}="):
tag_name = tag_name[len(tag_type) + 1 :]
# If tool call is incomplete - add remaining as content and stop
if tool_end_idx == -1:
remaining = model_output[tool_start_idx:]
if remaining:
content_parts.append(remaining)
break
# Extract and try to parse the complete tool call
tool_call_text = model_output[tool_start_idx:tool_end_idx]
parsed_result = self.extract_tool_calls_basic(tool_call_text, request)
# If parsing succeeded, record the tool call(s)
if parsed_result.tools_called and parsed_result.tool_calls:
valid_tool_calls.extend(parsed_result.tool_calls)
processed_length = tool_end_idx
else:
# Parsing failed - treat this tool call as content
content_parts.append(tool_call_text)
processed_length = tool_end_idx
# Remove trailing non-alphanumeric chars (keep - and _)
while tag_name and not (
tag_name[-1].isalnum() or tag_name[-1] in ("-", "_")
):
tag_name = tag_name[:-1]
# Populate prev_tool_call_arr for serving layer to set finish_reason
self._update_prev_tool_call_state(valid_tool_calls)
if not tag_name:
continue
# Combine content parts
content = "".join(content_parts) if content_parts else None
# Validate parameter exists in tool definition
if tag_type == "parameter" and not self._validate_parameter_name(tag_name):
continue
return ExtractedToolCallInformation(
tools_called=(len(valid_tool_calls) > 0),
tool_calls=valid_tool_calls,
content=content if content else None,
)
# Apply fix
chunk = chunk.replace(
f"<{tag_type}={content[:end_pos]}", f"<{tag_type}={tag_name}>", 1
except Exception:
logger.warning("Error in extracting tool call from response.")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
return chunk
def extract_tool_calls_basic(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
model_output = self._wrap_missing_tool_call_tags(model_output)
# Quick check to avoid unnecessary processing
if not self._check_format(model_output):
tool_call_matches = self.tool_call_complete_regex.findall(model_output)
if len(tool_call_matches) == 0:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def _fix_missing_equals_in_function_tag(self, chunk: str) -> str:
"""
Fix missing = in function tags: <function xxx> or <functionxxx>
Examples:
<function execute_bash> -> <function=execute_bash>
<functionexecute_bash> -> <function=execute_bash>
Only fixes if function name exists in tool definition
"""
# already correct
if "<function=" in chunk:
return chunk
# Pattern 1: <function xxx> (with space/newline but no =)
pattern1 = r"<function\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*>"
match1 = re.search(pattern1, chunk)
if match1:
func_name = match1.group(1).strip()
# must validate function name exists before fixing
if func_name and self._validate_function_name(func_name):
original = match1.group(0)
fixed = f"<function={func_name}>"
chunk = chunk.replace(original, fixed, 1)
return chunk
# Pattern 2: <functionxxx> (no space, no =)
# only match <function followed by letters
pattern2 = r"<function([a-zA-Z_][a-zA-Z0-9_]*)\s*>"
match2 = re.search(pattern2, chunk)
if match2:
func_name = match2.group(1).strip()
# must validate function name exists before fixing
if func_name and self._validate_function_name(func_name):
original = match2.group(0)
fixed = f"<function={func_name}>"
chunk = chunk.replace(original, fixed, 1)
return chunk
return chunk
def _validate_function_name(self, func_name: str) -> bool:
"""Check if function name exists in tool definitions"""
if not self.tools:
return False
try:
function_calls = self._get_function_calls(model_output)
if len(function_calls) == 0:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
for tool in self.tools:
if (
hasattr(tool, "type")
and tool.type == "function"
and hasattr(tool, "function")
and hasattr(tool.function, "name")
and tool.function.name == func_name
):
return True
tool_calls: list[ToolCall] = []
for function_call_str in function_calls:
tool_call = self._parse_xml_function_call(
function_call_str, request.tools
)
if tool_call:
tool_calls.append(tool_call)
if not tool_calls:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
for tool_call in tool_calls:
if (
not tool_call.function
or tool_call.function.arguments is None
or not self._is_valid_json_arguments(tool_call.function.arguments)
):
logger.warning(
"Invalid JSON arguments in tool call, falling back to content."
)
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
return False
# Populate prev_tool_call_arr for serving layer to set finish_reason
self._update_prev_tool_call_state(tool_calls)
def _validate_parameter_name(self, param_name: str) -> bool:
"""Check if parameter exists in current function's tool definition"""
if not self.tools or not self.current_function_name:
return True
# Extract content before tool calls
content_index = model_output.find(self.tool_call_start_token)
content = model_output[:content_index] # .rstrip()
for tool in self.tools:
if (
hasattr(tool, "type")
and tool.type == "function"
and hasattr(tool, "function")
and hasattr(tool.function, "name")
and tool.function.name == self.current_function_name
):
if not hasattr(tool.function, "parameters"):
return True
params = tool.function.parameters
if isinstance(params, dict):
properties = params.get("properties", params)
return param_name in properties
break
return ExtractedToolCallInformation(
tools_called=(len(tool_calls) > 0),
tool_calls=tool_calls,
content=content if content else None,
)
return True
except Exception:
logger.warning("Error in extracting tool call from response.")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def _should_skip_element(self, element: str) -> bool:
"""
Determine whether an element should be skipped
def _find_first_complete_tool_call_end(self, text: str, start_pos: int = 0) -> int:
"""Find the end position of the first complete tool call.
Args:
element: Element to evaluate
text: Text to search in
start_pos: Position to start searching from
Returns:
bool: True means should skip, False means should process
"""
Position after the first </tool_call> tag, or -1 if incomplete
# If it's a tool_call XML tag, don't skip
if (
element.startswith(self.tool_call_start_token)
or element.startswith(self.function_start_token)
or element.startswith(self.parameter_start_token)
):
return False
# If currently not parsing tool calls and not blank,
# collect this text instead of skipping
# Only process other XML elements after tool_call appears,
# otherwise treat as plain text
if self.current_call_id is None and element:
# Collect text content to buffer
self.text_content_buffer += element
return True # Still skip, but content has been collected
# If currently parsing tool calls,
# this might be parameter value, don't skip
if self.current_call_id is not None:
return False
Example:
"<tool_call>...</tool_call>..." returns position after </tool_call>
"""
# Find tool call start
start_idx = text.find(self.tool_call_start_token, start_pos)
if start_idx == -1:
return -1
# Find matching end token
end_idx = text.find(
self.tool_call_end_token, start_idx + len(self.tool_call_start_token)
)
if end_idx == -1:
return -1 # Incomplete tool call
# Skip blank content
return not element
# Return position after end token
return end_idx + len(self.tool_call_end_token)
def _find_next_complete_element(self, start_pos: int) -> tuple[str | None, int]:
"""
Find next complete XML element from specified position
def _find_tool_call_start(self, text: str, start_pos: int = 0) -> int:
"""Find the start position of next tool call.
Args:
start_pos: Position to start searching
text: Text to search in
start_pos: Position to start searching from
Returns:
(Complete element string, element end position),
returns (None, start_pos) if no complete element found
Position of <tool_call> token, or -1 if not found
"""
buffer = self.streaming_buffer[start_pos:]
return text.find(self.tool_call_start_token, start_pos)
if not buffer:
return None, start_pos
def _extract_content_between_tool_calls_list(self, text: str) -> list[str]:
"""Extract content segments after each tool call.
if buffer.startswith("<"):
# Check if this is an incomplete parameter/function tag
# e.g., <parameter=-C: or <function=xxx
is_incomplete_param = (
buffer.startswith("<parameter=") and ">" not in buffer.split("\n")[0]
)
is_incomplete_func = (
buffer.startswith("<function=") and ">" not in buffer.split("\n")[0]
)
For n tool calls, returns n segments where segment[i] is the content
after tool_call[i] (before tool_call[i+1] or at the end).
if is_incomplete_param or is_incomplete_func:
# Find the corresponding closing tag
tag_type = "parameter" if is_incomplete_param else "function"
closing_tag = f"</{tag_type}>"
closing_pos = buffer.find(closing_tag)
if closing_pos != -1:
# Found closing tag, return complete element including closing tag
complete_element = buffer[: closing_pos + len(closing_tag)]
return complete_element, start_pos + closing_pos + len(closing_tag)
# Need to ensure no new < appears,
# find the nearest one between < and >
tag_end = buffer.find("<", 1)
tag_end2 = buffer.find(">", 1)
if tag_end != -1 and tag_end2 != -1:
# Next nearest is <
if tag_end < tag_end2:
return buffer[:tag_end], start_pos + tag_end
# Next nearest is >, means found XML element
else:
return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1
elif tag_end != -1:
return buffer[:tag_end], start_pos + tag_end
elif tag_end2 != -1:
return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1
else:
# If currently not parsing tool calls (entering a tool_call),
# check if starts with <tool_call> or <function=
if self.current_call_id is None:
# Check if might be start of <tool_call>
if buffer == "<tool_call>"[: len(buffer)]:
# Might be start of <tool_call>, wait for more data
return None, start_pos
elif (
buffer.startswith("<function=")
or buffer == "<function="[: len(buffer)]
):
# Might be start of <function=, wait for more data
# to get the complete function tag
return None, start_pos
else:
# Not start of <tool_call> or <function=, treat as text
return buffer, start_pos + len(buffer)
else:
# When parsing tool calls,
# wait for more data to get complete tag
return None, start_pos
else:
# Find text content (until next < or buffer end)
next_tag_pos = buffer.find("<")
if next_tag_pos != -1:
# Found text content
text_content = buffer[:next_tag_pos]
return text_content, start_pos + next_tag_pos
else:
# Buffer end is all text, process
# (no longer wait for more data)
remaining = buffer
return remaining, start_pos + len(remaining)
def _merge_new_deltas_to_single_response(self, initial_count: int) -> DeltaMessage:
"""
Merge newly generated deltas from this processing
into a single DeltaMessage
Empty or whitespace-only segments are represented as empty string "".
Args:
initial_count: Delta count before processing
text: Text containing tool calls
Returns:
Merged DeltaMessage containing all newly generated delta information
List of content segments (one per tool call)
"""
if len(self.deltas) <= initial_count:
return DeltaMessage(content=None)
content_segments = []
pos = 0
# Get newly generated deltas
new_deltas = self.deltas[initial_count:]
while True:
# Find end of current tool call
end_pos = text.find(self.tool_call_end_token, pos)
if end_pos == -1:
break
if len(new_deltas) == 1:
# Only one new delta, return directly
return new_deltas[0]
# Move past the end token
end_pos += len(self.tool_call_end_token)
# Merge multiple new deltas
merged_tool_calls: list[DeltaToolCall] = []
merged_content: str = ""
# Find start of next tool call
next_start = self._find_tool_call_start(text, end_pos)
for delta in new_deltas:
if delta.content:
merged_content += delta.content
if delta.tool_calls:
# For tool_calls, we need to intelligently merge arguments
for tool_call in delta.tool_calls:
# Find if there's already a tool_call with the same call_id
existing_call = None
for existing in merged_tool_calls:
if existing.id == tool_call.id:
existing_call = existing
break
if existing_call and existing_call.function:
# Merge to existing tool_call
if tool_call.function and tool_call.function.name:
existing_call.function.name = tool_call.function.name
if (
tool_call.function
and tool_call.function.arguments is not None
):
if existing_call.function.arguments is None:
existing_call.function.arguments = ""
# For streaming JSON parameters,
# simply concatenate in order
new_args = tool_call.function.arguments
existing_call.function.arguments += new_args
if tool_call.type:
existing_call.type = tool_call.type
else:
# Add new tool_call
merged_tool_calls.append(tool_call)
# Extract content between current end and next start (or text end)
content = text[end_pos:next_start] if next_start != -1 else text[end_pos:]
return DeltaMessage(
content=merged_content if merged_content else None,
tool_calls=merged_tool_calls,
)
def _preprocess_xml_chunk(self, chunk: str) -> str:
"""
Preprocess XML chunk, handle non-standard formats,
and escape special characters
# Store content (empty string if whitespace-only)
content_segments.append(content if content.strip() else "")
Args:
chunk: Original XML chunk
if next_start == -1:
break
pos = next_start
Returns:
Processed XML chunk
"""
return content_segments
# Check if this is a tool_call related element
is_tool_call = False
if chunk.startswith(self.tool_call_start_token) or chunk.startswith(
self.tool_call_end_token
):
is_tool_call = True
# Check for function tags (including malformed ones without =)
# <function=xxx>, </function>, <function xxx>, <functionxxx>
if (
chunk.startswith(self.function_start_token)
or chunk.startswith(self.function_end_token)
or chunk.startswith("<function ")
or re.match(r"^<function[a-zA-Z_]", chunk)
): # <functionXXX without space or =
is_tool_call = True
if chunk.startswith(self.parameter_start_token) or chunk.startswith(
self.parameter_end_token
):
is_tool_call = True
def _convert_tool_calls_to_deltas(
self, tool_calls: list[ToolCall], starting_index: int = 0
) -> list[DeltaToolCall]:
"""Convert complete ToolCall list to DeltaToolCall list.
# Fallback: fix incomplete <parameter= or <function= tags without
# closing >
# This handles cases like: <parameter=-C:\n or <parameter=-B 5\n
# Apply when parsing tool calls OR when chunk looks like a function/
# parameter tag
if (
self.current_call_id is not None
or chunk.startswith("<function")
or chunk.startswith("<parameter")
):
chunk = self._fix_incomplete_tag_in_chunk(chunk)
# Handle <function=name> format -> <function name="name">
processed = re.sub(r"<function=([^>]+)>", r'<function name="\1">', chunk)
# Handle <parameter=name> format -> <parameter name="name">
processed = re.sub(r"<parameter=([^>]+)>", r'<parameter name="\1">', processed)
original_chunk = chunk
# If in parameter value accumulation mode
if self._pre_inside_parameter:
# Parameter end: output accumulated raw text
# safely then return </parameter>
if processed.startswith("</parameter>"):
body_text = self._pre_param_buffer
# Trigger deferred parsing mode
# literal_eval+json output in end_element
self.defer_current_parameter = True
self.deferred_param_raw_value = body_text
# Clean up state
self._pre_inside_parameter = False
self._pre_param_buffer = ""
self._pre_current_param_name = None
safe_text = self._escape_xml_special_chars(body_text)
return f"{safe_text}</parameter>"
else:
# If this is the first block of content after entering parameter
# evaluate if deferred parsing is needed;
# If not needed, exit accumulation mode
# and pass through directly
if self._pre_param_buffer == "":
# Get current parameter type
param_type = (
self._get_param_type(self._pre_current_param_name)
if self._pre_current_param_name
else "string"
)
# Only these types need deferred parsing to
# handle Python literals containing single quotes
is_object_type = param_type in ["object"]
is_complex_type = (
param_type in ["array", "arr", "sequence"]
or param_type.startswith("dict")
or param_type.startswith("list")
)
Returns complete tool calls without splitting into fragments.
# Only delay when contains container symbols
# and has single quotes and is complex type
has_container_hint = (
("[" in original_chunk)
or ("{" in original_chunk)
or ("(" in original_chunk)
)
Args:
tool_calls: List of tool calls to convert
starting_index: Starting index for tool calls (default 0)
# Determine if deferred parsing is needed
need_defer = False
if is_complex_type:
# Complex type, always need deferred parsing
need_defer = True
elif (
is_object_type
and has_container_hint
and ("'" in original_chunk)
):
# Object type with container symbols
# and single quotes, need deferred parsing
need_defer = True
if not need_defer:
# No need for deferred parsing,
# exit parameter mode directly
self._pre_inside_parameter = False
return self._escape_xml_special_chars(original_chunk)
self._pre_param_buffer += original_chunk
return ""
# Parameter start: enable accumulation
if processed.startswith("<parameter name="):
m = re.match(r'<parameter name="([^"]+)">', processed)
if m:
self._pre_current_param_name = m.group(1)
self._pre_inside_parameter = True
self._pre_param_buffer = ""
return processed
# If processed doesn't contain special_token, escape processed
# This is because XML parsing encounters special characters
# and reports errors, so escaping is needed
if not is_tool_call:
processed = self._escape_xml_special_chars(processed)
return processed
def _emit_delta(self, delta: DeltaMessage):
"""Emit Delta response (streaming output)"""
self.deltas.append(delta)
def _auto_close_open_parameter_if_needed(self, incoming_tag: str | None = None):
"""Before starting to process new elements,
if there are unclosed tags from before,
automatically complete their endings to the parser.
- If there are unclosed parameters,
it's equivalent to feeding `</parameter>`
- When about to start a new function or tool_call,
if there are unclosed functions, complete `</function>`.
- When about to start a new tool_call,
if there are unclosed tool_calls, complete `</tool_call>`.
Returns:
List of DeltaToolCall with complete arguments
"""
# First close unclosed parameters
if self.current_param_name:
self._end_element("parameter")
# If about to start new function or tool_call,
# and there are unclosed functions, close function first
if incoming_tag in ("function", "tool_call") and self.current_function_name:
self._end_element("function")
# If about to start new tool_call,
# and there are unclosed tool_calls, close tool_call first
if incoming_tag == "tool_call" and self.current_call_id:
self._end_element("tool_call")
def _start_element(self, name: str, attrs: dict[str, str]):
"""Handle XML start element events"""
if name == "root":
return
if name == "tool_call":
# Before opening new tool_call,
# automatically complete previous unclosed tags
self._auto_close_open_parameter_if_needed("tool_call")
self.parameters = {}
self.current_call_id = make_tool_call_id()
self.current_param_is_first = True
self.tool_call_index += 1
elif name.startswith("function") or (name == "function"):
# If missing tool_call, manually complete
if not self.current_call_id:
self._start_element("tool_call", {})
# Before opening new function,
# automatically complete previous unclosed tags (parameter/function)
self._auto_close_open_parameter_if_needed("function")
function_name = self._extract_function_name(name, attrs)
self.current_function_name = function_name
self.current_function_open = True
if function_name:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(
name=function_name, arguments=""
),
)
]
)
self._emit_delta(delta)
elif name.startswith("parameter") or (name == "parameter"):
# If previous parameter hasn't ended normally,
# complete its end first, then start new parameter
self._auto_close_open_parameter_if_needed("parameter")
param_name = self._extract_parameter_name(name, attrs)
self.current_param_name = param_name
self.current_param_value = ""
self.current_param_value_converted = ""
self.start_quote_emitted = False # Reset start quote flag
# Only output parameter name and colon,
# don't output quotes
# decide after parameter value type is determined
if param_name:
if not self.parameters:
# First parameter
# start JSON, only output parameter name and colon
json_start = f'{{"{param_name}": '
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(
name=None, arguments=json_start
),
)
]
)
self._emit_delta(delta)
self.current_param_is_first = True
else:
# Subsequent parameters
# add comma and parameter name, no quotes
json_continue = f', "{param_name}": '
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(
name=None, arguments=json_continue
),
)
]
)
self._emit_delta(delta)
self.current_param_is_first = False
def _char_data(self, data: str):
"""Handle XML character data events"""
if data and self.current_param_name:
# If preprocessing stage determines deferred parsing is needed,
# only cache character data, no streaming output
if self.defer_current_parameter:
original_data = data
if self.should_emit_end_newline:
original_data = "\n" + original_data
self.should_emit_end_newline = False
if original_data.endswith("\n"):
self.should_emit_end_newline = True
original_data = original_data[:-1]
self.current_param_value += original_data
return
param_type = self._get_param_type(self.current_param_name)
# Check if this is the first time receiving data for this parameter
# If this is the first packet of data and starts with \n, remove \n
if not self.current_param_value and data.startswith("\n"):
data = data[1:]
# Output start quote for string type (if not already output)
if (
param_type in ["string", "str", "text", "varchar", "char", "enum"]
and not self.start_quote_emitted
):
quote_delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments='"'),
)
]
delta_tool_calls = []
for i, tool_call in enumerate[ToolCall](tool_calls):
index = starting_index + i
tool_id = self._generate_tool_call_id()
# Create complete DeltaToolCall with full arguments
delta_tool_calls.append(
DeltaToolCall(
index=index,
id=tool_id,
function=DeltaFunctionCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
),
type="function",
)
self._emit_delta(quote_delta)
self.start_quote_emitted = True
if not data:
return
original_data = data
# Delay output of trailing newline
if self.should_emit_end_newline:
original_data = "\n" + original_data
self.should_emit_end_newline = False
if original_data.endswith("\n"):
self.should_emit_end_newline = True
original_data = original_data[:-1]
self.current_param_value += original_data
# convert parameter value by param_type
converted_value = self._convert_param_value(
self.current_param_value, param_type
)
output_data = self._convert_for_json_streaming(converted_value, param_type)
delta_data = output_data[len(self.current_param_value_converted) :]
self.current_param_value_converted = output_data
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments=delta_data),
)
]
)
self._emit_delta(delta)
def _end_element(self, name: str):
"""Handle XML end element events"""
return delta_tool_calls
if name == "root":
return
# If function or tool_call ends and there are still unclosed parameters,
# complete parameter end first
if (
name.startswith("function") or name == "function" or name == "tool_call"
) and self.current_param_name:
self._auto_close_open_parameter_if_needed()
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
"""Extract tool calls from streaming text using complete parsing.
Strategy:
1. Accumulate text in buffer and track processed position
2. In each iteration, try to extract content or complete tool calls
3. Parse complete tool calls using non-streaming method
4. Convert parsed results to delta sequence
5. Handle EOS token to flush incomplete tool calls as content
"""
# Initialize state for new request
if not previous_text:
self._reset_streaming_state()
self.streaming_request = request
# Check for EOS token
has_eos = (
self.eos_token_id is not None
and delta_token_ids
and self.eos_token_id in delta_token_ids
)
if (
name.startswith("parameter") or name == "parameter"
) and self.current_param_name:
# End current parameter
param_name = self.current_param_name
param_value = self.current_param_value
# If in deferred parsing mode,
# perform overall parsing on raw content
# accumulated in preprocessing stage and output once
if self.defer_current_parameter:
raw_text = (
self.deferred_param_raw_value
if self.deferred_param_raw_value
else param_value
)
parsed_value = None
output_arguments = None
try:
# If previously delayed trailing newline,
# add it back before parsing
if self.should_emit_end_newline:
raw_for_parse = raw_text + "\n"
else:
raw_for_parse = raw_text
parsed_value = ast.literal_eval(raw_for_parse)
output_arguments = json.dumps(parsed_value, ensure_ascii=False)
except Exception:
# Fallback: output as string as-is
output_arguments = json.dumps(raw_text, ensure_ascii=False)
parsed_value = raw_text
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(
name=None, arguments=output_arguments
),
)
]
# If no delta text, check if we need to return empty delta for finish_reason
if not delta_text and not has_eos:
# Check if this is an EOS token after all tool calls are complete
if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids:
# Count complete tool calls
complete_calls = len(
self.tool_call_complete_regex.findall(current_text)
)
self._emit_delta(delta)
# Clean up and store
self.should_emit_end_newline = False
self.parameters[param_name] = parsed_value
self.current_param_name = None
self.current_param_value = ""
self.current_param_value_converted = ""
self.start_quote_emitted = False
self.defer_current_parameter = False
self.deferred_param_raw_value = ""
return
param_type = self._get_param_type(param_name)
# convert complete parameter value by param_type
converted_value = self._convert_param_value(param_value, param_type)
# Decide whether to add end quote based on parameter type
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
# For empty string parameters, need special handling
if not param_value and not self.start_quote_emitted:
# No start quote output,
# directly output complete empty string
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments='""'),
)
]
)
self._emit_delta(delta)
else:
# Non-empty parameter value, output end quote
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments='"'),
)
]
)
self._emit_delta(delta)
self.should_emit_end_newline = False
# Store converted value
self.parameters[param_name] = converted_value
self.current_param_name = None
self.current_param_value = ""
self.current_param_value_converted = ""
self.start_quote_emitted = False
elif name.startswith("function") or name == "function":
# if there are parameters, close JSON object
if self.parameters:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments="}"),
)
]
)
self._emit_delta(delta)
# return empty object
else:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments="{}"),
)
]
)
self._emit_delta(delta)
self.current_function_open = False
self.current_function_name = (
None # Clear function name to prevent duplicate closing
)
elif name == "tool_call":
# Before ending tool_call,
# ensure function is closed to complete missing right brace
if self.current_function_open:
# If there are still unclosed parameters, close them first
if self.current_param_name:
self._end_element("parameter")
# Close function, ensure output '}' or '{}'
self._end_element("function")
# Final Delta
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments=""),
)
]
)
self._emit_delta(delta)
# If we have completed tool calls and populated prev_tool_call_arr
if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
# Check if all tool calls are closed
open_calls = current_text.count(
self.tool_call_start_token
) - current_text.count(self.tool_call_end_token)
if open_calls == 0:
# Return empty delta for finish_reason processing
return DeltaMessage(content="")
return None
# Check if there's text content to output (between tool_calls)
if self.text_content_buffer.strip():
text_delta = DeltaMessage(content=self.text_content_buffer)
self._emit_delta(text_delta)
# Process all available content
accumulated_deltas: list[DeltaMessage] = []
self._reset_xml_parser_after_tool_call()
while self._has_unprocessed_content(current_text):
# Try to process next chunk (content or tool call)
delta = self._process_next_chunk(current_text)
def setup_parser(self):
"""Set up XML parser event handlers"""
self.parser.buffer_text = True
self.parser.StartElementHandler = self._start_element
self.parser.EndElementHandler = self._end_element
self.parser.CharacterDataHandler = self._char_data
if delta is None:
# Cannot proceed further, need more tokens
break
def set_tools(self, tools: list[ChatCompletionToolsParam] | None):
"""Set tool configuration information"""
self.tools = tools
# Accumulate deltas
if isinstance(delta, list):
accumulated_deltas.extend(delta)
else:
accumulated_deltas.append(delta)
# Handle EOS: flush any remaining incomplete tool calls as content
if has_eos:
remaining_delta = self._flush_remaining_content(current_text)
if remaining_delta:
accumulated_deltas.append(remaining_delta)
# If no remaining content but we have tool calls, return empty delta
elif len(self.prev_tool_call_arr) > 0:
# Check if all tool calls are closed
open_calls = current_text.count(
self.tool_call_start_token
) - current_text.count(self.tool_call_end_token)
if open_calls == 0:
accumulated_deltas.append(DeltaMessage(content=""))
# Return results
return self._format_delta_result(accumulated_deltas)
def _has_unprocessed_content(self, current_text: str) -> bool:
"""Check if there's unprocessed content in the buffer."""
return self._processed_length < len(current_text)
def _process_next_chunk(
self, current_text: str
) -> DeltaMessage | list[DeltaMessage] | None:
"""Process next chunk: either regular content or a complete tool call.
def _extract_function_name(self, name: str, attrs: dict[str, str]) -> str | None:
"""Extract function name from various formats"""
if attrs and "name" in attrs:
return attrs["name"]
Args:
current_text: Current accumulated text
if "=" in name:
parts = name.split("=", 1)
if len(parts) == 2 and parts[0] == "function":
return parts[1]
Returns:
- DeltaMessage or list of DeltaMessage if processed successfully
- None if cannot proceed (need more tokens)
"""
# Find next tool call start
tool_start_idx = self._find_tool_call_start(
current_text, self._processed_length
)
return None
# Case 1: No tool call found - return remaining content
if tool_start_idx == -1:
return self._process_content(
current_text, self._processed_length, len(current_text)
)
def _extract_parameter_name(self, name: str, attrs: dict[str, str]) -> str | None:
"""Extract parameter name from various formats"""
if attrs and "name" in attrs:
return attrs["name"]
# Case 2: Content before tool call
if tool_start_idx > self._processed_length:
return self._process_content(
current_text, self._processed_length, tool_start_idx
)
if "=" in name:
parts = name.split("=", 1)
if len(parts) == 2 and parts[0] == "parameter":
return parts[1]
# Case 3: Tool call at current position
# Find end of the first complete tool call
tool_end_idx = self._find_first_complete_tool_call_end(
current_text, tool_start_idx
)
return None
if tool_end_idx == -1:
# Tool call incomplete, wait for more tokens
return None
# Process complete tool call
return self._process_complete_tool_calls(
current_text, tool_start_idx, tool_end_idx
)
def _process_content(
self, current_text: str, start_pos: int, end_pos: int
) -> DeltaMessage | None:
"""Process regular content (non-tool-call text).
def _get_param_type(self, param_name: str) -> str:
"""Get parameter type based on tool configuration, defaults to string
Args:
param_name: Parameter name
current_text: Current accumulated text
start_pos: Start position in buffer
end_pos: End position in buffer
Returns:
Parameter type
DeltaMessage with content if non-empty
"""
if not self.tools or not self.current_function_name:
return "string"
if start_pos >= end_pos:
return None
for tool in self.tools:
if not hasattr(tool, "type") or not (
hasattr(tool, "function") and hasattr(tool.function, "name")
):
continue
content = current_text[start_pos:end_pos]
# Check if we're between tool calls - skip whitespace
if start_pos > 0:
# Check if text before start_pos ends with </tool_call>
text_before = current_text[:start_pos]
if (
tool.type == "function"
and tool.function.name == self.current_function_name
text_before.rstrip().endswith(self.tool_call_end_token)
and content.strip() == ""
):
if not hasattr(tool.function, "parameters"):
return "string"
params = tool.function.parameters
if isinstance(params, dict) and "properties" in params:
properties = params["properties"]
if param_name in properties and isinstance(
properties[param_name], dict
):
return self.repair_param_type(
str(properties[param_name].get("type", "string"))
)
elif isinstance(params, dict) and param_name in params:
param_config = params[param_name]
if isinstance(param_config, dict):
return self.repair_param_type(
str(param_config.get("type", "string"))
)
break
return "string"
# We just ended a tool call, skip whitespace between tool calls
self._processed_length = end_pos
return None
def repair_param_type(self, param_type: str) -> str:
"""Repair unknown parameter types by treating them as string
Args:
param_type: Parameter type
# Return content if non-empty
if content:
self._processed_length = end_pos
return DeltaMessage(content=content)
Returns:
Repaired parameter type
"""
if (
param_type in ["string", "str", "text", "varchar", "char", "enum"]
or param_type.startswith("int")
or param_type.startswith("uint")
or param_type.startswith("long")
or param_type.startswith("short")
or param_type.startswith("unsigned")
or param_type.startswith("num")
or param_type.startswith("float")
or param_type in ["boolean", "bool", "binary"]
or (
param_type in ["object", "array", "arr", "sequence"]
or param_type.startswith("dict")
or param_type.startswith("list")
)
):
return param_type
else:
return "string"
# Mark as processed even if empty
self._processed_length = end_pos
return None
def _flush_remaining_content(self, current_text: str) -> DeltaMessage | None:
"""Flush any remaining unprocessed content as regular content.
def _convert_param_value(self, param_value: str, param_type: str) -> Any:
"""Convert value based on parameter type
Args:
param_value: Parameter value
param_type: Parameter type
current_text: Current accumulated text
Returns:
Converted value
Used when EOS token is encountered to handle incomplete tool calls.
"""
if param_value.lower() == "null":
if not self._has_unprocessed_content(current_text):
return None
param_type = param_type.strip().lower()
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
return param_value
elif (
param_type.startswith("int")
or param_type.startswith("uint")
or param_type.startswith("long")
or param_type.startswith("short")
or param_type.startswith("unsigned")
):
try:
return int(param_value)
except (ValueError, TypeError):
logger.warning(
"Parsed value '%s' is not an integer, degenerating to string.",
param_value,
)
return param_value
elif param_type.startswith("num") or param_type.startswith("float"):
try:
float_param_value: float = float(param_value)
return (
float_param_value
if float_param_value - int(float_param_value) != 0
else int(float_param_value)
)
except (ValueError, TypeError):
logger.warning(
"Parsed value '%s' is not a float, degenerating to string.",
param_value,
)
return param_value
elif param_type in ["boolean", "bool", "binary"]:
param_value = param_value.lower()
return param_value == "true"
else:
return param_value
remaining = current_text[self._processed_length :]
if remaining:
self._processed_length = len(current_text)
return DeltaMessage(content=remaining)
self._processed_length = len(current_text)
return None
def _format_delta_result(self, deltas: list[DeltaMessage]) -> DeltaMessage | None:
"""Format delta result for return.
Merges all deltas into a single DeltaMessage.
def _convert_for_json_streaming(self, converted_value: Any, param_type: str) -> str:
"""Convert converted_value based on
whether it's empty and if type is string
Args:
converted_value: Converted value
param_type: Parameter type
deltas: List of delta messages
Returns:
Converted string for streaming output
- None if empty
- Single merged DeltaMessage with all content and tool_calls
"""
# Check if value is empty, but exclude numeric 0
if converted_value is None or converted_value == "":
return ""
if not deltas:
return None
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
# String type, remove double quotes
return json.dumps(converted_value, ensure_ascii=False)[1:-1]
else:
# Non-string type, return complete JSON string
if not isinstance(converted_value, str):
return json.dumps(converted_value, ensure_ascii=False)
else:
return converted_value
if len(deltas) == 1:
return deltas[0]
def _reset_xml_parser_after_tool_call(self):
"""
Each tool_call is treated as a separate XML document,
so we need to reset the parser after each tool_call.
"""
# Merge multiple deltas into one
merged_content_parts = []
merged_tool_calls = []
# recreate XML parser
self.parser = ParserCreate()
self.setup_parser()
# Reset current tool_call state
if self.current_call_id:
self.last_completed_call_id = self.current_call_id
self.current_call_id = None
self.current_function_name = None
self.current_function_open = False
self.parameters = {}
self.current_param_name = None
self.current_param_value = ""
self.current_param_value_converted = ""
self.current_param_is_first = False
self.should_emit_end_newline = False
self.start_quote_emitted = False
self.text_content_buffer = ""
# Reset preprocessing and deferred parsing state
self._pre_inside_parameter = False
self._pre_param_buffer = ""
self._pre_current_param_name = None
self.defer_current_parameter = False
self.deferred_param_raw_value = ""
for delta in deltas:
if delta.content:
merged_content_parts.append(delta.content)
if delta.tool_calls:
merged_tool_calls.extend(delta.tool_calls)
# Create merged DeltaMessage
merged_content = "".join(merged_content_parts) if merged_content_parts else None
class Step3p5ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.parser = StreamingXMLToolCallParser()
# Build kwargs - only include tool_calls if non-empty
kwargs: dict[str, Any] = {"content": merged_content}
if merged_tool_calls:
kwargs["tool_calls"] = merged_tool_calls
# Add missing attributes for compatibility with serving_chat.py
self.prev_tool_call_arr: list[dict] = []
self.streamed_args_for_tool: list[str] = []
return DeltaMessage(**kwargs)
logger.info(
"vLLM Successfully import tool parser %s !", self.__class__.__name__
)
def _process_complete_tool_calls(
self, current_text: str, start_pos: int, end_pos: int
) -> list[DeltaMessage] | None:
"""Process complete tool calls and convert to delta sequence.
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
self.parser.reset_streaming_state()
# Reset tool call tracking arrays for new extraction
self.prev_tool_call_arr = []
self.streamed_args_for_tool = []
if request:
self.parser.set_tools(request.tools)
result = self.parser.parse_single_streaming_chunks(model_output)
if not result.tool_calls:
return ExtractedToolCallInformation(
tool_calls=[],
tools_called=False,
content=result.content,
)
else:
tool_calls = []
for tool_call in result.tool_calls:
if tool_call.function and tool_call.function.name:
tool_calls.append(
ToolCall(
id=tool_call.id,
type=tool_call.type,
function=FunctionCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
),
)
)
Args:
current_text: Current accumulated text
start_pos: Start position (should be at <tool_call>)
end_pos: End position (after </tool_call>)
# Update tool call tracking arrays for compatibility
tool_index = (
tool_call.index
if tool_call.index is not None
else len(self.prev_tool_call_arr) - 1
)
Returns:
List of DeltaMessage if successful, None otherwise
"""
try:
# Extract text segment containing complete tool call(s)
text_to_parse = current_text[start_pos:end_pos]
# Ensure we have enough entries in our tracking arrays
while len(self.prev_tool_call_arr) <= tool_index:
self.prev_tool_call_arr.append({"name": "", "arguments": ""})
while len(self.streamed_args_for_tool) <= tool_index:
self.streamed_args_for_tool.append("")
# Parse using non-streaming method
result = self.extract_tool_calls_basic(
text_to_parse, self.streaming_request
)
# Update tool call information
self.prev_tool_call_arr[tool_index]["name"] = (
tool_call.function.name
)
self.prev_tool_call_arr[tool_index]["arguments"] = (
tool_call.function.arguments
)
# Case 1: Successfully parsed tool calls
if result.tools_called and result.tool_calls:
# Note: Due to _find_first_complete_tool_call_end, we typically
# process only one tool call at a time
# but we can also process multiple tool calls below
deltas = self._build_tool_call_deltas(result.tool_calls, text_to_parse)
self._update_state_after_tool_calls(result.tool_calls, end_pos)
return deltas if deltas else None
# Case 2: Parsing failed - treat as regular content
self._processed_length = end_pos
return [DeltaMessage(content=text_to_parse)]
except Exception as e:
# Exception during parsing - treat as content
logger.debug("Failed to parse tool calls: %s, treating as content", e)
self._processed_length = end_pos
failed_text = current_text[start_pos:end_pos]
return [DeltaMessage(content=failed_text)] if failed_text else None
def _build_tool_call_deltas(
self, tool_calls: list[ToolCall], parsed_text: str
) -> list[DeltaMessage]:
"""Build delta messages from parsed tool calls with interleaved content.
# Update streamed arguments
if tool_call.function.arguments:
self.streamed_args_for_tool[tool_index] = (
tool_call.function.arguments
)
Args:
tool_calls: List of parsed tool calls
parsed_text: Original text that was parsed
return ExtractedToolCallInformation(
tool_calls=tool_calls,
tools_called=len(tool_calls) > 0,
content=result.content,
)
Returns:
List of DeltaMessage with tool calls and content interleaved
"""
# Extract content segments between tool calls
content_segments = self._extract_content_between_tool_calls_list(parsed_text)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
if not previous_text:
self.parser.reset_streaming_state()
# Reset tool call tracking arrays for new streaming session
self.prev_tool_call_arr = []
self.streamed_args_for_tool = []
if request:
self.parser.set_tools(request.tools)
# Model sometimes outputs separately causing delta_text to be empty.
# If there were tool_calls before and all current tool_calls have ended,
# return an empty tool_call for outer streaming output
# to correctly output tool_call field
if not delta_text and delta_token_ids:
open_calls = current_text.count(
self.parser.tool_call_start_token
) - current_text.count(self.parser.tool_call_end_token)
if (
open_calls == 0
and self.parser.tool_call_index > 0
or not self.parser.tool_call_index
and current_text
):
return DeltaMessage(content="")
return None
# Convert all tool calls to DeltaToolCall list
delta_tool_calls = self._convert_tool_calls_to_deltas(
tool_calls, self._tool_call_index
)
# Parse the delta text and get the result
result = self.parser.parse_single_streaming_chunks(delta_text)
# Update tool call tracking arrays based on incremental parsing results
if result and result.tool_calls:
for tool_call in result.tool_calls:
if tool_call.function:
tool_index = (
tool_call.index
if tool_call.index is not None
else len(self.prev_tool_call_arr) - 1
)
# Merge all content segments into a single string
merged_content = "".join(content_segments)
# Ensure we have enough entries in our tracking arrays
while len(self.prev_tool_call_arr) <= tool_index:
self.prev_tool_call_arr.append({"name": "", "arguments": ""})
while len(self.streamed_args_for_tool) <= tool_index:
self.streamed_args_for_tool.append("")
# Return a single DeltaMessage with all tool calls and content
# Build kwargs - only include non-empty fields
kwargs: dict[str, Any] = {}
if merged_content:
kwargs["content"] = merged_content
if delta_tool_calls:
kwargs["tool_calls"] = delta_tool_calls
# Update tool name if provided
if tool_call.function.name:
self.prev_tool_call_arr[tool_index]["name"] = (
tool_call.function.name
)
# Only return DeltaMessage if we have content or tool_calls
if kwargs:
return [DeltaMessage(**kwargs)]
else:
return []
# Update arguments incrementally
if tool_call.function.arguments is not None:
# Concatenate the incremental arguments
# to the existing streamed arguments
self.prev_tool_call_arr[tool_index]["arguments"] += (
tool_call.function.arguments
)
self.streamed_args_for_tool[tool_index] += (
tool_call.function.arguments
)
return result
def _update_state_after_tool_calls(
self, tool_calls: list[ToolCall], end_pos: int
) -> None:
"""Update internal state after processing tool calls.
def parser_should_check_for_unstreamed_tool_arg_tokens(self) -> bool:
"""
Skip the remaining_call calculation in serving_chat
Args:
tool_calls: List of processed tool calls
end_pos: End position in buffer
"""
return False
# Update processed position
self._processed_length = end_pos
# Update tool call index
self._tool_call_index += len(tool_calls)
# Update prev_tool_call_arr for finish_reason
self._update_prev_tool_call_state(tool_calls)
......@@ -514,7 +514,7 @@ class TritonAttentionImpl(AttentionImpl):
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
seq_threshold_3D=seq_threshold_3D,
seq_threshold_3D=None,
num_par_softmax_segments=num_par_softmax_segments,
softmax_segm_output=softmax_segm_output,
softmax_segm_max=softmax_segm_max,
......
......@@ -957,6 +957,7 @@ def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bo
def _get_kv_cache_groups_uniform_page_size(
vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
) -> list[KVCacheGroupSpec]:
"""
......@@ -976,6 +977,12 @@ def _get_kv_cache_groups_uniform_page_size(
The KVCacheManager allocates the block_table for each group based on its
kv_cache spec, and the model runner applies the block table to each layer
in the group.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The KVCacheSpec of each attention layer in the model
Returns:
The generated KVCacheGroupSpecs
For example:
1. A model only uses full attention. The pattern is
(num_hidden_layers * full), so there is only one group and the block table
......@@ -1062,19 +1069,28 @@ def _get_kv_cache_groups_uniform_page_size(
num_padding_layers / len(layers) * 100,
)
num_groups = cdiv(len(layers), group_size)
# In PP case, say if we have
# - stage 0: full.0, sw.0, sw.1
# - stage 1: full.1, sw.2, sw.3
# We should have 3 groups: (full.0, full.1), (sw.0, sw.2), (sw.1, sw.3)
# It can't be (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3) because
# the 3 groups in stage 0 will be (full.0), (sw.0, sw.1), (empty group)
# and it will be padded to (full.0, padding), (sw.0, sw.1),
# (padding, padding) to ensure the number of layers in each group is
# the same and will cause memory waste.
# To avoid this, we assign layers[i::num_groups] to the i-th group
# instead of layers[i * group_size: (i + 1) * group_size]
for i in range(num_groups):
grouped_layers.append(layers[i::num_groups])
# for support multi layer mtp, we need to
# make all mtp layers in the same group
if (
vllm_config.speculative_config is not None
and vllm_config.speculative_config.enable_multi_layers_mtp
):
for i in range(0, len(layers), group_size):
grouped_layers.append(layers[i : i + group_size])
else:
# In PP case, say if we have
# - stage 0: full.0, sw.0, sw.1
# - stage 1: full.1, sw.2, sw.3
# We should have 3 groups: (full.0, full.1), (sw.0, sw.2), (sw.1, sw.3)
# It can't be (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3) because
# the 3 groups in stage 0 will be (full.0), (sw.0, sw.1), (empty group)
# and it will be padded to (full.0, padding), (sw.0, sw.1),
# (padding, padding) to ensure the number of layers in each group is
# the same and will cause memory waste.
# To avoid this, we assign layers[i::num_groups] to the i-th group
# instead of layers[i * group_size: (i + 1) * group_size]
for i in range(num_groups):
grouped_layers.append(layers[i::num_groups])
return create_kv_cache_group_specs(kv_cache_spec, grouped_layers)
......@@ -1259,7 +1275,9 @@ def get_kv_cache_groups(
# have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page
# size.
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
return _get_kv_cache_groups_uniform_page_size(
vllm_config=vllm_config, kv_cache_spec=kv_cache_spec
)
def generate_scheduler_kv_cache_config(
......
......@@ -38,7 +38,7 @@ from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import KVCacheConfig, UniformTypeKVCacheSpecs
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import _SAMPLING_EPS
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata, SpecDecodeMetadata
from vllm.v1.spec_decode.utils import (
PADDING_SLOT_ID,
compute_new_slot_mapping,
......@@ -395,6 +395,7 @@ class SpecDecodeBaseProposer:
token_indices_to_sample: torch.Tensor | None,
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
num_rejected_tokens_gpu: torch.Tensor | None = None,
slot_mappings: dict[str, torch.Tensor]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
......@@ -64,3 +66,45 @@ class SpecDecodeMetadata:
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
)
@dataclass
class MultiLayerEagleMetadata:
# [batch_size]
cached_len: torch.Tensor | None = None
# [batch_size, layer_num]
cached_token_ids: torch.Tensor | None = None
# [batch_size, layer_num, hidden_size]
cached_hidden_states: torch.Tensor | None = None
# [batch_size, layer_num]
cached_slot_mappings: torch.Tensor | None = None
# [batch_size, layer_num]
cached_positions: torch.Tensor | None = None
@classmethod
def make_dummy(
cls,
layer_num: int,
hidden_size: int,
device: torch.device,
) -> "MultiLayerEagleMetadata":
cached_len = torch.zeros((1), dtype=torch.int64, device=device)
cached_token_ids = torch.zeros(
(1, layer_num), dtype=torch.int32, device=device
)
cached_hidden_states = torch.zeros(
(1, layer_num, hidden_size), dtype=torch.float32, device=device
)
cached_slot_mappings = torch.zeros(
(1, layer_num), dtype=torch.int64, device=device
)
cached_positions = torch.zeros(
(1, layer_num), dtype=torch.int64, device=device
)
return cls(
cached_len=cached_len,
cached_token_ids=cached_token_ids,
cached_hidden_states=cached_hidden_states,
cached_slot_mappings=cached_slot_mappings,
cached_positions=cached_positions,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import (
CommonAttentionMetadata,
)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata
logger = init_logger(__name__)
BLOCK_HIDDEN = 128
BLOCK_TOKENS = 128
class MultiLayerEagleProposer(EagleProposer):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
runner=None,
):
super().__init__(vllm_config, device, runner)
self.layer_num: int = getattr(
self.speculative_config.draft_model_config.hf_text_config, "n_predict", 0
)
self.num_speculative_tokens: int = (
self.speculative_config.num_speculative_tokens
)
if self.num_speculative_tokens != self.layer_num:
logger.warning_once(
"For multi_layer_eagle, num_speculative_tokens "
"does not match layer_num, adjusting to layer_num"
)
self.num_speculative_tokens = self.layer_num
def adjust_input(
self,
batch_size: int,
target_token_ids: torch.Tensor,
target_positions: torch.Tensor,
target_hidden_states: torch.Tensor,
token_indices_to_sample: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
multi_layer_eagle_metadata: MultiLayerEagleMetadata,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any]:
MAX_SHIFT = self.layer_num
assert MAX_SHIFT > 0
prev_token_ids = target_token_ids.clone()
prev_positions = target_positions.clone()
prev_hidden_states = target_hidden_states.clone()
slot_mapping = common_attn_metadata.slot_mapping
start_token_indices = common_attn_metadata.query_start_loc[:-1]
end_token_indices = common_attn_metadata.query_start_loc[1:] - 1
pos_for_shift = (
target_positions[0] if target_positions.dim() == 2 else target_positions
)
start_token_pos = pos_for_shift[start_token_indices]
shift = torch.minimum(
end_token_indices - token_indices_to_sample,
start_token_pos,
)
shift = torch.clamp(shift, min=0)
# Metadata updates (matches the original reference implementation).
token_indices_to_sample.add_(shift)
common_attn_metadata.seq_lens.sub_(shift)
cached_lens = multi_layer_eagle_metadata.cached_len
shift = torch.minimum(shift, cached_lens)
_multi_layer_eagle_shift_and_cache(
batch_size=batch_size,
max_shift=MAX_SHIFT,
src_token_ids=target_token_ids,
dst_token_ids=prev_token_ids,
src_positions=target_positions,
dst_positions=prev_positions,
src_hidden_states=target_hidden_states,
dst_hidden_states=prev_hidden_states,
src_slot_mapping=slot_mapping,
dst_slot_mapping=slot_mapping,
start_token_indices=start_token_indices,
end_token_indices=end_token_indices,
token_indices_to_sample=token_indices_to_sample,
shift=shift,
cached_lens=cached_lens,
cached_prev_token_ids=multi_layer_eagle_metadata.cached_token_ids,
cached_prev_positions=multi_layer_eagle_metadata.cached_positions,
cached_prev_hidden_states=multi_layer_eagle_metadata.cached_hidden_states,
cached_slot_mappings=multi_layer_eagle_metadata.cached_slot_mappings,
common_attn_metadata=common_attn_metadata,
)
return prev_token_ids, prev_positions, prev_hidden_states, common_attn_metadata
def initial_inputs_for_forward(
self,
num_tokens: int,
prev_token_ids: torch.Tensor,
prev_positions: torch.Tensor,
prev_hidden_states: torch.Tensor,
next_token_ids: torch.Tensor,
token_indices_to_sample: torch.Tensor,
spec_step_idx: int = 0,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
):
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[: num_tokens - 1] = prev_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[token_indices_to_sample] = next_token_ids
self._set_positions(num_tokens, prev_positions)
self.hidden_states[:num_tokens] = prev_hidden_states[:num_tokens]
if self.supports_mm_inputs:
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
self.input_ids[:num_tokens],
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
)
else:
self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
self.input_ids[:num_tokens],
)
def draft_model_forward(
self,
num_tokens: int,
per_layer_attn_metadata: dict[str, Any],
token_indices_to_sample: torch.Tensor,
sampling_metadata: SamplingMetadata,
common_attn_metadata: CommonAttentionMetadata,
spec_step_idx: int = 0,
):
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(num_tokens)
)
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
else:
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = self.inputs_embeds[:num_input_tokens]
model_kwargs = {
"input_ids": input_ids,
"positions": self._get_positions(num_input_tokens),
"hidden_states": self.hidden_states[:num_input_tokens],
"inputs_embeds": inputs_embeds,
"spec_step_idx": spec_step_idx,
}
with set_forward_context(
per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping
),
):
last_hidden_states = self.model(**model_kwargs)
sample_hidden_states = last_hidden_states[token_indices_to_sample]
logits = self.model.compute_logits(
sample_hidden_states, spec_step_idx=spec_step_idx
)
draft_token_ids = logits.argmax(dim=-1)
return draft_token_ids, last_hidden_states
def propose(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens] or [3, num_tokens] when M-RoPE is enabled
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
token_indices_to_sample: torch.Tensor | None,
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
num_rejected_tokens_gpu: torch.Tensor | None = None,
slot_mappings: dict[str, torch.Tensor]
| list[dict[str, torch.Tensor]]
| None = None,
) -> torch.Tensor:
assert self.method == "mtp"
assert self.runner is not None
assert multi_layer_eagle_metadata is not None
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
if token_indices_to_sample is None:
token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1
prev_token_ids, prev_positions, prev_hidden_states, common_attn_metadata = (
self.adjust_input(
batch_size=batch_size,
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
token_indices_to_sample=token_indices_to_sample,
common_attn_metadata=common_attn_metadata,
multi_layer_eagle_metadata=multi_layer_eagle_metadata,
)
)
# Build per-layer attention metadata using draft attention groups
attn_metadata = None
for attn_group in self.draft_attn_groups:
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0
)
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self._draft_attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
# Generate the remaining draft tokens.
draft_token_ids_list: list[torch.Tensor] = []
for token_index in range(self.num_speculative_tokens):
if token_index != 0:
prev_token_ids = self.input_ids[:num_tokens].clone()
next_token_ids = draft_token_ids_list[-1].int()
self.initial_inputs_for_forward(
num_tokens=num_tokens,
prev_token_ids=prev_token_ids,
prev_positions=prev_positions,
prev_hidden_states=prev_hidden_states,
next_token_ids=next_token_ids,
token_indices_to_sample=token_indices_to_sample,
spec_step_idx=token_index,
mm_embed_inputs=mm_embed_inputs,
)
draft_token_ids, prev_hidden_states = self.draft_model_forward(
num_tokens=num_tokens,
per_layer_attn_metadata=per_layer_attn_metadata,
token_indices_to_sample=token_indices_to_sample,
sampling_metadata=sampling_metadata,
common_attn_metadata=common_attn_metadata,
spec_step_idx=token_index,
)
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
return draft_token_ids.view(-1, 1)
draft_token_ids_list.append(draft_token_ids)
# [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
def prepare_inputs(
self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: list[list[int]],
num_draft_tokens: list[int],
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding.
It updates to the common_attn_metadata to account for the rejected
tokens (and newly sampled tokens). It also returns the token indices
of the tokens that should be fed to the speculator.
"""
raise Exception(
"speculative_config.disable_padded_drafter_batch"
" is not supported now for MultiLayerEagleProposer."
)
@torch.inference_mode()
def dummy_run(
self,
num_tokens: int,
use_cudagraphs: bool = True,
is_graph_capturing: bool = False,
slot_mappings: dict[str, torch.Tensor] | None = None,
) -> None:
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(num_tokens)
)
# Make sure to use EAGLE's own buffer during cudagraph capture.
if (
self._draft_attn_layer_names
and slot_mappings is not None
and next(iter(self._draft_attn_layer_names)) in slot_mappings
):
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
else:
slot_mapping_dict = slot_mappings or {}
adjust_input_kwargs = {
"batch_size": 1,
"target_token_ids": self.input_ids[:num_input_tokens],
"target_positions": self._get_positions(num_input_tokens),
"target_hidden_states": self.hidden_states[:num_input_tokens],
"token_indices_to_sample": torch.tensor(
[num_input_tokens - 1], dtype=torch.int32, device=self.device
),
"common_attn_metadata": CommonAttentionMetadata(
query_start_loc=torch.tensor(
[0, num_input_tokens], dtype=torch.int32, device=self.device
),
query_start_loc_cpu=torch.tensor(
[0, num_input_tokens], dtype=torch.int32, device="cpu"
),
seq_lens=torch.tensor(
[num_input_tokens], dtype=torch.int32, device=self.device
),
num_reqs=1,
num_actual_tokens=num_input_tokens,
max_query_len=num_input_tokens,
max_seq_len=self.max_model_len,
block_table_tensor=torch.tensor(
[], dtype=torch.int32, device=self.device
),
slot_mapping=self.arange[:num_input_tokens],
logits_indices_padded=None,
num_logits_indices=None,
causal=True,
encoder_seq_lens=None,
),
"multi_layer_eagle_metadata": MultiLayerEagleMetadata.make_dummy(
layer_num=self.layer_num,
hidden_size=self.hidden_size,
device=self.device,
),
}
# NOTE ensure the jit kernel in _adjust_input can be compiled
self.adjust_input(**adjust_input_kwargs)
for fwd_idx in range(self.layer_num):
with set_forward_context(
None,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=slot_mapping_dict,
):
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
else:
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = self.inputs_embeds[:num_input_tokens]
model_kwargs = {
"input_ids": input_ids,
"positions": self._get_positions(num_input_tokens),
"hidden_states": self.hidden_states[:num_input_tokens],
"inputs_embeds": inputs_embeds,
"spec_step_idx": fwd_idx,
}
self.model(**model_kwargs)
def _multi_layer_eagle_shift_and_cache(
*,
batch_size: int,
max_shift: int,
src_token_ids: torch.Tensor,
dst_token_ids: torch.Tensor,
src_positions: torch.Tensor,
dst_positions: torch.Tensor,
src_hidden_states: torch.Tensor,
dst_hidden_states: torch.Tensor,
src_slot_mapping: torch.Tensor,
dst_slot_mapping: torch.Tensor,
start_token_indices: torch.Tensor,
end_token_indices: torch.Tensor,
token_indices_to_sample: torch.Tensor,
shift: torch.Tensor,
cached_lens: torch.Tensor,
cached_prev_token_ids: torch.Tensor,
cached_prev_positions: torch.Tensor,
cached_prev_hidden_states: torch.Tensor,
cached_slot_mappings: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
):
if batch_size == 0:
return
assert max_shift > 0
assert cached_prev_positions.is_contiguous()
assert cached_prev_token_ids.is_contiguous()
assert cached_prev_hidden_states.is_contiguous()
assert cached_slot_mappings.is_contiguous()
assert src_hidden_states.is_contiguous()
assert dst_hidden_states.is_contiguous()
# If src/dst are the same tensor, shifting is unsafe without a separate src.
if src_slot_mapping.data_ptr() == dst_slot_mapping.data_ptr():
src_slot_mapping = src_slot_mapping.clone()
# Cache extraction for the next call.
store_start = torch.maximum(
start_token_indices,
(token_indices_to_sample + 1 - max_shift),
)
store_lens = torch.clamp(
token_indices_to_sample - store_start + 1,
min=0,
max=max_shift,
)
# Avoid device sync: query length == (end - start + 1) == diff of
# query_start_loc (CPU copy).
max_window_len = int(
(
common_attn_metadata.query_start_loc_cpu[1:]
- common_attn_metadata.query_start_loc_cpu[:-1]
)
.max()
.item()
)
num_blocks = max(1, (max_window_len + BLOCK_TOKENS - 1) // BLOCK_TOKENS)
_shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)](
src_token_ids,
dst_token_ids,
cached_prev_token_ids,
start_token_indices,
end_token_indices,
shift,
cached_lens,
store_start,
store_lens,
MAX_SHIFT=max_shift,
PADDED_SHIFT=triton.next_power_of_2(max_shift),
BLOCK_TOKENS=BLOCK_TOKENS,
)
_shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)](
src_slot_mapping,
dst_slot_mapping,
cached_slot_mappings,
start_token_indices,
end_token_indices,
shift,
cached_lens,
store_start,
store_lens,
MAX_SHIFT=max_shift,
PADDED_SHIFT=triton.next_power_of_2(max_shift),
BLOCK_TOKENS=BLOCK_TOKENS,
)
_shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)](
src_positions,
dst_positions,
cached_prev_positions,
start_token_indices,
end_token_indices,
shift,
cached_lens,
store_start,
store_lens,
MAX_SHIFT=max_shift,
PADDED_SHIFT=triton.next_power_of_2(max_shift),
BLOCK_TOKENS=BLOCK_TOKENS,
)
hidden_size = int(dst_hidden_states.shape[1])
# Hidden blocking avoids extremely large Triton tiles (and huge cubins)
# when hidden_size is large.
num_hidden_blocks = max(1, (hidden_size + BLOCK_HIDDEN - 1) // BLOCK_HIDDEN)
_shift_and_gather_hidden_kernel[(batch_size, num_blocks, num_hidden_blocks)](
src_hidden_states,
dst_hidden_states,
cached_prev_hidden_states,
start_token_indices,
end_token_indices,
shift,
cached_lens,
store_start,
store_lens,
MAX_SHIFT=max_shift,
PADDED_SHIFT=triton.next_power_of_2(max_shift),
HIDDEN_SIZE=hidden_size,
BLOCK_TOKENS=BLOCK_TOKENS,
BLOCK_HIDDEN=BLOCK_HIDDEN,
num_warps=4,
)
cached_lens.copy_(store_lens)
return
@triton.jit
def _shift_and_gather_cache_1d_kernel(
src_ptr,
dst_ptr,
cached_ptr,
start_ptr,
end_ptr,
shift_ptr,
cached_len_ptr,
store_start_ptr,
store_len_ptr,
MAX_SHIFT: tl.constexpr,
PADDED_SHIFT: tl.constexpr,
BLOCK_TOKENS: tl.constexpr,
):
# Per-sequence "shift + gather" for packed 1D arrays (token ids, positions,
# slot mappings, ...).
#
# We operate on a packed batch where each sequence (request) occupies a
# contiguous window [start, end] (inclusive) in a flattened tensor.
# For the next speculative step, we build a right-shifted version of each
# window. The shift amount can differ per sequence.
#
# For a single sequence (0-based index i within its window):
# - Prefix (i < shift):
# dst[start + i] = cached[cached_len - shift + i]
# - Body (i >= shift):
# dst[start + i] = src[start + i - shift]
#
# The vacated prefix is filled from a small per-sequence cache (up to
# MAX_SHIFT elements) that stores values from previous speculative steps.
#
# Example:
# cached_tail = [a3, a4]
# src_window = [b0, b1, b2, b3, b4]
# shift = 2
# -> dst_window = [a3, a4, b0, b1, b2]
#
# After dst is produced, we refresh cached_ptr[seq, :] with a suffix of dst
# (specified by store_start / store_len) so the next call can populate its
# prefix from cache.
pid_seq = tl.program_id(0)
pid_blk = tl.program_id(1)
start = tl.load(start_ptr + pid_seq).to(tl.int32)
end = tl.load(end_ptr + pid_seq).to(tl.int32)
shift = tl.load(shift_ptr + pid_seq).to(tl.int32)
cached_len = tl.load(cached_len_ptr + pid_seq).to(tl.int32)
assert cached_len >= shift
# get dst indices
base = pid_blk * BLOCK_TOKENS
k = tl.arange(0, BLOCK_TOKENS)
offs = base + k
dst_idx = start + offs
# get dst mask
window_len = end - start + 1
mask = offs < window_len
# load from cached
base_cached = cached_ptr + pid_seq * MAX_SHIFT
cached_idx = cached_len - shift + offs
cached_mask = offs < shift
val_cached = tl.load(base_cached + cached_idx, mask=mask & cached_mask, other=0)
# load from src
src_idx = start + offs - shift
val_src = tl.load(src_ptr + src_idx, mask=mask & ~cached_mask, other=0)
# store to dst
val = tl.where(cached_mask, val_cached, val_src)
tl.store(dst_ptr + dst_idx, val, mask=mask)
# Store into the per-sequence cache.
#
# Cache layout: [batch_size, MAX_SHIFT] (flattened). We always write the
# full MAX_SHIFT region (zero-padded when store_len < MAX_SHIFT) to keep the
# cache contiguous.
store_start = tl.load(store_start_ptr + pid_seq).to(tl.int32)
store_len = tl.load(store_len_ptr + pid_seq).to(tl.int32)
m = tl.arange(0, PADDED_SHIFT)
store_mask = m < MAX_SHIFT
dst_idx = store_start + m
val = tl.load(dst_ptr + dst_idx, mask=store_mask & (m < store_len), other=0)
tl.store(base_cached + m, val, mask=store_mask)
@triton.jit
def _shift_and_gather_hidden_kernel(
src_ptr,
dst_ptr,
cached_ptr,
start_ptr,
end_ptr,
shift_ptr,
cached_len_ptr,
store_start_ptr,
store_len_ptr,
MAX_SHIFT: tl.constexpr,
PADDED_SHIFT: tl.constexpr,
HIDDEN_SIZE: tl.constexpr,
BLOCK_TOKENS: tl.constexpr,
BLOCK_HIDDEN: tl.constexpr,
):
# Per-sequence "shift + gather" for hidden states.
#
# This kernel implements the same logical transformation as
# _shift_and_gather_cache_1d_kernel, but operates on hidden states with
# shape [num_tokens, hidden_size].
#
# Layout:
# - src_ptr / dst_ptr: packed hidden states [num_tokens, hidden_size]
# - cached_ptr: per-sequence cache [batch_size, MAX_SHIFT, hidden_size]
#
# For each sequence window [start, end] (inclusive) and its shift value, for
# 0-based index i within the window:
# - Prefix (i < shift):
# dst[start + i, :] = cached[seq, cached_len - shift + i, :]
# - Body (i >= shift):
# dst[start + i, :] = src[start + i - shift, :]
#
# We tile over tokens (BLOCK_TOKENS) and hidden dim (BLOCK_HIDDEN) to avoid
# extremely large Triton tiles when hidden_size is large. As in the 1D
# kernel, we refresh cached_ptr[seq, :, :] with a suffix of dst so the next
# call can populate its prefix from cache.
pid_seq = tl.program_id(0)
pid_blk = tl.program_id(1)
pid_hid = tl.program_id(2)
start = tl.load(start_ptr + pid_seq).to(tl.int32)
end = tl.load(end_ptr + pid_seq).to(tl.int32)
shift = tl.load(shift_ptr + pid_seq).to(tl.int32)
cached_len = tl.load(cached_len_ptr + pid_seq).to(tl.int32)
assert cached_len >= shift
# get dst indices
base = pid_blk * BLOCK_TOKENS
k = tl.arange(0, BLOCK_TOKENS)
tok_offs = base + k
dst_tok = start + tok_offs
n = pid_hid * BLOCK_HIDDEN + tl.arange(0, BLOCK_HIDDEN)
dst_ptrs = dst_ptr + dst_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
# get dst mask
window_len = end - start + 1
tok_mask = tok_offs < window_len
n_mask = n < HIDDEN_SIZE
mask = tok_mask[:, None] & n_mask[None, :]
# load from cached
base_cached = cached_ptr + pid_seq * HIDDEN_SIZE * MAX_SHIFT
cached_tok = cached_len - shift + tok_offs
cached_ptrs = base_cached + cached_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
cached_mask = tok_offs < shift
val_cached = tl.load(cached_ptrs, mask=mask & cached_mask[:, None], other=0)
# load from src
src_tok = start + tok_offs - shift
src_ptrs = src_ptr + src_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
val_src = tl.load(src_ptrs, mask=mask & ~cached_mask[:, None], other=0)
# store to dst
val = tl.where(cached_mask[:, None], val_cached, val_src)
tl.store(dst_ptrs, val, mask=mask)
# store to cached
store_start = tl.load(store_start_ptr + pid_seq).to(tl.int32)
store_len = tl.load(store_len_ptr + pid_seq).to(tl.int32)
m = tl.arange(0, PADDED_SHIFT)
m_mask = (m < MAX_SHIFT) & (m < store_len)
store_tok = store_start + m
dst_ptrs = dst_ptr + store_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
store_ptrs = base_cached + m[:, None] * HIDDEN_SIZE + n[None, :] * 1
mask = m_mask[:, None] & n_mask[None, :]
val = tl.load(dst_ptrs, mask=mask, other=0)
tl.store(store_ptrs, val, mask=mask)
......@@ -53,6 +53,13 @@ class CachedRequestState:
pooling_params: PoolingParams | None = None
pooling_states: PoolingStates | None = None
# for multi layer eagle proposer
cached_len: torch.Tensor | None = None
cached_token_ids: torch.Tensor | None = None
cached_hidden_states: torch.Tensor | None = None
cached_slot_mappings: torch.Tensor | None = None
cached_positions: torch.Tensor | None = None
def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds
......@@ -95,6 +102,8 @@ class InputBatch:
is_spec_decode: bool = False,
is_pooling_model: bool = False,
cp_kv_cache_interleave_size: int = 1,
multi_layer_eagle_num: int = 0,
hidden_size: int | None = None,
):
self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode
......@@ -223,6 +232,46 @@ class InputBatch:
)
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
# Multi layer eagle
self.multi_layer_eagle_num = multi_layer_eagle_num
if multi_layer_eagle_num > 0:
self.cached_len = torch.zeros(
(max_num_reqs,), dtype=torch.int64, device=device
)
self.cached_token_ids = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
),
dtype=torch.int32,
device=device,
)
self.cached_hidden_states = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
hidden_size,
),
dtype=torch.float,
device=device,
)
self.cached_slot_mappings = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
),
dtype=torch.int64,
device=device,
)
self.cached_positions = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
),
dtype=torch.int64,
device=device,
)
# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
self.lora_id_to_request_ids: dict[int, set[str]] = {}
......@@ -437,6 +486,13 @@ class InputBatch:
# Speculative decoding: by default 1 token is generated.
self.num_accepted_tokens_cpu[req_index] = 1
if self.multi_layer_eagle_num > 0:
self.cached_len[req_index] = request.cached_len
self.cached_token_ids[req_index] = request.cached_token_ids
self.cached_hidden_states[req_index] = request.cached_hidden_states
self.cached_slot_mappings[req_index] = request.cached_slot_mappings
self.cached_positions[req_index] = request.cached_positions
# Add request lora ID
if request.lora_request:
lora_id = request.lora_request.lora_int_id
......@@ -632,6 +688,20 @@ class InputBatch:
self.num_accepted_tokens_cpu[i1],
)
if self.multi_layer_eagle_num > 0:
self.cached_len[i1], self.cached_len[i2] = (
self.cached_len[i2],
self.cached_len[i1],
)
self.cached_token_ids[[i1, i2], ...] = self.cached_token_ids[[i2, i1], ...]
self.cached_hidden_states[[i1, i2], ...] = self.cached_hidden_states[
[i2, i1], ...
]
self.cached_slot_mappings[[i1, i2], ...] = self.cached_slot_mappings[
[i2, i1], ...
]
self.cached_positions[[i1, i2], ...] = self.cached_positions[[i2, i1], ...]
swap_dict_values(self.generators, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2)
......@@ -769,6 +839,23 @@ class InputBatch:
if bad_words_token_ids is not None:
self.bad_words_token_ids[empty_index] = bad_words_token_ids
if self.multi_layer_eagle_num > 0:
self.cached_len[empty_index] = self.cached_len[
last_req_index
]
self.cached_token_ids[empty_index] = self.cached_token_ids[
last_req_index
]
self.cached_hidden_states[empty_index] = self.cached_hidden_states[
last_req_index
]
self.cached_slot_mappings[empty_index] = self.cached_slot_mappings[
last_req_index
]
self.cached_positions[empty_index] = self.cached_positions[
last_req_index
]
# Decrement last_req_index since it is now empty.
last_req_index -= 1
......
......@@ -164,7 +164,8 @@ from vllm.v1.spec_decode.draft_model import DraftModelProposer
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata, SpecDecodeMetadata
from vllm.v1.spec_decode.multi_layer_eagle import MultiLayerEagleProposer
from vllm.v1.spec_decode.ngram_proposer_gpu import (
NgramProposerGPU,
copy_num_valid_draft_tokens,
......@@ -374,6 +375,7 @@ class ExecuteModelState(NamedTuple):
scheduler_output: "SchedulerOutput"
logits: torch.Tensor
spec_decode_metadata: SpecDecodeMetadata | None
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None
spec_decode_common_attn_metadata: CommonAttentionMetadata | None
hidden_states: torch.Tensor
sample_hidden_states: torch.Tensor
......@@ -500,6 +502,11 @@ class GPUModelRunner(
self.late_interaction_runner = LateInteractionRunner()
self.use_aux_hidden_state_outputs = False
# multi layer eagle
self.enable_multi_layer_eagle = False
self.multi_layer_eagle_num = 0
# Set up speculative decoding.
# NOTE(Jiayi): currently we put the entire draft model on
# the last PP rank. This is not ideal if there are many
......@@ -544,7 +551,17 @@ class GPUModelRunner(
elif self.speculative_config.method == "suffix":
self.drafter = SuffixDecodingProposer(self.vllm_config)
elif self.speculative_config.use_eagle():
self.drafter = EagleProposer(self.vllm_config, self.device, self)
if (
self.speculative_config.enable_multi_layers_mtp
and self.speculative_config.method == "mtp"
):
self.enable_multi_layer_eagle = True
self.drafter = MultiLayerEagleProposer(
self.vllm_config, self.device, self
)
self.multi_layer_eagle_num = self.drafter.layer_num
else:
self.drafter = EagleProposer(self.vllm_config, self.device, self)
if self.speculative_config.method == "eagle3":
self.use_aux_hidden_state_outputs = (
self.drafter.eagle3_use_aux_hidden_state
......@@ -623,6 +640,10 @@ class GPUModelRunner(
logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
is_pooling_model=self.is_pooling_model,
cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
multi_layer_eagle_num=self.multi_layer_eagle_num
if self.enable_multi_layer_eagle
else 0,
hidden_size=self.model_config.get_hidden_size(),
)
# Separate cuda stream for overlapping transfer of sampled token ids from
......@@ -1143,6 +1164,9 @@ class GPUModelRunner(
if self.uses_xdrope_dim > 0:
self._init_xdrope_positions(req_state)
if self.enable_multi_layer_eagle:
self._init_multi_layer_eagle_cache(req_state)
reqs_to_add.append(req_state)
# Track new requests for ngram_gpu full tensor copy
if is_ngram_gpu:
......@@ -1442,6 +1466,24 @@ class GPUModelRunner(
req_state.mm_features,
)
def _init_multi_layer_eagle_cache(self, req_state: CachedRequestState):
req_state.cached_len = torch.zeros(1, dtype=torch.int64, device=self.device)
req_state.cached_hidden_states = torch.zeros(
self.multi_layer_eagle_num,
self.model_config.get_hidden_size(),
dtype=self.dtype,
device=self.device,
)
req_state.cached_token_ids = torch.zeros(
self.multi_layer_eagle_num, dtype=torch.int32, device=self.device
)
req_state.cached_positions = torch.zeros(
self.multi_layer_eagle_num, dtype=torch.int64, device=self.device
)
req_state.cached_slot_mappings = torch.zeros(
self.multi_layer_eagle_num, dtype=torch.int64, device=self.device
)
def _extract_mm_kwargs(
self,
scheduler_output: "SchedulerOutput",
......@@ -1672,10 +1714,11 @@ class GPUModelRunner(
) -> tuple[
torch.Tensor,
SpecDecodeMetadata | None,
MultiLayerEagleMetadata | None,
]:
"""
:return: tuple[
logits_indices, spec_decode_metadata,
logits_indices, spec_decode_metadata, multi_layer_eagle_metadata,
]
"""
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
......@@ -1879,9 +1922,21 @@ class GPUModelRunner(
self.input_batch, num_scheduled_tokens, num_sampled_tokens
)
if self.enable_multi_layer_eagle:
multi_layer_eagle_metadata = MultiLayerEagleMetadata(
cached_len=self.input_batch.cached_len[:num_reqs],
cached_token_ids=self.input_batch.cached_token_ids[:num_reqs],
cached_hidden_states=self.input_batch.cached_hidden_states[:num_reqs],
cached_slot_mappings=self.input_batch.cached_slot_mappings[:num_reqs],
cached_positions=self.input_batch.cached_positions[:num_reqs],
)
else:
multi_layer_eagle_metadata = None
return (
logits_indices,
spec_decode_metadata,
multi_layer_eagle_metadata,
)
def _build_attention_metadata(
......@@ -3634,9 +3689,11 @@ class GPUModelRunner(
max_num_scheduled_tokens = int(num_scheduled_tokens_np.max())
num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
logits_indices, spec_decode_metadata = self._prepare_inputs(
scheduler_output,
num_scheduled_tokens_np,
logits_indices, spec_decode_metadata, multi_layer_eagle_metadata = (
self._prepare_inputs(
scheduler_output,
num_scheduled_tokens_np,
)
)
cascade_attn_prefix_lens = None
......@@ -3867,6 +3924,7 @@ class GPUModelRunner(
scheduler_output,
logits,
spec_decode_metadata,
multi_layer_eagle_metadata,
spec_decode_common_attn_metadata,
hidden_states,
sample_hidden_states,
......@@ -3905,6 +3963,7 @@ class GPUModelRunner(
scheduler_output,
logits,
spec_decode_metadata,
multi_layer_eagle_metadata,
spec_decode_common_attn_metadata,
hidden_states,
sample_hidden_states,
......@@ -3953,6 +4012,7 @@ class GPUModelRunner(
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
multi_layer_eagle_metadata,
spec_decode_common_attn_metadata,
slot_mappings,
)
......@@ -4242,6 +4302,7 @@ class GPUModelRunner(
sample_hidden_states: torch.Tensor,
aux_hidden_states: list[torch.Tensor] | None,
spec_decode_metadata: SpecDecodeMetadata | None,
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None,
common_attn_metadata: CommonAttentionMetadata,
slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None,
) -> list[list[int]] | torch.Tensor:
......@@ -4466,6 +4527,7 @@ class GPUModelRunner(
token_indices_to_sample=token_indices_to_sample,
sampling_metadata=sampling_metadata,
common_attn_metadata=common_attn_metadata,
multi_layer_eagle_metadata=multi_layer_eagle_metadata,
mm_embed_inputs=mm_embed_inputs,
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
slot_mappings=slot_mappings,
......@@ -6216,6 +6278,10 @@ class GPUModelRunner(
logitsprocs=self.input_batch.logitsprocs,
logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids,
is_pooling_model=self.is_pooling_model,
multi_layer_eagle_num=self.multi_layer_eagle_num
if self.enable_multi_layer_eagle
else 0,
hidden_size=self.model_config.get_hidden_size(),
)
assert self._init_block_sizes == block_sizes, (
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
try:
from ._version import __version__, __version_tuple__
__version__ = "0.18.1"
__version_tuple__ = (0, 18, 1)
__hcu_version__ = f'0.18.1+das.ori.dtk2604'
from vllm.version import __version__, __version_tuple__, __hcu_version__
except Exception as e:
import warnings
warnings.warn(f"Failed to read commit hash:\n{e}", RuntimeWarning, stacklevel=2)
warnings.warn(f"Failed to read commit hash:\n + str(e)",
RuntimeWarning,
stacklevel=2)
__version__ = "dev"
__version_tuple__ = (0, 0, __version__)
def _prev_minor_version_was(version_str):
"""Check whether a given version matches the previous minor version.
'''Check whether a given version matches the previous minor version.
Return True if version_str matches the previous minor version.
......@@ -21,19 +24,19 @@ def _prev_minor_version_was(version_str):
supplied version_str is '0.6'.
Used for --show-hidden-metrics-for-version.
"""
'''
# Match anything if this is a dev tree
if __version_tuple__[0:2] == (0, 0):
return True
# Note - this won't do the right thing when we release 1.0!
assert __version_tuple__[0] == 0
# assert __version_tuple__[0] == 0
assert isinstance(__version_tuple__[1], int)
return version_str == f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
def _prev_minor_version():
"""For the purpose of testing, return a previous minor version number."""
'''For the purpose of testing, return a previous minor version number.'''
# In dev tree, this will return "0.-1", but that will work fine"
assert isinstance(__version_tuple__[1], int)
return f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment