Commit 84ac1d27 authored by zhangning3's avatar zhangning3
Browse files

add step3p5 mtp3

parent b27f1671
......@@ -57,6 +57,11 @@ def parse_args():
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--enforce-eager", action="store_true")
parser.add_argument("--enable-chunked-prefill", action="store_true")
parser.add_argument(
"--enable-multi-layers-mtp",
action="store_true",
help="Enable multi-layer MTP (only effective when --method=mtp).",
)
parser.add_argument("--max-model-len", type=int, default=16384)
parser.add_argument("--temp", type=float, default=0)
parser.add_argument("--top-p", type=float, default=1.0)
......@@ -66,12 +71,14 @@ def parse_args():
parser.add_argument("--model-dir", type=str, default=None)
parser.add_argument("--eagle-dir", type=str, default=None)
parser.add_argument("--draft-model", type=str, default=None)
parser.add_argument("--tokenizer-dir", type=str, default=None)
parser.add_argument("--custom-mm-prompts", action="store_true")
parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
parser.add_argument("--disable-padded-drafter-batch", action="store_true")
parser.add_argument("--max-num-seqs", type=int, default=None)
parser.add_argument("--parallel-drafting", action="store_true")
parser.add_argument("--allowed-local-media-path", type=str, default="")
parser.add_argument("--trust-remote-code", action="store_true")
return parser.parse_args()
......@@ -85,7 +92,11 @@ def main(args):
"please specify model_dir to give a mm based model"
)
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
tokenizer_dir = args.tokenizer_dir
if tokenizer_dir is None:
tokenizer_dir = model_dir
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
if args.custom_mm_prompts:
prompts = llm_prompts = get_custom_mm_prompts(args.num_prompts)
......@@ -141,6 +152,8 @@ def main(args):
"method": "mtp",
"num_speculative_tokens": args.num_spec_tokens,
}
if args.enable_multi_layers_mtp:
speculative_config["enable_multi_layers_mtp"] = True
else:
raise ValueError(f"unknown method: {args.method}")
......
#!/usr/bin/env bash
set -u
DEFAULT_BATCH_SIZES=(1 8 16 32 64 128)
MODEL_PATH="/module/step3.5-fp8/"
SERVED_MODEL_NAME="/module/step3.5-fp8/"
DATASET_NAME="random"
DEFAULT_OUTPUT_LEN_DECODE=4096
DEFAULT_OUTPUT_LEN_PREFILL=1
DEFAULT_ROLE="decode"
READY_CHECK_TIMEOUT=3
RESULT_DIR="benchmark_result"
print_usage() {
cat <<'USAGE'
Usage:
./scripts/step3p5_benchmark_test.sh
./scripts/step3p5_benchmark_test.sh 1,8,16,32
./scripts/step3p5_benchmark_test.sh --role prefill
./scripts/step3p5_benchmark_test.sh --role decode
./scripts/step3p5_benchmark_test.sh --role both
./scripts/step3p5_benchmark_test.sh --role prefill --output-len 1
Description:
- No argument: use default batch sizes, role=decode, output-len=4096
- Optional positional argument: batch size list (comma or space separated)
- Optional flag: --role <prefill|decode|both>
- Optional flag: --output-len <N> (must be positive integer)
- role=both 时串行执行 prefill 再 decode
- Result files are saved under:
<result_dir>/prefill (when role=prefill)
<result_dir>/decode (when role=decode)
USAGE
}
parse_batch_sizes() {
local raw_input="${1:-}"
if [[ -z "$raw_input" ]]; then
BATCH_SIZES=("${DEFAULT_BATCH_SIZES[@]}")
return
fi
raw_input="${raw_input//,/ }"
read -r -a BATCH_SIZES <<< "$raw_input"
if (( ${#BATCH_SIZES[@]} == 0 )); then
echo "[ERROR] batch size 列表为空。"
print_usage
exit 1
fi
local batch_size
for batch_size in "${BATCH_SIZES[@]}"; do
if ! [[ "$batch_size" =~ ^[1-9][0-9]*$ ]]; then
echo "[ERROR] 非法 batch size: $batch_size(必须是正整数)"
exit 1
fi
done
}
parse_role() {
local role_input="${1:-$DEFAULT_ROLE}"
if [[ "$role_input" != "prefill" && "$role_input" != "decode" && "$role_input" != "both" ]]; then
echo "[ERROR] 非法 role: $role_input(必须是 prefill、decode 或 both)"
print_usage
exit 1
fi
ROLE="$role_input"
}
parse_output_len() {
local output_len_input="$1"
if ! [[ "$output_len_input" =~ ^[1-9][0-9]*$ ]]; then
echo "[ERROR] 非法 output-len: $output_len_input(必须是正整数)"
print_usage
exit 1
fi
RANDOM_OUTPUT_LEN="$output_len_input"
}
calc_num_prompts() {
local batch_size="$1"
local num_prompts=$((batch_size + batch_size / 2))
if (( num_prompts < 16 )); then
num_prompts=16
fi
if (( num_prompts > 384 )); then
num_prompts=384
fi
if (( num_prompts < batch_size )); then
num_prompts=$batch_size
fi
echo "$num_prompts"
}
main() {
local batch_arg=""
local role_arg="$DEFAULT_ROLE"
local output_len_arg=""
while (( $# > 0 )); do
case "$1" in
-h|--help)
print_usage
exit 0
;;
--role|-r)
if [[ -z "${2:-}" ]]; then
echo "[ERROR] --role 缺少参数。"
print_usage
exit 1
fi
role_arg="$2"
shift 2
;;
--output-len|-o)
if [[ -z "${2:-}" ]]; then
echo "[ERROR] --output-len 缺少参数。"
print_usage
exit 1
fi
output_len_arg="$2"
shift 2
;;
--*)
echo "[ERROR] 未知参数: $1"
print_usage
exit 1
;;
*)
if [[ -n "$batch_arg" ]]; then
echo "[ERROR] 仅支持一个 batch size 列表参数。"
print_usage
exit 1
fi
batch_arg="$1"
shift
;;
esac
done
parse_batch_sizes "$batch_arg"
parse_role "$role_arg"
if [[ "$ROLE" == "both" ]]; then
local -a prefill_cmd=("$0")
local -a decode_cmd=("$0")
if [[ -n "$batch_arg" ]]; then
prefill_cmd+=("$batch_arg")
decode_cmd+=("$batch_arg")
fi
prefill_cmd+=("--role" "prefill")
decode_cmd+=("--role" "decode")
if [[ -n "$output_len_arg" ]]; then
decode_cmd+=("--output-len" "$output_len_arg")
fi
echo "[INFO] role=both: 将串行执行 prefill 和 decode"
echo "[INFO] step1: ${prefill_cmd[*]}"
"${prefill_cmd[@]}"
echo "[INFO] step2: ${decode_cmd[*]}"
"${decode_cmd[@]}"
echo "[INFO] role=both 执行完成。"
return 0
fi
if [[ "$ROLE" == "prefill" ]]; then
if [[ -n "$output_len_arg" && "$output_len_arg" != "$DEFAULT_OUTPUT_LEN_PREFILL" ]]; then
echo "[WARN] role=prefill 时 output-len 必须为 1,已自动覆盖为 1。"
fi
output_len_arg="$DEFAULT_OUTPUT_LEN_PREFILL"
elif [[ -z "$output_len_arg" ]]; then
output_len_arg="$DEFAULT_OUTPUT_LEN_DECODE"
fi
parse_output_len "$output_len_arg"
RESULT_SUBDIR="$RESULT_DIR/$ROLE"
mkdir -p "$RESULT_SUBDIR"
echo "[INFO] 将执行 ${#BATCH_SIZES[@]} 组 benchmark"
echo "[INFO] role: $ROLE"
echo "[INFO] random-output-len: $RANDOM_OUTPUT_LEN"
echo "[INFO] batch size 列表: ${BATCH_SIZES[*]}"
echo "[INFO] result_dir: $RESULT_SUBDIR"
local batch_size
local num_prompts
local failed_count=0
for batch_size in "${BATCH_SIZES[@]}"; do
num_prompts="$(calc_num_prompts "$batch_size")"
echo ""
echo "[INFO] 开始测试: role=$ROLE, batch_size=$batch_size, max_concurrency=$batch_size, num_prompts=$num_prompts, random_output_len=$RANDOM_OUTPUT_LEN"
if ! vllm bench serve \
--backend vllm \
--model "$MODEL_PATH" \
--served-model-name "$SERVED_MODEL_NAME" \
--dataset-name "$DATASET_NAME" \
--random-input-len 65536 \
--random-output-len "$RANDOM_OUTPUT_LEN" \
--num-prompts "$num_prompts" \
--temperature 0 \
--max-concurrency "$batch_size" \
--ready-check-timeout "$READY_CHECK_TIMEOUT" \
--result-dir "$RESULT_SUBDIR" \
--port 8018 \
--save-result; then
echo "[WARN] batch_size=$batch_size 执行失败,继续下一个。"
failed_count=$((failed_count + 1))
else
echo "[INFO] batch_size=$batch_size 执行完成。"
fi
done
echo ""
if (( failed_count > 0 )); then
echo "[WARN] 全部执行结束,但有 $failed_count 组失败。"
exit 1
fi
echo "[INFO] 全部 benchmark 执行完成。"
}
main "$@"
......@@ -386,6 +386,57 @@ TX
assert extracted_tool_calls.tool_calls[0].function.name == "get_current_weather"
def test_extract_tool_calls_missing_function_name(step3p5_tool_parser, sample_tools):
"""Tool call without function name should fallback to content."""
model_output = (
"<tool_call><parameter=pattern>*.py</parameter></function></tool_call>"
)
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
extracted_tool_calls = step3p5_tool_parser.extract_tool_calls(
model_output, request=request
)
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
def test_extract_tool_calls_empty_function_name(step3p5_tool_parser, sample_tools):
"""Tool call with empty function name should fallback to content."""
model_output = (
"<tool_call><function=><parameter=pattern>*.py</parameter>"
"</function></tool_call>"
)
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
extracted_tool_calls = step3p5_tool_parser.extract_tool_calls(
model_output, request=request
)
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
def test_extract_tool_calls_empty_function_name_single_quotes(
step3p5_tool_parser, sample_tools
):
"""Tool call with empty function name (single quotes) should fallback."""
model_output = (
"<tool_call><function=''><parameter=pattern>*.py</parameter></tool_call>"
)
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
extracted_tool_calls = step3p5_tool_parser.extract_tool_calls(
model_output, request=request
)
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
def test_extract_tool_calls_type_conversion(step3p5_tool_parser):
"""Test parameter type conversion based on tool schema"""
tools = [
......@@ -623,7 +674,7 @@ def test_extract_tool_calls_streaming(
expected_tool_calls,
expected_content,
):
"""Test incremental streaming behavior including typed parameters"""
"""Test streaming returns complete tool calls when parsed."""
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
other_content = ""
......@@ -647,11 +698,10 @@ def test_extract_tool_calls_streaming(
tool_states[idx] = {
"id": None,
"name": None,
"arguments": "",
"arguments": None,
"type": None,
}
# First chunk should have id, name, and type
if tool_call.id:
tool_states[idx]["id"] = tool_call.id
......@@ -666,8 +716,15 @@ def test_extract_tool_calls_streaming(
tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None:
# Accumulate arguments incrementally
tool_states[idx]["arguments"] += tool_call.function.arguments
# Arguments should be complete JSON when emitted.
json.loads(tool_call.function.arguments)
if tool_states[idx]["arguments"] is None:
tool_states[idx]["arguments"] = tool_call.function.arguments
else:
assert (
tool_states[idx]["arguments"]
== tool_call.function.arguments
)
# Verify final content
assert other_content == (expected_content or "") # Handle None case
......@@ -682,7 +739,7 @@ def test_extract_tool_calls_streaming(
assert state["type"] == "function"
assert state["name"] == expected_tool.function.name
# Parse accumulated arguments
# Parse arguments
arguments_str = state["arguments"]
assert arguments_str is not None
actual_args = json.loads(arguments_str)
......@@ -770,7 +827,7 @@ fahrenheit
tool_states[idx] = {
"id": None,
"name": None,
"arguments": "",
"arguments": None,
"type": None,
}
......@@ -786,7 +843,14 @@ fahrenheit
tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None:
tool_states[idx]["arguments"] += tool_call.function.arguments
json.loads(tool_call.function.arguments)
if tool_states[idx]["arguments"] is None:
tool_states[idx]["arguments"] = tool_call.function.arguments
else:
assert (
tool_states[idx]["arguments"]
== tool_call.function.arguments
)
# Verify content was streamed
assert "Let me check the weather for you:" in other_content
......@@ -806,62 +870,69 @@ fahrenheit
assert args["unit"] == "fahrenheit"
def test_extract_tool_calls_streaming_incremental(
def test_extract_tool_calls_streaming_missing_function_name(
step3p5_tool_parser, step3p5_tokenizer, sample_tools
):
"""Test that streaming is truly incremental"""
model_output = """I'll check the weather.<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
</function>
</tool_call>"""
"""Streaming: missing function name should be treated as content."""
model_output = (
"<tool_call><parameter=pattern>*.py</parameter></function></tool_call>"
)
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
chunks = []
for delta_message in stream_delta_message_generator(
step3p5_tool_parser, step3p5_tokenizer, model_output, request
other_content = ""
tool_calls = []
for delta_message in stream_delta_message_generator_from_chunks(
step3p5_tool_parser,
step3p5_tokenizer,
[
"<tool_call><parameter=pattern>",
"*.py</parameter>",
"</function></tool_call>",
],
request,
):
if delta_message.content:
other_content += delta_message.content
if delta_message.tool_calls:
tool_calls.extend(delta_message.tool_calls)
assert other_content == model_output
assert tool_calls == []
def test_extract_tool_calls_streaming_empty_function_name(
step3p5_tool_parser, step3p5_tokenizer, sample_tools
):
"""Streaming: empty function name should be treated as content."""
model_output = (
"<tool_call><function=''><parameter=pattern>*.py</parameter>"
"</function></tool_call>"
)
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
other_content = ""
tool_calls = []
for delta_message in stream_delta_message_generator_from_chunks(
step3p5_tool_parser,
step3p5_tokenizer,
[
"<tool_call><function=",
"''><parameter=pattern>*.py</parameter>",
"</function></tool_call>",
],
request,
):
chunks.append(delta_message)
# Should have multiple chunks
assert len(chunks) > 3
# First chunk(s) should be content
assert chunks[0].content is not None
assert chunks[0].tool_calls is None or chunks[0].tool_calls == []
# Should have a chunk with tool header (id, name, type)
header_found = False
for chunk in chunks:
if chunk.tool_calls and chunk.tool_calls[0].id:
header_found = True
assert chunk.tool_calls[0].function.name == "get_current_weather"
assert chunk.tool_calls[0].type == "function"
# Empty initially
assert chunk.tool_calls[0].function.arguments == ""
break
assert header_found
# Should have chunks with incremental arguments
arg_chunks = []
for chunk in chunks:
if chunk.tool_calls and chunk.tool_calls[0].function.arguments:
arg_chunks.append(chunk.tool_calls[0].function.arguments)
# Arguments should be streamed incrementally
assert len(arg_chunks) > 1
# Concatenated arguments should form valid JSON
full_args = "".join(arg_chunks)
parsed_args = json.loads(full_args)
assert parsed_args["city"] == "Dallas"
assert parsed_args["state"] == "TX"
if delta_message.content:
other_content += delta_message.content
if delta_message.tool_calls:
tool_calls.extend(delta_message.tool_calls)
assert other_content == model_output
assert tool_calls == []
def test_extract_tool_calls_complex_type_with_single_quote(step3p5_tool_parser):
......@@ -951,7 +1022,7 @@ rectangle
tool_states[idx] = {
"id": None,
"name": None,
"arguments": "",
"arguments": None,
"type": None,
}
......@@ -967,7 +1038,14 @@ rectangle
tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None:
tool_states[idx]["arguments"] += tool_call.function.arguments
json.loads(tool_call.function.arguments)
if tool_states[idx]["arguments"] is None:
tool_states[idx]["arguments"] = tool_call.function.arguments
else:
assert (
tool_states[idx]["arguments"]
== tool_call.function.arguments
)
# Should have exactly two complete tool calls
assert len(tool_states) == 2, "Should have exactly two complete tool calls"
......@@ -1164,7 +1242,7 @@ rectangle
tool_states[idx] = {
"id": None,
"name": None,
"arguments": "",
"arguments": None,
"type": None,
}
if tool_call.id:
......@@ -1175,7 +1253,14 @@ rectangle
if tool_call.function.name:
tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None:
tool_states[idx]["arguments"] += tool_call.function.arguments
json.loads(tool_call.function.arguments)
if tool_states[idx]["arguments"] is None:
tool_states[idx]["arguments"] = tool_call.function.arguments
else:
assert (
tool_states[idx]["arguments"]
== tool_call.function.arguments
)
# Should have exactly two complete tool calls
assert len(tool_states) == 2, "Should have exactly two complete tool calls"
......@@ -1266,7 +1351,7 @@ rectangle
tool_states[idx] = {
"id": None,
"name": None,
"arguments": "",
"arguments": None,
"type": None,
}
......@@ -1282,7 +1367,14 @@ rectangle
tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None:
tool_states[idx]["arguments"] += tool_call.function.arguments
json.loads(tool_call.function.arguments)
if tool_states[idx]["arguments"] is None:
tool_states[idx]["arguments"] = tool_call.function.arguments
else:
assert (
tool_states[idx]["arguments"]
== tool_call.function.arguments
)
# Should have exactly two complete tool calls
assert len(tool_states) == 2, "Should have exactly two complete tool calls"
......@@ -1344,20 +1436,26 @@ rectangle""",
for delta_message in stream_delta_message_generator_from_chunks(
step3p5_tool_parser, step3p5_tokenizer, delta_text_chunks, request
):
print(delta_message)
if delta_message.tool_calls:
for tool_call in delta_message.tool_calls:
idx = tool_call.index
if idx not in tool_states:
tool_states[idx] = {
"name": None,
"arguments": "",
"arguments": None,
}
if tool_call.function:
if tool_call.function.name:
tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None:
tool_states[idx]["arguments"] += tool_call.function.arguments
json.loads(tool_call.function.arguments)
if tool_states[idx]["arguments"] is None:
tool_states[idx]["arguments"] = tool_call.function.arguments
else:
assert (
tool_states[idx]["arguments"]
== tool_call.function.arguments
)
assert len(tool_states) == 2
assert all(state["name"] for state in tool_states.values())
......@@ -1368,7 +1466,7 @@ rectangle""",
def test_extract_tool_calls_non_streaming_multiple_tool_calls_no_content_between(
step3p5_tool_parser, sample_tools
):
"""Test non-streaming extraction with tool calls and no content between them.
"""Test non-streaming extraction with multiple tool calls.
Scenario: Model outputs "hello" + tool call + tool call.
Expected: "hello" as content, first tool call parsed (index=0),
......
This diff is collapsed.
......@@ -80,6 +80,10 @@ class SpeculativeConfig:
If using `ngram` method, the related configuration `prompt_lookup_max` and
`prompt_lookup_min` should be considered."""
enable_multi_layers_mtp: bool = False
"""If set to True, the MTP method will run multiple layers of MTP
speculator. If set to False, it will run only one layer of MTP speculator.
This is only effective when the method is set to `mtp`."""
draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
"""The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size."""
......@@ -493,7 +497,10 @@ class SpeculativeConfig:
MTPModelTypes
):
self.method = "mtp"
if self.num_speculative_tokens > 1:
if (
self.enable_multi_layers_mtp is False
and self.num_speculative_tokens > 1
):
logger.warning(
"Enabling num_speculative_tokens > 1 will run "
"multiple times of forward on same MTP layer"
......
......@@ -166,7 +166,72 @@ class AnthropicServingMessages(OpenAIServingChat):
if isinstance(msg.content, str):
openai_msg["content"] = msg.content
else:
cls._convert_message_content(msg, openai_msg, openai_messages)
# Handle complex content blocks
content_parts: list[dict[str, Any]] = []
tool_calls: list[dict[str, Any]] = []
reasoning_parts: list[str] = []
for block in msg.content:
if block.type == "text" and block.text:
content_parts.append({"type": "text", "text": block.text})
elif block.type == "image" and block.source:
content_parts.append(
{
"type": "image_url",
"image_url": {"url": block.source.get("data", "")},
}
)
elif block.type == "thinking" and block.thinking is not None:
reasoning_parts.append(block.thinking)
elif block.type == "tool_use":
# Convert tool use to function call format
tool_call = {
"id": block.id or f"call_{int(time.time())}",
"type": "function",
"function": {
"name": block.name or "",
"arguments": json.dumps(block.input or {}),
},
}
tool_calls.append(tool_call)
elif block.type == "tool_result":
if msg.role == "user":
openai_messages.append(
{
"role": "tool",
"tool_call_id": block.id or "",
"content": str(block.content)
if block.content
else "",
}
)
else:
# Assistant tool result becomes regular text
tool_result_text = (
str(block.content) if block.content else ""
)
content_parts.append(
{
"type": "text",
"text": f"Tool result: {tool_result_text}",
}
)
if reasoning_parts:
openai_msg["reasoning"] = "".join(reasoning_parts)
# Add tool calls to the message if any
if tool_calls:
openai_msg["tool_calls"] = tool_calls # type: ignore
# Add content parts if any
if content_parts:
if len(content_parts) == 1 and content_parts[0]["type"] == "text":
openai_msg["content"] = content_parts[0]["text"]
else:
openai_msg["content"] = content_parts # type: ignore
elif not tool_calls and not reasoning_parts:
continue
openai_messages.append(openai_msg)
......@@ -522,49 +587,75 @@ class AnthropicServingMessages(OpenAIServingChat):
first_item = True
finish_reason = None
state = _ActiveBlockState()
content_block_index = 0
active_block_type: str | None = None
active_block_index: int | None = None
active_block_signature: str | None = None
signature_emitted = False
active_tool_use_id: str | None = None
# Map from tool call index to tool_use_id
tool_index_to_id: dict[int, str] = {}
def stop_active_block():
nonlocal active_block_type, active_block_index, content_block_index
nonlocal active_block_signature, signature_emitted, active_tool_use_id
events: list[str] = []
if state.block_type is None:
if active_block_type is None:
return events
if (
state.block_type == "thinking"
and state.block_signature is not None
and not state.signature_emitted
active_block_type == "thinking"
and active_block_signature is not None
and not signature_emitted
):
chunk = AnthropicStreamEvent(
index=state.block_index,
index=active_block_index,
type="content_block_delta",
delta=AnthropicDelta(
type="signature_delta",
signature=state.block_signature,
signature=active_block_signature,
),
)
data = chunk.model_dump_json(exclude_unset=True)
events.append(wrap_data_with_event(data, "content_block_delta"))
state.signature_emitted = True
signature_emitted = True
stop_chunk = AnthropicStreamEvent(
index=state.block_index,
index=active_block_index,
type="content_block_stop",
)
data = stop_chunk.model_dump_json(exclude_unset=True)
events.append(wrap_data_with_event(data, "content_block_stop"))
state.reset()
state.content_block_index += 1
active_block_type = None
active_block_index = None
active_block_signature = None
signature_emitted = False
active_tool_use_id = None
content_block_index += 1
return events
def start_block(block: AnthropicContentBlock):
nonlocal active_block_type, active_block_index, content_block_index
nonlocal active_block_signature, signature_emitted, active_tool_use_id
chunk = AnthropicStreamEvent(
index=state.content_block_index,
index=content_block_index,
type="content_block_start",
content_block=block,
)
data = chunk.model_dump_json(exclude_unset=True)
event = wrap_data_with_event(data, "content_block_start")
state.start(block)
active_block_type = block.type
active_block_index = content_block_index
if block.type == "thinking":
active_block_signature = uuid.uuid4().hex
signature_emitted = False
active_tool_use_id = None
elif block.type == "tool_use":
active_block_signature = None
signature_emitted = True
active_tool_use_id = block.id
else:
active_block_signature = None
signature_emitted = True
active_tool_use_id = None
return event
async for item in generator:
......@@ -638,7 +729,7 @@ class AnthropicServingMessages(OpenAIServingChat):
if reasoning_delta == "":
pass
else:
if state.block_type != "thinking":
if active_block_type != "thinking":
for event in stop_active_block():
yield event
start_event = start_block(
......@@ -649,9 +740,9 @@ class AnthropicServingMessages(OpenAIServingChat):
yield start_event
chunk = AnthropicStreamEvent(
index=(
state.block_index
if state.block_index is not None
else state.content_block_index
active_block_index
if active_block_index is not None
else content_block_index
),
type="content_block_delta",
delta=AnthropicDelta(
......@@ -666,7 +757,7 @@ class AnthropicServingMessages(OpenAIServingChat):
if origin_chunk.choices[0].delta.content == "":
pass
else:
if state.block_type != "text":
if active_block_type != "text":
for event in stop_active_block():
yield event
start_event = start_block(
......@@ -675,9 +766,9 @@ class AnthropicServingMessages(OpenAIServingChat):
yield start_event
chunk = AnthropicStreamEvent(
index=(
state.block_index
if state.block_index is not None
else state.content_block_index
active_block_index
if active_block_index is not None
else content_block_index
),
type="content_block_delta",
delta=AnthropicDelta(
......@@ -702,7 +793,7 @@ class AnthropicServingMessages(OpenAIServingChat):
else None
)
if (
state.tool_use_id != tool_call.id
active_tool_use_id != tool_call.id
and tool_name is not None
):
for event in stop_active_block():
......@@ -720,13 +811,13 @@ class AnthropicServingMessages(OpenAIServingChat):
if (
tool_call.function
and tool_call.function.arguments
and state.tool_use_id == tool_call.id
and active_tool_use_id == tool_call.id
):
chunk = AnthropicStreamEvent(
index=(
state.block_index
if state.block_index is not None
else state.content_block_index
active_block_index
if active_block_index is not None
else content_block_index
),
type="content_block_delta",
delta=AnthropicDelta(
......@@ -745,13 +836,13 @@ class AnthropicServingMessages(OpenAIServingChat):
tool_use_id is not None
and tool_call.function
and tool_call.function.arguments
and state.tool_use_id == tool_use_id
and active_tool_use_id == tool_use_id
):
chunk = AnthropicStreamEvent(
index=(
state.block_index
if state.block_index is not None
else state.content_block_index
active_block_index
if active_block_index is not None
else content_block_index
),
type="content_block_delta",
delta=AnthropicDelta(
......
......@@ -1101,10 +1101,10 @@ class OpenAIServingChat(OpenAIServing):
index = 0
if (
self._should_check_for_unstreamed_tool_arg_tokens(
delta_message, output
tool_parser
and self._should_check_for_unstreamed_tool_arg_tokens(
delta_message, output, tool_parser
)
and tool_parser
):
latest_delta_len = 0
if (
......@@ -1760,6 +1760,7 @@ class OpenAIServingChat(OpenAIServing):
self,
delta_message: DeltaMessage | None,
output: CompletionOutput,
tool_parser: ToolParser | None = None,
) -> bool:
"""
Check to see if we should check for unstreamed tool arguments tokens.
......@@ -1772,6 +1773,8 @@ class OpenAIServingChat(OpenAIServing):
# include a function that has arguments
output.finish_reason is not None
and self.enable_auto_tools
and tool_parser is not None
and tool_parser.parser_should_check_for_unstreamed_tool_arg_tokens()
and self.tool_parser
and delta_message
and delta_message.tool_calls
......
......@@ -262,6 +262,7 @@ def select_fp8_moe_backend(
supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format
)
supported = True
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
......
......@@ -6,6 +6,7 @@ import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
......@@ -40,9 +41,11 @@ class SharedHead(nn.Module):
return self.norm(hidden_states)
@support_torch_compile
class Step3p5AMultiTokenPredictorLayer(nn.Module):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str,
) -> None:
......@@ -52,7 +55,7 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
self.enorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.hnorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.lm_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = Step3p5DecoderLayer(
vllm_config,
prefix=f"{prefix}.mtp_block",
......@@ -64,9 +67,12 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
embed_tokens: VocabParallelEmbedding | None = None,
spec_step_index: int = 0,
) -> torch.Tensor:
assert inputs_embeds is not None
if inputs_embeds is None:
assert embed_tokens is not None
inputs_embeds = embed_tokens(input_ids)
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
......@@ -92,8 +98,8 @@ class Step3p5AMultiTokenPredictor(nn.Module):
self.layers = torch.nn.ModuleDict(
{
str(idx): Step3p5AMultiTokenPredictorLayer(
vllm_config,
f"{prefix}.layers.{idx}",
vllm_config=vllm_config,
prefix=f"{prefix}.layers.{idx}",
)
for idx in range(
self.mtp_start_layer_idx,
......@@ -112,14 +118,13 @@ class Step3p5AMultiTokenPredictor(nn.Module):
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = spec_step_idx % self.num_mtp_layers
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids,
positions,
previous_hidden_states,
inputs_embeds,
self.embed_tokens,
current_step_idx,
)
......@@ -131,7 +136,7 @@ class Step3p5AMultiTokenPredictor(nn.Module):
current_step_idx = spec_step_idx % self.num_mtp_layers
mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
logits = self.logits_processor(
mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states)
mtp_layer.lm_head.head, mtp_layer.lm_head(hidden_states)
)
return logits
......@@ -257,6 +262,7 @@ class Step3p5MTP(nn.Module):
name = name.replace(".transformer.", ".")
if "shared_head" in name:
name = name.replace("shared_head.output", "shared_head.head")
name = name.replace("shared_head", "lm_head")
if "embed_tokens" in name:
assert (
hasattr(self.config, "num_nextn_predict_layers")
......
......@@ -118,6 +118,12 @@ class ToolParser:
"AbstractToolParser.extract_tool_calls_streaming has not been implemented!"
)
def parser_should_check_for_unstreamed_tool_arg_tokens(self) -> bool:
"""
Whether to check for unstreamed tool-argument tokens in serving
"""
return True
class ToolParserManager:
"""
......
This diff is collapsed.
......@@ -514,7 +514,7 @@ class TritonAttentionImpl(AttentionImpl):
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
seq_threshold_3D=seq_threshold_3D,
seq_threshold_3D=None,
num_par_softmax_segments=num_par_softmax_segments,
softmax_segm_output=softmax_segm_output,
softmax_segm_max=softmax_segm_max,
......
......@@ -957,6 +957,7 @@ def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bo
def _get_kv_cache_groups_uniform_page_size(
vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
) -> list[KVCacheGroupSpec]:
"""
......@@ -976,6 +977,12 @@ def _get_kv_cache_groups_uniform_page_size(
The KVCacheManager allocates the block_table for each group based on its
kv_cache spec, and the model runner applies the block table to each layer
in the group.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The KVCacheSpec of each attention layer in the model
Returns:
The generated KVCacheGroupSpecs
For example:
1. A model only uses full attention. The pattern is
(num_hidden_layers * full), so there is only one group and the block table
......@@ -1062,6 +1069,15 @@ def _get_kv_cache_groups_uniform_page_size(
num_padding_layers / len(layers) * 100,
)
num_groups = cdiv(len(layers), group_size)
# for support multi layer mtp, we need to
# make all mtp layers in the same group
if (
vllm_config.speculative_config is not None
and vllm_config.speculative_config.enable_multi_layers_mtp
):
for i in range(0, len(layers), group_size):
grouped_layers.append(layers[i : i + group_size])
else:
# In PP case, say if we have
# - stage 0: full.0, sw.0, sw.1
# - stage 1: full.1, sw.2, sw.3
......@@ -1259,7 +1275,9 @@ def get_kv_cache_groups(
# have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page
# size.
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
return _get_kv_cache_groups_uniform_page_size(
vllm_config=vllm_config, kv_cache_spec=kv_cache_spec
)
def generate_scheduler_kv_cache_config(
......
......@@ -38,7 +38,7 @@ from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import KVCacheConfig, UniformTypeKVCacheSpecs
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import _SAMPLING_EPS
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata, SpecDecodeMetadata
from vllm.v1.spec_decode.utils import (
PADDING_SLOT_ID,
compute_new_slot_mapping,
......@@ -395,6 +395,7 @@ class SpecDecodeBaseProposer:
token_indices_to_sample: torch.Tensor | None,
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
num_rejected_tokens_gpu: torch.Tensor | None = None,
slot_mappings: dict[str, torch.Tensor]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
......@@ -64,3 +66,45 @@ class SpecDecodeMetadata:
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
)
@dataclass
class MultiLayerEagleMetadata:
# [batch_size]
cached_len: torch.Tensor | None = None
# [batch_size, layer_num]
cached_token_ids: torch.Tensor | None = None
# [batch_size, layer_num, hidden_size]
cached_hidden_states: torch.Tensor | None = None
# [batch_size, layer_num]
cached_slot_mappings: torch.Tensor | None = None
# [batch_size, layer_num]
cached_positions: torch.Tensor | None = None
@classmethod
def make_dummy(
cls,
layer_num: int,
hidden_size: int,
device: torch.device,
) -> "MultiLayerEagleMetadata":
cached_len = torch.zeros((1), dtype=torch.int64, device=device)
cached_token_ids = torch.zeros(
(1, layer_num), dtype=torch.int32, device=device
)
cached_hidden_states = torch.zeros(
(1, layer_num, hidden_size), dtype=torch.float32, device=device
)
cached_slot_mappings = torch.zeros(
(1, layer_num), dtype=torch.int64, device=device
)
cached_positions = torch.zeros(
(1, layer_num), dtype=torch.int64, device=device
)
return cls(
cached_len=cached_len,
cached_token_ids=cached_token_ids,
cached_hidden_states=cached_hidden_states,
cached_slot_mappings=cached_slot_mappings,
cached_positions=cached_positions,
)
This diff is collapsed.
......@@ -53,6 +53,13 @@ class CachedRequestState:
pooling_params: PoolingParams | None = None
pooling_states: PoolingStates | None = None
# for multi layer eagle proposer
cached_len: torch.Tensor | None = None
cached_token_ids: torch.Tensor | None = None
cached_hidden_states: torch.Tensor | None = None
cached_slot_mappings: torch.Tensor | None = None
cached_positions: torch.Tensor | None = None
def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds
......@@ -95,6 +102,8 @@ class InputBatch:
is_spec_decode: bool = False,
is_pooling_model: bool = False,
cp_kv_cache_interleave_size: int = 1,
multi_layer_eagle_num: int = 0,
hidden_size: int | None = None,
):
self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode
......@@ -223,6 +232,46 @@ class InputBatch:
)
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
# Multi layer eagle
self.multi_layer_eagle_num = multi_layer_eagle_num
if multi_layer_eagle_num > 0:
self.cached_len = torch.zeros(
(max_num_reqs,), dtype=torch.int64, device=device
)
self.cached_token_ids = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
),
dtype=torch.int32,
device=device,
)
self.cached_hidden_states = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
hidden_size,
),
dtype=torch.float,
device=device,
)
self.cached_slot_mappings = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
),
dtype=torch.int64,
device=device,
)
self.cached_positions = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
),
dtype=torch.int64,
device=device,
)
# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
self.lora_id_to_request_ids: dict[int, set[str]] = {}
......@@ -437,6 +486,13 @@ class InputBatch:
# Speculative decoding: by default 1 token is generated.
self.num_accepted_tokens_cpu[req_index] = 1
if self.multi_layer_eagle_num > 0:
self.cached_len[req_index] = request.cached_len
self.cached_token_ids[req_index] = request.cached_token_ids
self.cached_hidden_states[req_index] = request.cached_hidden_states
self.cached_slot_mappings[req_index] = request.cached_slot_mappings
self.cached_positions[req_index] = request.cached_positions
# Add request lora ID
if request.lora_request:
lora_id = request.lora_request.lora_int_id
......@@ -632,6 +688,20 @@ class InputBatch:
self.num_accepted_tokens_cpu[i1],
)
if self.multi_layer_eagle_num > 0:
self.cached_len[i1], self.cached_len[i2] = (
self.cached_len[i2],
self.cached_len[i1],
)
self.cached_token_ids[[i1, i2], ...] = self.cached_token_ids[[i2, i1], ...]
self.cached_hidden_states[[i1, i2], ...] = self.cached_hidden_states[
[i2, i1], ...
]
self.cached_slot_mappings[[i1, i2], ...] = self.cached_slot_mappings[
[i2, i1], ...
]
self.cached_positions[[i1, i2], ...] = self.cached_positions[[i2, i1], ...]
swap_dict_values(self.generators, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2)
......@@ -769,6 +839,23 @@ class InputBatch:
if bad_words_token_ids is not None:
self.bad_words_token_ids[empty_index] = bad_words_token_ids
if self.multi_layer_eagle_num > 0:
self.cached_len[empty_index] = self.cached_len[
last_req_index
]
self.cached_token_ids[empty_index] = self.cached_token_ids[
last_req_index
]
self.cached_hidden_states[empty_index] = self.cached_hidden_states[
last_req_index
]
self.cached_slot_mappings[empty_index] = self.cached_slot_mappings[
last_req_index
]
self.cached_positions[empty_index] = self.cached_positions[
last_req_index
]
# Decrement last_req_index since it is now empty.
last_req_index -= 1
......
......@@ -164,7 +164,8 @@ from vllm.v1.spec_decode.draft_model import DraftModelProposer
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata, SpecDecodeMetadata
from vllm.v1.spec_decode.multi_layer_eagle import MultiLayerEagleProposer
from vllm.v1.spec_decode.ngram_proposer_gpu import (
NgramProposerGPU,
copy_num_valid_draft_tokens,
......@@ -374,6 +375,7 @@ class ExecuteModelState(NamedTuple):
scheduler_output: "SchedulerOutput"
logits: torch.Tensor
spec_decode_metadata: SpecDecodeMetadata | None
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None
spec_decode_common_attn_metadata: CommonAttentionMetadata | None
hidden_states: torch.Tensor
sample_hidden_states: torch.Tensor
......@@ -500,6 +502,11 @@ class GPUModelRunner(
self.late_interaction_runner = LateInteractionRunner()
self.use_aux_hidden_state_outputs = False
# multi layer eagle
self.enable_multi_layer_eagle = False
self.multi_layer_eagle_num = 0
# Set up speculative decoding.
# NOTE(Jiayi): currently we put the entire draft model on
# the last PP rank. This is not ideal if there are many
......@@ -544,6 +551,16 @@ class GPUModelRunner(
elif self.speculative_config.method == "suffix":
self.drafter = SuffixDecodingProposer(self.vllm_config)
elif self.speculative_config.use_eagle():
if (
self.speculative_config.enable_multi_layers_mtp
and self.speculative_config.method == "mtp"
):
self.enable_multi_layer_eagle = True
self.drafter = MultiLayerEagleProposer(
self.vllm_config, self.device, self
)
self.multi_layer_eagle_num = self.drafter.layer_num
else:
self.drafter = EagleProposer(self.vllm_config, self.device, self)
if self.speculative_config.method == "eagle3":
self.use_aux_hidden_state_outputs = (
......@@ -623,6 +640,10 @@ class GPUModelRunner(
logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
is_pooling_model=self.is_pooling_model,
cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
multi_layer_eagle_num=self.multi_layer_eagle_num
if self.enable_multi_layer_eagle
else 0,
hidden_size=self.model_config.get_hidden_size(),
)
# Separate cuda stream for overlapping transfer of sampled token ids from
......@@ -1143,6 +1164,9 @@ class GPUModelRunner(
if self.uses_xdrope_dim > 0:
self._init_xdrope_positions(req_state)
if self.enable_multi_layer_eagle:
self._init_multi_layer_eagle_cache(req_state)
reqs_to_add.append(req_state)
# Track new requests for ngram_gpu full tensor copy
if is_ngram_gpu:
......@@ -1442,6 +1466,24 @@ class GPUModelRunner(
req_state.mm_features,
)
def _init_multi_layer_eagle_cache(self, req_state: CachedRequestState):
req_state.cached_len = torch.zeros(1, dtype=torch.int64, device=self.device)
req_state.cached_hidden_states = torch.zeros(
self.multi_layer_eagle_num,
self.model_config.get_hidden_size(),
dtype=self.dtype,
device=self.device,
)
req_state.cached_token_ids = torch.zeros(
self.multi_layer_eagle_num, dtype=torch.int32, device=self.device
)
req_state.cached_positions = torch.zeros(
self.multi_layer_eagle_num, dtype=torch.int64, device=self.device
)
req_state.cached_slot_mappings = torch.zeros(
self.multi_layer_eagle_num, dtype=torch.int64, device=self.device
)
def _extract_mm_kwargs(
self,
scheduler_output: "SchedulerOutput",
......@@ -1672,10 +1714,11 @@ class GPUModelRunner(
) -> tuple[
torch.Tensor,
SpecDecodeMetadata | None,
MultiLayerEagleMetadata | None,
]:
"""
:return: tuple[
logits_indices, spec_decode_metadata,
logits_indices, spec_decode_metadata, multi_layer_eagle_metadata,
]
"""
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
......@@ -1879,9 +1922,21 @@ class GPUModelRunner(
self.input_batch, num_scheduled_tokens, num_sampled_tokens
)
if self.enable_multi_layer_eagle:
multi_layer_eagle_metadata = MultiLayerEagleMetadata(
cached_len=self.input_batch.cached_len[:num_reqs],
cached_token_ids=self.input_batch.cached_token_ids[:num_reqs],
cached_hidden_states=self.input_batch.cached_hidden_states[:num_reqs],
cached_slot_mappings=self.input_batch.cached_slot_mappings[:num_reqs],
cached_positions=self.input_batch.cached_positions[:num_reqs],
)
else:
multi_layer_eagle_metadata = None
return (
logits_indices,
spec_decode_metadata,
multi_layer_eagle_metadata,
)
def _build_attention_metadata(
......@@ -3634,10 +3689,12 @@ class GPUModelRunner(
max_num_scheduled_tokens = int(num_scheduled_tokens_np.max())
num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
logits_indices, spec_decode_metadata = self._prepare_inputs(
logits_indices, spec_decode_metadata, multi_layer_eagle_metadata = (
self._prepare_inputs(
scheduler_output,
num_scheduled_tokens_np,
)
)
cascade_attn_prefix_lens = None
# Disable cascade attention when using microbatching (DBO)
......@@ -3867,6 +3924,7 @@ class GPUModelRunner(
scheduler_output,
logits,
spec_decode_metadata,
multi_layer_eagle_metadata,
spec_decode_common_attn_metadata,
hidden_states,
sample_hidden_states,
......@@ -3905,6 +3963,7 @@ class GPUModelRunner(
scheduler_output,
logits,
spec_decode_metadata,
multi_layer_eagle_metadata,
spec_decode_common_attn_metadata,
hidden_states,
sample_hidden_states,
......@@ -3953,6 +4012,7 @@ class GPUModelRunner(
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
multi_layer_eagle_metadata,
spec_decode_common_attn_metadata,
slot_mappings,
)
......@@ -4242,6 +4302,7 @@ class GPUModelRunner(
sample_hidden_states: torch.Tensor,
aux_hidden_states: list[torch.Tensor] | None,
spec_decode_metadata: SpecDecodeMetadata | None,
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None,
common_attn_metadata: CommonAttentionMetadata,
slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None,
) -> list[list[int]] | torch.Tensor:
......@@ -4466,6 +4527,7 @@ class GPUModelRunner(
token_indices_to_sample=token_indices_to_sample,
sampling_metadata=sampling_metadata,
common_attn_metadata=common_attn_metadata,
multi_layer_eagle_metadata=multi_layer_eagle_metadata,
mm_embed_inputs=mm_embed_inputs,
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
slot_mappings=slot_mappings,
......@@ -6216,6 +6278,10 @@ class GPUModelRunner(
logitsprocs=self.input_batch.logitsprocs,
logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids,
is_pooling_model=self.is_pooling_model,
multi_layer_eagle_num=self.multi_layer_eagle_num
if self.enable_multi_layer_eagle
else 0,
hidden_size=self.model_config.get_hidden_size(),
)
assert self._init_block_sizes == block_sizes, (
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
try:
from ._version import __version__, __version_tuple__
__version__ = "0.18.1"
__version_tuple__ = (0, 18, 1)
__hcu_version__ = f'0.18.1+das.ori.dtk2604'
from vllm.version import __version__, __version_tuple__, __hcu_version__
except Exception as e:
import warnings
warnings.warn(f"Failed to read commit hash:\n{e}", RuntimeWarning, stacklevel=2)
warnings.warn(f"Failed to read commit hash:\n + str(e)",
RuntimeWarning,
stacklevel=2)
__version__ = "dev"
__version_tuple__ = (0, 0, __version__)
def _prev_minor_version_was(version_str):
"""Check whether a given version matches the previous minor version.
'''Check whether a given version matches the previous minor version.
Return True if version_str matches the previous minor version.
......@@ -21,19 +24,19 @@ def _prev_minor_version_was(version_str):
supplied version_str is '0.6'.
Used for --show-hidden-metrics-for-version.
"""
'''
# Match anything if this is a dev tree
if __version_tuple__[0:2] == (0, 0):
return True
# Note - this won't do the right thing when we release 1.0!
assert __version_tuple__[0] == 0
# assert __version_tuple__[0] == 0
assert isinstance(__version_tuple__[1], int)
return version_str == f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
def _prev_minor_version():
"""For the purpose of testing, return a previous minor version number."""
'''For the purpose of testing, return a previous minor version number.'''
# In dev tree, this will return "0.-1", but that will work fine"
assert isinstance(__version_tuple__[1], int)
return f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment