Commit 84ac1d27 authored by zhangning3's avatar zhangning3
Browse files

add step3p5 mtp3

parent b27f1671
...@@ -57,6 +57,11 @@ def parse_args(): ...@@ -57,6 +57,11 @@ def parse_args():
parser.add_argument("--tp", type=int, default=1) parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--enforce-eager", action="store_true") parser.add_argument("--enforce-eager", action="store_true")
parser.add_argument("--enable-chunked-prefill", action="store_true") parser.add_argument("--enable-chunked-prefill", action="store_true")
parser.add_argument(
"--enable-multi-layers-mtp",
action="store_true",
help="Enable multi-layer MTP (only effective when --method=mtp).",
)
parser.add_argument("--max-model-len", type=int, default=16384) parser.add_argument("--max-model-len", type=int, default=16384)
parser.add_argument("--temp", type=float, default=0) parser.add_argument("--temp", type=float, default=0)
parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-p", type=float, default=1.0)
...@@ -66,12 +71,14 @@ def parse_args(): ...@@ -66,12 +71,14 @@ def parse_args():
parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--model-dir", type=str, default=None)
parser.add_argument("--eagle-dir", type=str, default=None) parser.add_argument("--eagle-dir", type=str, default=None)
parser.add_argument("--draft-model", type=str, default=None) parser.add_argument("--draft-model", type=str, default=None)
parser.add_argument("--tokenizer-dir", type=str, default=None)
parser.add_argument("--custom-mm-prompts", action="store_true") parser.add_argument("--custom-mm-prompts", action="store_true")
parser.add_argument("--gpu-memory-utilization", type=float, default=0.9) parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
parser.add_argument("--disable-padded-drafter-batch", action="store_true") parser.add_argument("--disable-padded-drafter-batch", action="store_true")
parser.add_argument("--max-num-seqs", type=int, default=None) parser.add_argument("--max-num-seqs", type=int, default=None)
parser.add_argument("--parallel-drafting", action="store_true") parser.add_argument("--parallel-drafting", action="store_true")
parser.add_argument("--allowed-local-media-path", type=str, default="") parser.add_argument("--allowed-local-media-path", type=str, default="")
parser.add_argument("--trust-remote-code", action="store_true")
return parser.parse_args() return parser.parse_args()
...@@ -85,7 +92,11 @@ def main(args): ...@@ -85,7 +92,11 @@ def main(args):
"please specify model_dir to give a mm based model" "please specify model_dir to give a mm based model"
) )
model_dir = "meta-llama/Llama-3.1-8B-Instruct" model_dir = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_dir) tokenizer_dir = args.tokenizer_dir
if tokenizer_dir is None:
tokenizer_dir = model_dir
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
if args.custom_mm_prompts: if args.custom_mm_prompts:
prompts = llm_prompts = get_custom_mm_prompts(args.num_prompts) prompts = llm_prompts = get_custom_mm_prompts(args.num_prompts)
...@@ -141,6 +152,8 @@ def main(args): ...@@ -141,6 +152,8 @@ def main(args):
"method": "mtp", "method": "mtp",
"num_speculative_tokens": args.num_spec_tokens, "num_speculative_tokens": args.num_spec_tokens,
} }
if args.enable_multi_layers_mtp:
speculative_config["enable_multi_layers_mtp"] = True
else: else:
raise ValueError(f"unknown method: {args.method}") raise ValueError(f"unknown method: {args.method}")
......
#!/usr/bin/env bash
set -u
DEFAULT_BATCH_SIZES=(1 8 16 32 64 128)
MODEL_PATH="/module/step3.5-fp8/"
SERVED_MODEL_NAME="/module/step3.5-fp8/"
DATASET_NAME="random"
DEFAULT_OUTPUT_LEN_DECODE=4096
DEFAULT_OUTPUT_LEN_PREFILL=1
DEFAULT_ROLE="decode"
READY_CHECK_TIMEOUT=3
RESULT_DIR="benchmark_result"
print_usage() {
cat <<'USAGE'
Usage:
./scripts/step3p5_benchmark_test.sh
./scripts/step3p5_benchmark_test.sh 1,8,16,32
./scripts/step3p5_benchmark_test.sh --role prefill
./scripts/step3p5_benchmark_test.sh --role decode
./scripts/step3p5_benchmark_test.sh --role both
./scripts/step3p5_benchmark_test.sh --role prefill --output-len 1
Description:
- No argument: use default batch sizes, role=decode, output-len=4096
- Optional positional argument: batch size list (comma or space separated)
- Optional flag: --role <prefill|decode|both>
- Optional flag: --output-len <N> (must be positive integer)
- role=both 时串行执行 prefill 再 decode
- Result files are saved under:
<result_dir>/prefill (when role=prefill)
<result_dir>/decode (when role=decode)
USAGE
}
parse_batch_sizes() {
local raw_input="${1:-}"
if [[ -z "$raw_input" ]]; then
BATCH_SIZES=("${DEFAULT_BATCH_SIZES[@]}")
return
fi
raw_input="${raw_input//,/ }"
read -r -a BATCH_SIZES <<< "$raw_input"
if (( ${#BATCH_SIZES[@]} == 0 )); then
echo "[ERROR] batch size 列表为空。"
print_usage
exit 1
fi
local batch_size
for batch_size in "${BATCH_SIZES[@]}"; do
if ! [[ "$batch_size" =~ ^[1-9][0-9]*$ ]]; then
echo "[ERROR] 非法 batch size: $batch_size(必须是正整数)"
exit 1
fi
done
}
parse_role() {
local role_input="${1:-$DEFAULT_ROLE}"
if [[ "$role_input" != "prefill" && "$role_input" != "decode" && "$role_input" != "both" ]]; then
echo "[ERROR] 非法 role: $role_input(必须是 prefill、decode 或 both)"
print_usage
exit 1
fi
ROLE="$role_input"
}
parse_output_len() {
local output_len_input="$1"
if ! [[ "$output_len_input" =~ ^[1-9][0-9]*$ ]]; then
echo "[ERROR] 非法 output-len: $output_len_input(必须是正整数)"
print_usage
exit 1
fi
RANDOM_OUTPUT_LEN="$output_len_input"
}
calc_num_prompts() {
local batch_size="$1"
local num_prompts=$((batch_size + batch_size / 2))
if (( num_prompts < 16 )); then
num_prompts=16
fi
if (( num_prompts > 384 )); then
num_prompts=384
fi
if (( num_prompts < batch_size )); then
num_prompts=$batch_size
fi
echo "$num_prompts"
}
main() {
local batch_arg=""
local role_arg="$DEFAULT_ROLE"
local output_len_arg=""
while (( $# > 0 )); do
case "$1" in
-h|--help)
print_usage
exit 0
;;
--role|-r)
if [[ -z "${2:-}" ]]; then
echo "[ERROR] --role 缺少参数。"
print_usage
exit 1
fi
role_arg="$2"
shift 2
;;
--output-len|-o)
if [[ -z "${2:-}" ]]; then
echo "[ERROR] --output-len 缺少参数。"
print_usage
exit 1
fi
output_len_arg="$2"
shift 2
;;
--*)
echo "[ERROR] 未知参数: $1"
print_usage
exit 1
;;
*)
if [[ -n "$batch_arg" ]]; then
echo "[ERROR] 仅支持一个 batch size 列表参数。"
print_usage
exit 1
fi
batch_arg="$1"
shift
;;
esac
done
parse_batch_sizes "$batch_arg"
parse_role "$role_arg"
if [[ "$ROLE" == "both" ]]; then
local -a prefill_cmd=("$0")
local -a decode_cmd=("$0")
if [[ -n "$batch_arg" ]]; then
prefill_cmd+=("$batch_arg")
decode_cmd+=("$batch_arg")
fi
prefill_cmd+=("--role" "prefill")
decode_cmd+=("--role" "decode")
if [[ -n "$output_len_arg" ]]; then
decode_cmd+=("--output-len" "$output_len_arg")
fi
echo "[INFO] role=both: 将串行执行 prefill 和 decode"
echo "[INFO] step1: ${prefill_cmd[*]}"
"${prefill_cmd[@]}"
echo "[INFO] step2: ${decode_cmd[*]}"
"${decode_cmd[@]}"
echo "[INFO] role=both 执行完成。"
return 0
fi
if [[ "$ROLE" == "prefill" ]]; then
if [[ -n "$output_len_arg" && "$output_len_arg" != "$DEFAULT_OUTPUT_LEN_PREFILL" ]]; then
echo "[WARN] role=prefill 时 output-len 必须为 1,已自动覆盖为 1。"
fi
output_len_arg="$DEFAULT_OUTPUT_LEN_PREFILL"
elif [[ -z "$output_len_arg" ]]; then
output_len_arg="$DEFAULT_OUTPUT_LEN_DECODE"
fi
parse_output_len "$output_len_arg"
RESULT_SUBDIR="$RESULT_DIR/$ROLE"
mkdir -p "$RESULT_SUBDIR"
echo "[INFO] 将执行 ${#BATCH_SIZES[@]} 组 benchmark"
echo "[INFO] role: $ROLE"
echo "[INFO] random-output-len: $RANDOM_OUTPUT_LEN"
echo "[INFO] batch size 列表: ${BATCH_SIZES[*]}"
echo "[INFO] result_dir: $RESULT_SUBDIR"
local batch_size
local num_prompts
local failed_count=0
for batch_size in "${BATCH_SIZES[@]}"; do
num_prompts="$(calc_num_prompts "$batch_size")"
echo ""
echo "[INFO] 开始测试: role=$ROLE, batch_size=$batch_size, max_concurrency=$batch_size, num_prompts=$num_prompts, random_output_len=$RANDOM_OUTPUT_LEN"
if ! vllm bench serve \
--backend vllm \
--model "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--dataset-name "$DATASET_NAME" \
--random-input-len 65536 \
--random-output-len "$RANDOM_OUTPUT_LEN" \
--num-prompts "$num_prompts" \
--temperature 0 \
--max-concurrency "$batch_size" \
--ready-check-timeout "$READY_CHECK_TIMEOUT" \
--result-dir "$RESULT_SUBDIR" \
--port 8018 \
--save-result; then
echo "[WARN] batch_size=$batch_size 执行失败,继续下一个。"
failed_count=$((failed_count + 1))
else
echo "[INFO] batch_size=$batch_size 执行完成。"
fi
done
echo ""
if (( failed_count > 0 )); then
echo "[WARN] 全部执行结束,但有 $failed_count 组失败。"
exit 1
fi
echo "[INFO] 全部 benchmark 执行完成。"
}
main "$@"
...@@ -386,6 +386,57 @@ TX ...@@ -386,6 +386,57 @@ TX
assert extracted_tool_calls.tool_calls[0].function.name == "get_current_weather" assert extracted_tool_calls.tool_calls[0].function.name == "get_current_weather"
def test_extract_tool_calls_missing_function_name(step3p5_tool_parser, sample_tools):
"""Tool call without function name should fallback to content."""
model_output = (
"<tool_call><parameter=pattern>*.py</parameter></function></tool_call>"
)
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
extracted_tool_calls = step3p5_tool_parser.extract_tool_calls(
model_output, request=request
)
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
def test_extract_tool_calls_empty_function_name(step3p5_tool_parser, sample_tools):
"""Tool call with empty function name should fallback to content."""
model_output = (
"<tool_call><function=><parameter=pattern>*.py</parameter>"
"</function></tool_call>"
)
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
extracted_tool_calls = step3p5_tool_parser.extract_tool_calls(
model_output, request=request
)
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
def test_extract_tool_calls_empty_function_name_single_quotes(
step3p5_tool_parser, sample_tools
):
"""Tool call with empty function name (single quotes) should fallback."""
model_output = (
"<tool_call><function=''><parameter=pattern>*.py</parameter></tool_call>"
)
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
extracted_tool_calls = step3p5_tool_parser.extract_tool_calls(
model_output, request=request
)
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
def test_extract_tool_calls_type_conversion(step3p5_tool_parser): def test_extract_tool_calls_type_conversion(step3p5_tool_parser):
"""Test parameter type conversion based on tool schema""" """Test parameter type conversion based on tool schema"""
tools = [ tools = [
...@@ -623,7 +674,7 @@ def test_extract_tool_calls_streaming( ...@@ -623,7 +674,7 @@ def test_extract_tool_calls_streaming(
expected_tool_calls, expected_tool_calls,
expected_content, expected_content,
): ):
"""Test incremental streaming behavior including typed parameters""" """Test streaming returns complete tool calls when parsed."""
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
other_content = "" other_content = ""
...@@ -647,11 +698,10 @@ def test_extract_tool_calls_streaming( ...@@ -647,11 +698,10 @@ def test_extract_tool_calls_streaming(
tool_states[idx] = { tool_states[idx] = {
"id": None, "id": None,
"name": None, "name": None,
"arguments": "", "arguments": None,
"type": None, "type": None,
} }
# First chunk should have id, name, and type
if tool_call.id: if tool_call.id:
tool_states[idx]["id"] = tool_call.id tool_states[idx]["id"] = tool_call.id
...@@ -666,8 +716,15 @@ def test_extract_tool_calls_streaming( ...@@ -666,8 +716,15 @@ def test_extract_tool_calls_streaming(
tool_states[idx]["name"] = tool_call.function.name tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None: if tool_call.function.arguments is not None:
# Accumulate arguments incrementally # Arguments should be complete JSON when emitted.
tool_states[idx]["arguments"] += tool_call.function.arguments json.loads(tool_call.function.arguments)
if tool_states[idx]["arguments"] is None:
tool_states[idx]["arguments"] = tool_call.function.arguments
else:
assert (
tool_states[idx]["arguments"]
== tool_call.function.arguments
)
# Verify final content # Verify final content
assert other_content == (expected_content or "") # Handle None case assert other_content == (expected_content or "") # Handle None case
...@@ -682,7 +739,7 @@ def test_extract_tool_calls_streaming( ...@@ -682,7 +739,7 @@ def test_extract_tool_calls_streaming(
assert state["type"] == "function" assert state["type"] == "function"
assert state["name"] == expected_tool.function.name assert state["name"] == expected_tool.function.name
# Parse accumulated arguments # Parse arguments
arguments_str = state["arguments"] arguments_str = state["arguments"]
assert arguments_str is not None assert arguments_str is not None
actual_args = json.loads(arguments_str) actual_args = json.loads(arguments_str)
...@@ -770,7 +827,7 @@ fahrenheit ...@@ -770,7 +827,7 @@ fahrenheit
tool_states[idx] = { tool_states[idx] = {
"id": None, "id": None,
"name": None, "name": None,
"arguments": "", "arguments": None,
"type": None, "type": None,
} }
...@@ -786,7 +843,14 @@ fahrenheit ...@@ -786,7 +843,14 @@ fahrenheit
tool_states[idx]["name"] = tool_call.function.name tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None: if tool_call.function.arguments is not None:
tool_states[idx]["arguments"] += tool_call.function.arguments json.loads(tool_call.function.arguments)
if tool_states[idx]["arguments"] is None:
tool_states[idx]["arguments"] = tool_call.function.arguments
else:
assert (
tool_states[idx]["arguments"]
== tool_call.function.arguments
)
# Verify content was streamed # Verify content was streamed
assert "Let me check the weather for you:" in other_content assert "Let me check the weather for you:" in other_content
...@@ -806,62 +870,69 @@ fahrenheit ...@@ -806,62 +870,69 @@ fahrenheit
assert args["unit"] == "fahrenheit" assert args["unit"] == "fahrenheit"
def test_extract_tool_calls_streaming_incremental( def test_extract_tool_calls_streaming_missing_function_name(
step3p5_tool_parser, step3p5_tokenizer, sample_tools step3p5_tool_parser, step3p5_tokenizer, sample_tools
): ):
"""Test that streaming is truly incremental""" """Streaming: missing function name should be treated as content."""
model_output = """I'll check the weather.<tool_call> model_output = (
<function=get_current_weather> "<tool_call><parameter=pattern>*.py</parameter></function></tool_call>"
<parameter=city> )
Dallas
</parameter>
<parameter=state>
TX
</parameter>
</function>
</tool_call>"""
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
chunks = [] other_content = ""
for delta_message in stream_delta_message_generator( tool_calls = []
step3p5_tool_parser, step3p5_tokenizer, model_output, request
for delta_message in stream_delta_message_generator_from_chunks(
step3p5_tool_parser,
step3p5_tokenizer,
[
"<tool_call><parameter=pattern>",
"*.py</parameter>",
"</function></tool_call>",
],
request,
):
if delta_message.content:
other_content += delta_message.content
if delta_message.tool_calls:
tool_calls.extend(delta_message.tool_calls)
assert other_content == model_output
assert tool_calls == []
def test_extract_tool_calls_streaming_empty_function_name(
step3p5_tool_parser, step3p5_tokenizer, sample_tools
):
"""Streaming: empty function name should be treated as content."""
model_output = (
"<tool_call><function=''><parameter=pattern>*.py</parameter>"
"</function></tool_call>"
)
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
other_content = ""
tool_calls = []
for delta_message in stream_delta_message_generator_from_chunks(
step3p5_tool_parser,
step3p5_tokenizer,
[
"<tool_call><function=",
"''><parameter=pattern>*.py</parameter>",
"</function></tool_call>",
],
request,
): ):
chunks.append(delta_message) if delta_message.content:
other_content += delta_message.content
# Should have multiple chunks if delta_message.tool_calls:
assert len(chunks) > 3 tool_calls.extend(delta_message.tool_calls)
# First chunk(s) should be content assert other_content == model_output
assert chunks[0].content is not None assert tool_calls == []
assert chunks[0].tool_calls is None or chunks[0].tool_calls == []
# Should have a chunk with tool header (id, name, type)
header_found = False
for chunk in chunks:
if chunk.tool_calls and chunk.tool_calls[0].id:
header_found = True
assert chunk.tool_calls[0].function.name == "get_current_weather"
assert chunk.tool_calls[0].type == "function"
# Empty initially
assert chunk.tool_calls[0].function.arguments == ""
break
assert header_found
# Should have chunks with incremental arguments
arg_chunks = []
for chunk in chunks:
if chunk.tool_calls and chunk.tool_calls[0].function.arguments:
arg_chunks.append(chunk.tool_calls[0].function.arguments)
# Arguments should be streamed incrementally
assert len(arg_chunks) > 1
# Concatenated arguments should form valid JSON
full_args = "".join(arg_chunks)
parsed_args = json.loads(full_args)
assert parsed_args["city"] == "Dallas"
assert parsed_args["state"] == "TX"
def test_extract_tool_calls_complex_type_with_single_quote(step3p5_tool_parser): def test_extract_tool_calls_complex_type_with_single_quote(step3p5_tool_parser):
...@@ -951,7 +1022,7 @@ rectangle ...@@ -951,7 +1022,7 @@ rectangle
tool_states[idx] = { tool_states[idx] = {
"id": None, "id": None,
"name": None, "name": None,
"arguments": "", "arguments": None,
"type": None, "type": None,
} }
...@@ -967,7 +1038,14 @@ rectangle ...@@ -967,7 +1038,14 @@ rectangle
tool_states[idx]["name"] = tool_call.function.name tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None: if tool_call.function.arguments is not None:
tool_states[idx]["arguments"] += tool_call.function.arguments json.loads(tool_call.function.arguments)
if tool_states[idx]["arguments"] is None:
tool_states[idx]["arguments"] = tool_call.function.arguments
else:
assert (
tool_states[idx]["arguments"]
== tool_call.function.arguments
)
# Should have exactly two complete tool calls # Should have exactly two complete tool calls
assert len(tool_states) == 2, "Should have exactly two complete tool calls" assert len(tool_states) == 2, "Should have exactly two complete tool calls"
...@@ -1164,7 +1242,7 @@ rectangle ...@@ -1164,7 +1242,7 @@ rectangle
tool_states[idx] = { tool_states[idx] = {
"id": None, "id": None,
"name": None, "name": None,
"arguments": "", "arguments": None,
"type": None, "type": None,
} }
if tool_call.id: if tool_call.id:
...@@ -1175,7 +1253,14 @@ rectangle ...@@ -1175,7 +1253,14 @@ rectangle
if tool_call.function.name: if tool_call.function.name:
tool_states[idx]["name"] = tool_call.function.name tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None: if tool_call.function.arguments is not None:
tool_states[idx]["arguments"] += tool_call.function.arguments json.loads(tool_call.function.arguments)
if tool_states[idx]["arguments"] is None:
tool_states[idx]["arguments"] = tool_call.function.arguments
else:
assert (
tool_states[idx]["arguments"]
== tool_call.function.arguments
)
# Should have exactly two complete tool calls # Should have exactly two complete tool calls
assert len(tool_states) == 2, "Should have exactly two complete tool calls" assert len(tool_states) == 2, "Should have exactly two complete tool calls"
...@@ -1266,7 +1351,7 @@ rectangle ...@@ -1266,7 +1351,7 @@ rectangle
tool_states[idx] = { tool_states[idx] = {
"id": None, "id": None,
"name": None, "name": None,
"arguments": "", "arguments": None,
"type": None, "type": None,
} }
...@@ -1282,7 +1367,14 @@ rectangle ...@@ -1282,7 +1367,14 @@ rectangle
tool_states[idx]["name"] = tool_call.function.name tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None: if tool_call.function.arguments is not None:
tool_states[idx]["arguments"] += tool_call.function.arguments json.loads(tool_call.function.arguments)
if tool_states[idx]["arguments"] is None:
tool_states[idx]["arguments"] = tool_call.function.arguments
else:
assert (
tool_states[idx]["arguments"]
== tool_call.function.arguments
)
# Should have exactly two complete tool calls # Should have exactly two complete tool calls
assert len(tool_states) == 2, "Should have exactly two complete tool calls" assert len(tool_states) == 2, "Should have exactly two complete tool calls"
...@@ -1344,20 +1436,26 @@ rectangle""", ...@@ -1344,20 +1436,26 @@ rectangle""",
for delta_message in stream_delta_message_generator_from_chunks( for delta_message in stream_delta_message_generator_from_chunks(
step3p5_tool_parser, step3p5_tokenizer, delta_text_chunks, request step3p5_tool_parser, step3p5_tokenizer, delta_text_chunks, request
): ):
print(delta_message)
if delta_message.tool_calls: if delta_message.tool_calls:
for tool_call in delta_message.tool_calls: for tool_call in delta_message.tool_calls:
idx = tool_call.index idx = tool_call.index
if idx not in tool_states: if idx not in tool_states:
tool_states[idx] = { tool_states[idx] = {
"name": None, "name": None,
"arguments": "", "arguments": None,
} }
if tool_call.function: if tool_call.function:
if tool_call.function.name: if tool_call.function.name:
tool_states[idx]["name"] = tool_call.function.name tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None: if tool_call.function.arguments is not None:
tool_states[idx]["arguments"] += tool_call.function.arguments json.loads(tool_call.function.arguments)
if tool_states[idx]["arguments"] is None:
tool_states[idx]["arguments"] = tool_call.function.arguments
else:
assert (
tool_states[idx]["arguments"]
== tool_call.function.arguments
)
assert len(tool_states) == 2 assert len(tool_states) == 2
assert all(state["name"] for state in tool_states.values()) assert all(state["name"] for state in tool_states.values())
...@@ -1368,7 +1466,7 @@ rectangle""", ...@@ -1368,7 +1466,7 @@ rectangle""",
def test_extract_tool_calls_non_streaming_multiple_tool_calls_no_content_between( def test_extract_tool_calls_non_streaming_multiple_tool_calls_no_content_between(
step3p5_tool_parser, sample_tools step3p5_tool_parser, sample_tools
): ):
"""Test non-streaming extraction with tool calls and no content between them. """Test non-streaming extraction with multiple tool calls.
Scenario: Model outputs "hello" + tool call + tool call. Scenario: Model outputs "hello" + tool call + tool call.
Expected: "hello" as content, first tool call parsed (index=0), Expected: "hello" as content, first tool call parsed (index=0),
......
This diff is collapsed.
...@@ -80,6 +80,10 @@ class SpeculativeConfig: ...@@ -80,6 +80,10 @@ class SpeculativeConfig:
If using `ngram` method, the related configuration `prompt_lookup_max` and If using `ngram` method, the related configuration `prompt_lookup_max` and
`prompt_lookup_min` should be considered.""" `prompt_lookup_min` should be considered."""
enable_multi_layers_mtp: bool = False
"""If set to True, the MTP method will run multiple layers of MTP
speculator. If set to False, it will run only one layer of MTP speculator.
This is only effective when the method is set to `mtp`."""
draft_tensor_parallel_size: int | None = Field(default=None, ge=1) draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
"""The degree of the tensor parallelism for the draft model. Can only be 1 """The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size.""" or the same as the target model's tensor parallel size."""
...@@ -493,7 +497,10 @@ class SpeculativeConfig: ...@@ -493,7 +497,10 @@ class SpeculativeConfig:
MTPModelTypes MTPModelTypes
): ):
self.method = "mtp" self.method = "mtp"
if self.num_speculative_tokens > 1: if (
self.enable_multi_layers_mtp is False
and self.num_speculative_tokens > 1
):
logger.warning( logger.warning(
"Enabling num_speculative_tokens > 1 will run " "Enabling num_speculative_tokens > 1 will run "
"multiple times of forward on same MTP layer" "multiple times of forward on same MTP layer"
......
...@@ -166,7 +166,72 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -166,7 +166,72 @@ class AnthropicServingMessages(OpenAIServingChat):
if isinstance(msg.content, str): if isinstance(msg.content, str):
openai_msg["content"] = msg.content openai_msg["content"] = msg.content
else: else:
cls._convert_message_content(msg, openai_msg, openai_messages) # Handle complex content blocks
content_parts: list[dict[str, Any]] = []
tool_calls: list[dict[str, Any]] = []
reasoning_parts: list[str] = []
for block in msg.content:
if block.type == "text" and block.text:
content_parts.append({"type": "text", "text": block.text})
elif block.type == "image" and block.source:
content_parts.append(
{
"type": "image_url",
"image_url": {"url": block.source.get("data", "")},
}
)
elif block.type == "thinking" and block.thinking is not None:
reasoning_parts.append(block.thinking)
elif block.type == "tool_use":
# Convert tool use to function call format
tool_call = {
"id": block.id or f"call_{int(time.time())}",
"type": "function",
"function": {
"name": block.name or "",
"arguments": json.dumps(block.input or {}),
},
}
tool_calls.append(tool_call)
elif block.type == "tool_result":
if msg.role == "user":
openai_messages.append(
{
"role": "tool",
"tool_call_id": block.id or "",
"content": str(block.content)
if block.content
else "",
}
)
else:
# Assistant tool result becomes regular text
tool_result_text = (
str(block.content) if block.content else ""
)
content_parts.append(
{
"type": "text",
"text": f"Tool result: {tool_result_text}",
}
)
if reasoning_parts:
openai_msg["reasoning"] = "".join(reasoning_parts)
# Add tool calls to the message if any
if tool_calls:
openai_msg["tool_calls"] = tool_calls # type: ignore
# Add content parts if any
if content_parts:
if len(content_parts) == 1 and content_parts[0]["type"] == "text":
openai_msg["content"] = content_parts[0]["text"]
else:
openai_msg["content"] = content_parts # type: ignore
elif not tool_calls and not reasoning_parts:
continue
openai_messages.append(openai_msg) openai_messages.append(openai_msg)
...@@ -522,49 +587,75 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -522,49 +587,75 @@ class AnthropicServingMessages(OpenAIServingChat):
first_item = True first_item = True
finish_reason = None finish_reason = None
state = _ActiveBlockState() content_block_index = 0
active_block_type: str | None = None
active_block_index: int | None = None
active_block_signature: str | None = None
signature_emitted = False
active_tool_use_id: str | None = None
# Map from tool call index to tool_use_id # Map from tool call index to tool_use_id
tool_index_to_id: dict[int, str] = {} tool_index_to_id: dict[int, str] = {}
def stop_active_block(): def stop_active_block():
nonlocal active_block_type, active_block_index, content_block_index
nonlocal active_block_signature, signature_emitted, active_tool_use_id
events: list[str] = [] events: list[str] = []
if state.block_type is None: if active_block_type is None:
return events return events
if ( if (
state.block_type == "thinking" active_block_type == "thinking"
and state.block_signature is not None and active_block_signature is not None
and not state.signature_emitted and not signature_emitted
): ):
chunk = AnthropicStreamEvent( chunk = AnthropicStreamEvent(
index=state.block_index, index=active_block_index,
type="content_block_delta", type="content_block_delta",
delta=AnthropicDelta( delta=AnthropicDelta(
type="signature_delta", type="signature_delta",
signature=state.block_signature, signature=active_block_signature,
), ),
) )
data = chunk.model_dump_json(exclude_unset=True) data = chunk.model_dump_json(exclude_unset=True)
events.append(wrap_data_with_event(data, "content_block_delta")) events.append(wrap_data_with_event(data, "content_block_delta"))
state.signature_emitted = True signature_emitted = True
stop_chunk = AnthropicStreamEvent( stop_chunk = AnthropicStreamEvent(
index=state.block_index, index=active_block_index,
type="content_block_stop", type="content_block_stop",
) )
data = stop_chunk.model_dump_json(exclude_unset=True) data = stop_chunk.model_dump_json(exclude_unset=True)
events.append(wrap_data_with_event(data, "content_block_stop")) events.append(wrap_data_with_event(data, "content_block_stop"))
state.reset() active_block_type = None
state.content_block_index += 1 active_block_index = None
active_block_signature = None
signature_emitted = False
active_tool_use_id = None
content_block_index += 1
return events return events
def start_block(block: AnthropicContentBlock): def start_block(block: AnthropicContentBlock):
nonlocal active_block_type, active_block_index, content_block_index
nonlocal active_block_signature, signature_emitted, active_tool_use_id
chunk = AnthropicStreamEvent( chunk = AnthropicStreamEvent(
index=state.content_block_index, index=content_block_index,
type="content_block_start", type="content_block_start",
content_block=block, content_block=block,
) )
data = chunk.model_dump_json(exclude_unset=True) data = chunk.model_dump_json(exclude_unset=True)
event = wrap_data_with_event(data, "content_block_start") event = wrap_data_with_event(data, "content_block_start")
state.start(block) active_block_type = block.type
active_block_index = content_block_index
if block.type == "thinking":
active_block_signature = uuid.uuid4().hex
signature_emitted = False
active_tool_use_id = None
elif block.type == "tool_use":
active_block_signature = None
signature_emitted = True
active_tool_use_id = block.id
else:
active_block_signature = None
signature_emitted = True
active_tool_use_id = None
return event return event
async for item in generator: async for item in generator:
...@@ -638,7 +729,7 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -638,7 +729,7 @@ class AnthropicServingMessages(OpenAIServingChat):
if reasoning_delta == "": if reasoning_delta == "":
pass pass
else: else:
if state.block_type != "thinking": if active_block_type != "thinking":
for event in stop_active_block(): for event in stop_active_block():
yield event yield event
start_event = start_block( start_event = start_block(
...@@ -649,9 +740,9 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -649,9 +740,9 @@ class AnthropicServingMessages(OpenAIServingChat):
yield start_event yield start_event
chunk = AnthropicStreamEvent( chunk = AnthropicStreamEvent(
index=( index=(
state.block_index active_block_index
if state.block_index is not None if active_block_index is not None
else state.content_block_index else content_block_index
), ),
type="content_block_delta", type="content_block_delta",
delta=AnthropicDelta( delta=AnthropicDelta(
...@@ -666,7 +757,7 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -666,7 +757,7 @@ class AnthropicServingMessages(OpenAIServingChat):
if origin_chunk.choices[0].delta.content == "": if origin_chunk.choices[0].delta.content == "":
pass pass
else: else:
if state.block_type != "text": if active_block_type != "text":
for event in stop_active_block(): for event in stop_active_block():
yield event yield event
start_event = start_block( start_event = start_block(
...@@ -675,9 +766,9 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -675,9 +766,9 @@ class AnthropicServingMessages(OpenAIServingChat):
yield start_event yield start_event
chunk = AnthropicStreamEvent( chunk = AnthropicStreamEvent(
index=( index=(
state.block_index active_block_index
if state.block_index is not None if active_block_index is not None
else state.content_block_index else content_block_index
), ),
type="content_block_delta", type="content_block_delta",
delta=AnthropicDelta( delta=AnthropicDelta(
...@@ -702,7 +793,7 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -702,7 +793,7 @@ class AnthropicServingMessages(OpenAIServingChat):
else None else None
) )
if ( if (
state.tool_use_id != tool_call.id active_tool_use_id != tool_call.id
and tool_name is not None and tool_name is not None
): ):
for event in stop_active_block(): for event in stop_active_block():
...@@ -720,13 +811,13 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -720,13 +811,13 @@ class AnthropicServingMessages(OpenAIServingChat):
if ( if (
tool_call.function tool_call.function
and tool_call.function.arguments and tool_call.function.arguments
and state.tool_use_id == tool_call.id and active_tool_use_id == tool_call.id
): ):
chunk = AnthropicStreamEvent( chunk = AnthropicStreamEvent(
index=( index=(
state.block_index active_block_index
if state.block_index is not None if active_block_index is not None
else state.content_block_index else content_block_index
), ),
type="content_block_delta", type="content_block_delta",
delta=AnthropicDelta( delta=AnthropicDelta(
...@@ -745,13 +836,13 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -745,13 +836,13 @@ class AnthropicServingMessages(OpenAIServingChat):
tool_use_id is not None tool_use_id is not None
and tool_call.function and tool_call.function
and tool_call.function.arguments and tool_call.function.arguments
and state.tool_use_id == tool_use_id and active_tool_use_id == tool_use_id
): ):
chunk = AnthropicStreamEvent( chunk = AnthropicStreamEvent(
index=( index=(
state.block_index active_block_index
if state.block_index is not None if active_block_index is not None
else state.content_block_index else content_block_index
), ),
type="content_block_delta", type="content_block_delta",
delta=AnthropicDelta( delta=AnthropicDelta(
......
...@@ -1101,10 +1101,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1101,10 +1101,10 @@ class OpenAIServingChat(OpenAIServing):
index = 0 index = 0
if ( if (
self._should_check_for_unstreamed_tool_arg_tokens( tool_parser
delta_message, output and self._should_check_for_unstreamed_tool_arg_tokens(
delta_message, output, tool_parser
) )
and tool_parser
): ):
latest_delta_len = 0 latest_delta_len = 0
if ( if (
...@@ -1760,6 +1760,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1760,6 +1760,7 @@ class OpenAIServingChat(OpenAIServing):
self, self,
delta_message: DeltaMessage | None, delta_message: DeltaMessage | None,
output: CompletionOutput, output: CompletionOutput,
tool_parser: ToolParser | None = None,
) -> bool: ) -> bool:
""" """
Check to see if we should check for unstreamed tool arguments tokens. Check to see if we should check for unstreamed tool arguments tokens.
...@@ -1772,6 +1773,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1772,6 +1773,8 @@ class OpenAIServingChat(OpenAIServing):
# include a function that has arguments # include a function that has arguments
output.finish_reason is not None output.finish_reason is not None
and self.enable_auto_tools and self.enable_auto_tools
and tool_parser is not None
and tool_parser.parser_should_check_for_unstreamed_tool_arg_tokens()
and self.tool_parser and self.tool_parser
and delta_message and delta_message
and delta_message.tool_calls and delta_message.tool_calls
......
...@@ -262,6 +262,7 @@ def select_fp8_moe_backend( ...@@ -262,6 +262,7 @@ def select_fp8_moe_backend(
supported, reason = k_cls.is_supported_config( supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format k_cls, config, weight_key, activation_key, activation_format
) )
supported = True
if supported: if supported:
logger.info_once(_make_log_backend(backend), scope="local") logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls return backend, k_cls
......
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.layers.layernorm import GemmaRMSNorm
...@@ -40,9 +41,11 @@ class SharedHead(nn.Module): ...@@ -40,9 +41,11 @@ class SharedHead(nn.Module):
return self.norm(hidden_states) return self.norm(hidden_states)
@support_torch_compile
class Step3p5AMultiTokenPredictorLayer(nn.Module): class Step3p5AMultiTokenPredictorLayer(nn.Module):
def __init__( def __init__(
self, self,
*,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str, prefix: str,
) -> None: ) -> None:
...@@ -52,7 +55,7 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module): ...@@ -52,7 +55,7 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
self.enorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) self.enorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.hnorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) self.hnorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
self.shared_head = SharedHead(config=config, quant_config=quant_config) self.lm_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = Step3p5DecoderLayer( self.mtp_block = Step3p5DecoderLayer(
vllm_config, vllm_config,
prefix=f"{prefix}.mtp_block", prefix=f"{prefix}.mtp_block",
...@@ -64,9 +67,12 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module): ...@@ -64,9 +67,12 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
previous_hidden_states: torch.Tensor, previous_hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
embed_tokens: VocabParallelEmbedding | None = None,
spec_step_index: int = 0, spec_step_index: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
assert inputs_embeds is not None if inputs_embeds is None:
assert embed_tokens is not None
inputs_embeds = embed_tokens(input_ids)
inputs_embeds = self.enorm(inputs_embeds) inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states) previous_hidden_states = self.hnorm(previous_hidden_states)
...@@ -92,8 +98,8 @@ class Step3p5AMultiTokenPredictor(nn.Module): ...@@ -92,8 +98,8 @@ class Step3p5AMultiTokenPredictor(nn.Module):
self.layers = torch.nn.ModuleDict( self.layers = torch.nn.ModuleDict(
{ {
str(idx): Step3p5AMultiTokenPredictorLayer( str(idx): Step3p5AMultiTokenPredictorLayer(
vllm_config, vllm_config=vllm_config,
f"{prefix}.layers.{idx}", prefix=f"{prefix}.layers.{idx}",
) )
for idx in range( for idx in range(
self.mtp_start_layer_idx, self.mtp_start_layer_idx,
...@@ -112,14 +118,13 @@ class Step3p5AMultiTokenPredictor(nn.Module): ...@@ -112,14 +118,13 @@ class Step3p5AMultiTokenPredictor(nn.Module):
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0, spec_step_idx: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = spec_step_idx % self.num_mtp_layers current_step_idx = spec_step_idx % self.num_mtp_layers
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids, input_ids,
positions, positions,
previous_hidden_states, previous_hidden_states,
inputs_embeds, inputs_embeds,
self.embed_tokens,
current_step_idx, current_step_idx,
) )
...@@ -131,7 +136,7 @@ class Step3p5AMultiTokenPredictor(nn.Module): ...@@ -131,7 +136,7 @@ class Step3p5AMultiTokenPredictor(nn.Module):
current_step_idx = spec_step_idx % self.num_mtp_layers current_step_idx = spec_step_idx % self.num_mtp_layers
mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)] mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
logits = self.logits_processor( logits = self.logits_processor(
mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states) mtp_layer.lm_head.head, mtp_layer.lm_head(hidden_states)
) )
return logits return logits
...@@ -257,6 +262,7 @@ class Step3p5MTP(nn.Module): ...@@ -257,6 +262,7 @@ class Step3p5MTP(nn.Module):
name = name.replace(".transformer.", ".") name = name.replace(".transformer.", ".")
if "shared_head" in name: if "shared_head" in name:
name = name.replace("shared_head.output", "shared_head.head") name = name.replace("shared_head.output", "shared_head.head")
name = name.replace("shared_head", "lm_head")
if "embed_tokens" in name: if "embed_tokens" in name:
assert ( assert (
hasattr(self.config, "num_nextn_predict_layers") hasattr(self.config, "num_nextn_predict_layers")
......
...@@ -118,6 +118,12 @@ class ToolParser: ...@@ -118,6 +118,12 @@ class ToolParser:
"AbstractToolParser.extract_tool_calls_streaming has not been implemented!" "AbstractToolParser.extract_tool_calls_streaming has not been implemented!"
) )
def parser_should_check_for_unstreamed_tool_arg_tokens(self) -> bool:
"""
Whether to check for unstreamed tool-argument tokens in serving
"""
return True
class ToolParserManager: class ToolParserManager:
""" """
......
This diff is collapsed.
...@@ -514,7 +514,7 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -514,7 +514,7 @@ class TritonAttentionImpl(AttentionImpl):
q_descale=None, # Not supported q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape),
seq_threshold_3D=seq_threshold_3D, seq_threshold_3D=None,
num_par_softmax_segments=num_par_softmax_segments, num_par_softmax_segments=num_par_softmax_segments,
softmax_segm_output=softmax_segm_output, softmax_segm_output=softmax_segm_output,
softmax_segm_max=softmax_segm_max, softmax_segm_max=softmax_segm_max,
......
...@@ -957,6 +957,7 @@ def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bo ...@@ -957,6 +957,7 @@ def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bo
def _get_kv_cache_groups_uniform_page_size( def _get_kv_cache_groups_uniform_page_size(
vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec], kv_cache_spec: dict[str, KVCacheSpec],
) -> list[KVCacheGroupSpec]: ) -> list[KVCacheGroupSpec]:
""" """
...@@ -976,6 +977,12 @@ def _get_kv_cache_groups_uniform_page_size( ...@@ -976,6 +977,12 @@ def _get_kv_cache_groups_uniform_page_size(
The KVCacheManager allocates the block_table for each group based on its The KVCacheManager allocates the block_table for each group based on its
kv_cache spec, and the model runner applies the block table to each layer kv_cache spec, and the model runner applies the block table to each layer
in the group. in the group.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The KVCacheSpec of each attention layer in the model
Returns:
The generated KVCacheGroupSpecs
For example: For example:
1. A model only uses full attention. The pattern is 1. A model only uses full attention. The pattern is
(num_hidden_layers * full), so there is only one group and the block table (num_hidden_layers * full), so there is only one group and the block table
...@@ -1062,19 +1069,28 @@ def _get_kv_cache_groups_uniform_page_size( ...@@ -1062,19 +1069,28 @@ def _get_kv_cache_groups_uniform_page_size(
num_padding_layers / len(layers) * 100, num_padding_layers / len(layers) * 100,
) )
num_groups = cdiv(len(layers), group_size) num_groups = cdiv(len(layers), group_size)
# In PP case, say if we have # for support multi layer mtp, we need to
# - stage 0: full.0, sw.0, sw.1 # make all mtp layers in the same group
# - stage 1: full.1, sw.2, sw.3 if (
# We should have 3 groups: (full.0, full.1), (sw.0, sw.2), (sw.1, sw.3) vllm_config.speculative_config is not None
# It can't be (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3) because and vllm_config.speculative_config.enable_multi_layers_mtp
# the 3 groups in stage 0 will be (full.0), (sw.0, sw.1), (empty group) ):
# and it will be padded to (full.0, padding), (sw.0, sw.1), for i in range(0, len(layers), group_size):
# (padding, padding) to ensure the number of layers in each group is grouped_layers.append(layers[i : i + group_size])
# the same and will cause memory waste. else:
# To avoid this, we assign layers[i::num_groups] to the i-th group # In PP case, say if we have
# instead of layers[i * group_size: (i + 1) * group_size] # - stage 0: full.0, sw.0, sw.1
for i in range(num_groups): # - stage 1: full.1, sw.2, sw.3
grouped_layers.append(layers[i::num_groups]) # We should have 3 groups: (full.0, full.1), (sw.0, sw.2), (sw.1, sw.3)
# It can't be (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3) because
# the 3 groups in stage 0 will be (full.0), (sw.0, sw.1), (empty group)
# and it will be padded to (full.0, padding), (sw.0, sw.1),
# (padding, padding) to ensure the number of layers in each group is
# the same and will cause memory waste.
# To avoid this, we assign layers[i::num_groups] to the i-th group
# instead of layers[i * group_size: (i + 1) * group_size]
for i in range(num_groups):
grouped_layers.append(layers[i::num_groups])
return create_kv_cache_group_specs(kv_cache_spec, grouped_layers) return create_kv_cache_group_specs(kv_cache_spec, grouped_layers)
...@@ -1259,7 +1275,9 @@ def get_kv_cache_groups( ...@@ -1259,7 +1275,9 @@ def get_kv_cache_groups(
# have the same physical memory per block per layer. Split the layers # have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page # into groups with the same number of layers, and thus same total page
# size. # size.
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec) return _get_kv_cache_groups_uniform_page_size(
vllm_config=vllm_config, kv_cache_spec=kv_cache_spec
)
def generate_scheduler_kv_cache_config( def generate_scheduler_kv_cache_config(
......
...@@ -38,7 +38,7 @@ from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher ...@@ -38,7 +38,7 @@ from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import KVCacheConfig, UniformTypeKVCacheSpecs from vllm.v1.kv_cache_interface import KVCacheConfig, UniformTypeKVCacheSpecs
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import _SAMPLING_EPS from vllm.v1.sample.sampler import _SAMPLING_EPS
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata, SpecDecodeMetadata
from vllm.v1.spec_decode.utils import ( from vllm.v1.spec_decode.utils import (
PADDING_SLOT_ID, PADDING_SLOT_ID,
compute_new_slot_mapping, compute_new_slot_mapping,
...@@ -395,6 +395,7 @@ class SpecDecodeBaseProposer: ...@@ -395,6 +395,7 @@ class SpecDecodeBaseProposer:
token_indices_to_sample: torch.Tensor | None, token_indices_to_sample: torch.Tensor | None,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
num_rejected_tokens_gpu: torch.Tensor | None = None, num_rejected_tokens_gpu: torch.Tensor | None = None,
slot_mappings: dict[str, torch.Tensor] slot_mappings: dict[str, torch.Tensor]
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
...@@ -64,3 +66,45 @@ class SpecDecodeMetadata: ...@@ -64,3 +66,45 @@ class SpecDecodeMetadata:
bonus_logits_indices=bonus_logits_indices, bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices, logits_indices=logits_indices,
) )
@dataclass
class MultiLayerEagleMetadata:
# [batch_size]
cached_len: torch.Tensor | None = None
# [batch_size, layer_num]
cached_token_ids: torch.Tensor | None = None
# [batch_size, layer_num, hidden_size]
cached_hidden_states: torch.Tensor | None = None
# [batch_size, layer_num]
cached_slot_mappings: torch.Tensor | None = None
# [batch_size, layer_num]
cached_positions: torch.Tensor | None = None
@classmethod
def make_dummy(
cls,
layer_num: int,
hidden_size: int,
device: torch.device,
) -> "MultiLayerEagleMetadata":
cached_len = torch.zeros((1), dtype=torch.int64, device=device)
cached_token_ids = torch.zeros(
(1, layer_num), dtype=torch.int32, device=device
)
cached_hidden_states = torch.zeros(
(1, layer_num, hidden_size), dtype=torch.float32, device=device
)
cached_slot_mappings = torch.zeros(
(1, layer_num), dtype=torch.int64, device=device
)
cached_positions = torch.zeros(
(1, layer_num), dtype=torch.int64, device=device
)
return cls(
cached_len=cached_len,
cached_token_ids=cached_token_ids,
cached_hidden_states=cached_hidden_states,
cached_slot_mappings=cached_slot_mappings,
cached_positions=cached_positions,
)
This diff is collapsed.
...@@ -53,6 +53,13 @@ class CachedRequestState: ...@@ -53,6 +53,13 @@ class CachedRequestState:
pooling_params: PoolingParams | None = None pooling_params: PoolingParams | None = None
pooling_states: PoolingStates | None = None pooling_states: PoolingStates | None = None
# for multi layer eagle proposer
cached_len: torch.Tensor | None = None
cached_token_ids: torch.Tensor | None = None
cached_hidden_states: torch.Tensor | None = None
cached_slot_mappings: torch.Tensor | None = None
cached_positions: torch.Tensor | None = None
def __post_init__(self): def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds self.prompt_token_ids, self.prompt_embeds
...@@ -95,6 +102,8 @@ class InputBatch: ...@@ -95,6 +102,8 @@ class InputBatch:
is_spec_decode: bool = False, is_spec_decode: bool = False,
is_pooling_model: bool = False, is_pooling_model: bool = False,
cp_kv_cache_interleave_size: int = 1, cp_kv_cache_interleave_size: int = 1,
multi_layer_eagle_num: int = 0,
hidden_size: int | None = None,
): ):
self.is_pooling_model = is_pooling_model self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode self.is_spec_decode = is_spec_decode
...@@ -223,6 +232,46 @@ class InputBatch: ...@@ -223,6 +232,46 @@ class InputBatch:
) )
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy() self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
# Multi layer eagle
self.multi_layer_eagle_num = multi_layer_eagle_num
if multi_layer_eagle_num > 0:
self.cached_len = torch.zeros(
(max_num_reqs,), dtype=torch.int64, device=device
)
self.cached_token_ids = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
),
dtype=torch.int32,
device=device,
)
self.cached_hidden_states = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
hidden_size,
),
dtype=torch.float,
device=device,
)
self.cached_slot_mappings = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
),
dtype=torch.int64,
device=device,
)
self.cached_positions = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
),
dtype=torch.int64,
device=device,
)
# lora related # lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64) self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
self.lora_id_to_request_ids: dict[int, set[str]] = {} self.lora_id_to_request_ids: dict[int, set[str]] = {}
...@@ -437,6 +486,13 @@ class InputBatch: ...@@ -437,6 +486,13 @@ class InputBatch:
# Speculative decoding: by default 1 token is generated. # Speculative decoding: by default 1 token is generated.
self.num_accepted_tokens_cpu[req_index] = 1 self.num_accepted_tokens_cpu[req_index] = 1
if self.multi_layer_eagle_num > 0:
self.cached_len[req_index] = request.cached_len
self.cached_token_ids[req_index] = request.cached_token_ids
self.cached_hidden_states[req_index] = request.cached_hidden_states
self.cached_slot_mappings[req_index] = request.cached_slot_mappings
self.cached_positions[req_index] = request.cached_positions
# Add request lora ID # Add request lora ID
if request.lora_request: if request.lora_request:
lora_id = request.lora_request.lora_int_id lora_id = request.lora_request.lora_int_id
...@@ -632,6 +688,20 @@ class InputBatch: ...@@ -632,6 +688,20 @@ class InputBatch:
self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i1],
) )
if self.multi_layer_eagle_num > 0:
self.cached_len[i1], self.cached_len[i2] = (
self.cached_len[i2],
self.cached_len[i1],
)
self.cached_token_ids[[i1, i2], ...] = self.cached_token_ids[[i2, i1], ...]
self.cached_hidden_states[[i1, i2], ...] = self.cached_hidden_states[
[i2, i1], ...
]
self.cached_slot_mappings[[i1, i2], ...] = self.cached_slot_mappings[
[i2, i1], ...
]
self.cached_positions[[i1, i2], ...] = self.cached_positions[[i2, i1], ...]
swap_dict_values(self.generators, i1, i2) swap_dict_values(self.generators, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2) swap_dict_values(self.bad_words_token_ids, i1, i2)
...@@ -769,6 +839,23 @@ class InputBatch: ...@@ -769,6 +839,23 @@ class InputBatch:
if bad_words_token_ids is not None: if bad_words_token_ids is not None:
self.bad_words_token_ids[empty_index] = bad_words_token_ids self.bad_words_token_ids[empty_index] = bad_words_token_ids
if self.multi_layer_eagle_num > 0:
self.cached_len[empty_index] = self.cached_len[
last_req_index
]
self.cached_token_ids[empty_index] = self.cached_token_ids[
last_req_index
]
self.cached_hidden_states[empty_index] = self.cached_hidden_states[
last_req_index
]
self.cached_slot_mappings[empty_index] = self.cached_slot_mappings[
last_req_index
]
self.cached_positions[empty_index] = self.cached_positions[
last_req_index
]
# Decrement last_req_index since it is now empty. # Decrement last_req_index since it is now empty.
last_req_index -= 1 last_req_index -= 1
......
...@@ -164,7 +164,8 @@ from vllm.v1.spec_decode.draft_model import DraftModelProposer ...@@ -164,7 +164,8 @@ from vllm.v1.spec_decode.draft_model import DraftModelProposer
from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer
from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata, SpecDecodeMetadata
from vllm.v1.spec_decode.multi_layer_eagle import MultiLayerEagleProposer
from vllm.v1.spec_decode.ngram_proposer_gpu import ( from vllm.v1.spec_decode.ngram_proposer_gpu import (
NgramProposerGPU, NgramProposerGPU,
copy_num_valid_draft_tokens, copy_num_valid_draft_tokens,
...@@ -374,6 +375,7 @@ class ExecuteModelState(NamedTuple): ...@@ -374,6 +375,7 @@ class ExecuteModelState(NamedTuple):
scheduler_output: "SchedulerOutput" scheduler_output: "SchedulerOutput"
logits: torch.Tensor logits: torch.Tensor
spec_decode_metadata: SpecDecodeMetadata | None spec_decode_metadata: SpecDecodeMetadata | None
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None
spec_decode_common_attn_metadata: CommonAttentionMetadata | None spec_decode_common_attn_metadata: CommonAttentionMetadata | None
hidden_states: torch.Tensor hidden_states: torch.Tensor
sample_hidden_states: torch.Tensor sample_hidden_states: torch.Tensor
...@@ -500,6 +502,11 @@ class GPUModelRunner( ...@@ -500,6 +502,11 @@ class GPUModelRunner(
self.late_interaction_runner = LateInteractionRunner() self.late_interaction_runner = LateInteractionRunner()
self.use_aux_hidden_state_outputs = False self.use_aux_hidden_state_outputs = False
# multi layer eagle
self.enable_multi_layer_eagle = False
self.multi_layer_eagle_num = 0
# Set up speculative decoding. # Set up speculative decoding.
# NOTE(Jiayi): currently we put the entire draft model on # NOTE(Jiayi): currently we put the entire draft model on
# the last PP rank. This is not ideal if there are many # the last PP rank. This is not ideal if there are many
...@@ -544,7 +551,17 @@ class GPUModelRunner( ...@@ -544,7 +551,17 @@ class GPUModelRunner(
elif self.speculative_config.method == "suffix": elif self.speculative_config.method == "suffix":
self.drafter = SuffixDecodingProposer(self.vllm_config) self.drafter = SuffixDecodingProposer(self.vllm_config)
elif self.speculative_config.use_eagle(): elif self.speculative_config.use_eagle():
self.drafter = EagleProposer(self.vllm_config, self.device, self) if (
self.speculative_config.enable_multi_layers_mtp
and self.speculative_config.method == "mtp"
):
self.enable_multi_layer_eagle = True
self.drafter = MultiLayerEagleProposer(
self.vllm_config, self.device, self
)
self.multi_layer_eagle_num = self.drafter.layer_num
else:
self.drafter = EagleProposer(self.vllm_config, self.device, self)
if self.speculative_config.method == "eagle3": if self.speculative_config.method == "eagle3":
self.use_aux_hidden_state_outputs = ( self.use_aux_hidden_state_outputs = (
self.drafter.eagle3_use_aux_hidden_state self.drafter.eagle3_use_aux_hidden_state
...@@ -623,6 +640,10 @@ class GPUModelRunner( ...@@ -623,6 +640,10 @@ class GPUModelRunner(
logitsprocs_need_output_token_ids=bool(custom_logitsprocs), logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
is_pooling_model=self.is_pooling_model, is_pooling_model=self.is_pooling_model,
cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size, cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
multi_layer_eagle_num=self.multi_layer_eagle_num
if self.enable_multi_layer_eagle
else 0,
hidden_size=self.model_config.get_hidden_size(),
) )
# Separate cuda stream for overlapping transfer of sampled token ids from # Separate cuda stream for overlapping transfer of sampled token ids from
...@@ -1143,6 +1164,9 @@ class GPUModelRunner( ...@@ -1143,6 +1164,9 @@ class GPUModelRunner(
if self.uses_xdrope_dim > 0: if self.uses_xdrope_dim > 0:
self._init_xdrope_positions(req_state) self._init_xdrope_positions(req_state)
if self.enable_multi_layer_eagle:
self._init_multi_layer_eagle_cache(req_state)
reqs_to_add.append(req_state) reqs_to_add.append(req_state)
# Track new requests for ngram_gpu full tensor copy # Track new requests for ngram_gpu full tensor copy
if is_ngram_gpu: if is_ngram_gpu:
...@@ -1442,6 +1466,24 @@ class GPUModelRunner( ...@@ -1442,6 +1466,24 @@ class GPUModelRunner(
req_state.mm_features, req_state.mm_features,
) )
def _init_multi_layer_eagle_cache(self, req_state: CachedRequestState):
req_state.cached_len = torch.zeros(1, dtype=torch.int64, device=self.device)
req_state.cached_hidden_states = torch.zeros(
self.multi_layer_eagle_num,
self.model_config.get_hidden_size(),
dtype=self.dtype,
device=self.device,
)
req_state.cached_token_ids = torch.zeros(
self.multi_layer_eagle_num, dtype=torch.int32, device=self.device
)
req_state.cached_positions = torch.zeros(
self.multi_layer_eagle_num, dtype=torch.int64, device=self.device
)
req_state.cached_slot_mappings = torch.zeros(
self.multi_layer_eagle_num, dtype=torch.int64, device=self.device
)
def _extract_mm_kwargs( def _extract_mm_kwargs(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
...@@ -1672,10 +1714,11 @@ class GPUModelRunner( ...@@ -1672,10 +1714,11 @@ class GPUModelRunner(
) -> tuple[ ) -> tuple[
torch.Tensor, torch.Tensor,
SpecDecodeMetadata | None, SpecDecodeMetadata | None,
MultiLayerEagleMetadata | None,
]: ]:
""" """
:return: tuple[ :return: tuple[
logits_indices, spec_decode_metadata, logits_indices, spec_decode_metadata, multi_layer_eagle_metadata,
] ]
""" """
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
...@@ -1879,9 +1922,21 @@ class GPUModelRunner( ...@@ -1879,9 +1922,21 @@ class GPUModelRunner(
self.input_batch, num_scheduled_tokens, num_sampled_tokens self.input_batch, num_scheduled_tokens, num_sampled_tokens
) )
if self.enable_multi_layer_eagle:
multi_layer_eagle_metadata = MultiLayerEagleMetadata(
cached_len=self.input_batch.cached_len[:num_reqs],
cached_token_ids=self.input_batch.cached_token_ids[:num_reqs],
cached_hidden_states=self.input_batch.cached_hidden_states[:num_reqs],
cached_slot_mappings=self.input_batch.cached_slot_mappings[:num_reqs],
cached_positions=self.input_batch.cached_positions[:num_reqs],
)
else:
multi_layer_eagle_metadata = None
return ( return (
logits_indices, logits_indices,
spec_decode_metadata, spec_decode_metadata,
multi_layer_eagle_metadata,
) )
def _build_attention_metadata( def _build_attention_metadata(
...@@ -3634,9 +3689,11 @@ class GPUModelRunner( ...@@ -3634,9 +3689,11 @@ class GPUModelRunner(
max_num_scheduled_tokens = int(num_scheduled_tokens_np.max()) max_num_scheduled_tokens = int(num_scheduled_tokens_np.max())
num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
logits_indices, spec_decode_metadata = self._prepare_inputs( logits_indices, spec_decode_metadata, multi_layer_eagle_metadata = (
scheduler_output, self._prepare_inputs(
num_scheduled_tokens_np, scheduler_output,
num_scheduled_tokens_np,
)
) )
cascade_attn_prefix_lens = None cascade_attn_prefix_lens = None
...@@ -3867,6 +3924,7 @@ class GPUModelRunner( ...@@ -3867,6 +3924,7 @@ class GPUModelRunner(
scheduler_output, scheduler_output,
logits, logits,
spec_decode_metadata, spec_decode_metadata,
multi_layer_eagle_metadata,
spec_decode_common_attn_metadata, spec_decode_common_attn_metadata,
hidden_states, hidden_states,
sample_hidden_states, sample_hidden_states,
...@@ -3905,6 +3963,7 @@ class GPUModelRunner( ...@@ -3905,6 +3963,7 @@ class GPUModelRunner(
scheduler_output, scheduler_output,
logits, logits,
spec_decode_metadata, spec_decode_metadata,
multi_layer_eagle_metadata,
spec_decode_common_attn_metadata, spec_decode_common_attn_metadata,
hidden_states, hidden_states,
sample_hidden_states, sample_hidden_states,
...@@ -3953,6 +4012,7 @@ class GPUModelRunner( ...@@ -3953,6 +4012,7 @@ class GPUModelRunner(
sample_hidden_states, sample_hidden_states,
aux_hidden_states, aux_hidden_states,
spec_decode_metadata, spec_decode_metadata,
multi_layer_eagle_metadata,
spec_decode_common_attn_metadata, spec_decode_common_attn_metadata,
slot_mappings, slot_mappings,
) )
...@@ -4242,6 +4302,7 @@ class GPUModelRunner( ...@@ -4242,6 +4302,7 @@ class GPUModelRunner(
sample_hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor,
aux_hidden_states: list[torch.Tensor] | None, aux_hidden_states: list[torch.Tensor] | None,
spec_decode_metadata: SpecDecodeMetadata | None, spec_decode_metadata: SpecDecodeMetadata | None,
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None, slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None,
) -> list[list[int]] | torch.Tensor: ) -> list[list[int]] | torch.Tensor:
...@@ -4466,6 +4527,7 @@ class GPUModelRunner( ...@@ -4466,6 +4527,7 @@ class GPUModelRunner(
token_indices_to_sample=token_indices_to_sample, token_indices_to_sample=token_indices_to_sample,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
multi_layer_eagle_metadata=multi_layer_eagle_metadata,
mm_embed_inputs=mm_embed_inputs, mm_embed_inputs=mm_embed_inputs,
num_rejected_tokens_gpu=num_rejected_tokens_gpu, num_rejected_tokens_gpu=num_rejected_tokens_gpu,
slot_mappings=slot_mappings, slot_mappings=slot_mappings,
...@@ -6216,6 +6278,10 @@ class GPUModelRunner( ...@@ -6216,6 +6278,10 @@ class GPUModelRunner(
logitsprocs=self.input_batch.logitsprocs, logitsprocs=self.input_batch.logitsprocs,
logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids, logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids,
is_pooling_model=self.is_pooling_model, is_pooling_model=self.is_pooling_model,
multi_layer_eagle_num=self.multi_layer_eagle_num
if self.enable_multi_layer_eagle
else 0,
hidden_size=self.model_config.get_hidden_size(),
) )
assert self._init_block_sizes == block_sizes, ( assert self._init_block_sizes == block_sizes, (
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
try: try:
from ._version import __version__, __version_tuple__ __version__ = "0.18.1"
__version_tuple__ = (0, 18, 1)
__hcu_version__ = f'0.18.1+das.ori.dtk2604'
from vllm.version import __version__, __version_tuple__, __hcu_version__
except Exception as e: except Exception as e:
import warnings import warnings
warnings.warn(f"Failed to read commit hash:\n{e}", RuntimeWarning, stacklevel=2) warnings.warn(f"Failed to read commit hash:\n + str(e)",
RuntimeWarning,
stacklevel=2)
__version__ = "dev" __version__ = "dev"
__version_tuple__ = (0, 0, __version__) __version_tuple__ = (0, 0, __version__)
def _prev_minor_version_was(version_str): def _prev_minor_version_was(version_str):
"""Check whether a given version matches the previous minor version. '''Check whether a given version matches the previous minor version.
Return True if version_str matches the previous minor version. Return True if version_str matches the previous minor version.
...@@ -21,19 +24,19 @@ def _prev_minor_version_was(version_str): ...@@ -21,19 +24,19 @@ def _prev_minor_version_was(version_str):
supplied version_str is '0.6'. supplied version_str is '0.6'.
Used for --show-hidden-metrics-for-version. Used for --show-hidden-metrics-for-version.
""" '''
# Match anything if this is a dev tree # Match anything if this is a dev tree
if __version_tuple__[0:2] == (0, 0): if __version_tuple__[0:2] == (0, 0):
return True return True
# Note - this won't do the right thing when we release 1.0! # Note - this won't do the right thing when we release 1.0!
assert __version_tuple__[0] == 0 # assert __version_tuple__[0] == 0
assert isinstance(__version_tuple__[1], int) assert isinstance(__version_tuple__[1], int)
return version_str == f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}" return version_str == f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
def _prev_minor_version(): def _prev_minor_version():
"""For the purpose of testing, return a previous minor version number.""" '''For the purpose of testing, return a previous minor version number.'''
# In dev tree, this will return "0.-1", but that will work fine" # In dev tree, this will return "0.-1", but that will work fine"
assert isinstance(__version_tuple__[1], int) assert isinstance(__version_tuple__[1], int)
return f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}" return f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment