Unverified Commit d0009ddb authored by stevenkuang's avatar stevenkuang Committed by GitHub
Browse files

[Model] Support Hy3 preview (#40681)


Signed-off-by: default avatarstevenkuang <stevenkuang@tencent.com>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 424033f4
...@@ -419,6 +419,7 @@ th { ...@@ -419,6 +419,7 @@ th {
| `Grok1ForCausalLM` | Grok2 | `xai-org/grok-2` | ✅︎ | ✅︎ | | `Grok1ForCausalLM` | Grok2 | `xai-org/grok-2` | ✅︎ | ✅︎ |
| `HunYuanDenseV1ForCausalLM` | Hunyuan Dense | `tencent/Hunyuan-7B-Instruct` | ✅︎ | ✅︎ | | `HunYuanDenseV1ForCausalLM` | Hunyuan Dense | `tencent/Hunyuan-7B-Instruct` | ✅︎ | ✅︎ |
| `HunYuanMoEV1ForCausalLM` | Hunyuan-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | ✅︎ | | `HunYuanMoEV1ForCausalLM` | Hunyuan-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | ✅︎ |
| `HYV3ForCausalLM` | HY3 | `tencent/Hy3-preview-Base`, `tencent/Hy3-preview` | ✅︎ | ✅︎ |
| `HyperCLOVAXForCausalLM` | HyperCLOVAX-SEED-Think-14B | `naver-hyperclovax/HyperCLOVAX-SEED-Think-14B` | ✅︎ | ✅︎ | | `HyperCLOVAXForCausalLM` | HyperCLOVAX-SEED-Think-14B | `naver-hyperclovax/HyperCLOVAX-SEED-Think-14B` | ✅︎ | ✅︎ |
| `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ | | `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ |
| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | | `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ |
......
...@@ -324,6 +324,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -324,6 +324,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"HunYuanMoEV1ForCausalLM": _HfExamplesInfo( "HunYuanMoEV1ForCausalLM": _HfExamplesInfo(
"tencent/Hunyuan-A13B-Instruct", trust_remote_code=True "tencent/Hunyuan-A13B-Instruct", trust_remote_code=True
), ),
"HYV3ForCausalLM": _HfExamplesInfo("tencent/Hy3-preview", trust_remote_code=True),
"HyperCLOVAXForCausalLM": _HfExamplesInfo( "HyperCLOVAXForCausalLM": _HfExamplesInfo(
"naver-hyperclovax/HyperCLOVAX-SEED-Think-14B", "naver-hyperclovax/HyperCLOVAX-SEED-Think-14B",
trust_remote_code=True, trust_remote_code=True,
...@@ -1516,6 +1517,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { ...@@ -1516,6 +1517,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
is_available_online=False, is_available_online=False,
min_transformers_version="5.1.0", min_transformers_version="5.1.0",
), ),
"HYV3MTPModel": _HfExamplesInfo(
"tencent/Hy3-preview",
speculative_model="tencent/Hy3-preview",
),
"LongCatFlashMTPModel": _HfExamplesInfo( "LongCatFlashMTPModel": _HfExamplesInfo(
"meituan-longcat/LongCat-Flash-Chat", "meituan-longcat/LongCat-Flash-Chat",
trust_remote_code=True, trust_remote_code=True,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from tests.reasoning.utils import run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.tokenizers import get_tokenizer
parser_name = "hy_v3"
MODEL = "tencent/Hy3-preview"
@pytest.fixture(scope="module")
def hy_v3_tokenizer():
return get_tokenizer(tokenizer_name=MODEL)
WITH_THINK = {
"output": "This is a reasoning section</think>This is the rest",
"reasoning": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
"reasoning_effort": "high",
}
WITH_THINK_STREAM = {
"output": "This is a reasoning section</think>This is the rest",
"reasoning": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
"reasoning_effort": "high",
}
WITHOUT_THINK = {
"output": "This is the rest",
"reasoning": None,
"content": "This is the rest",
"is_reasoning_end": True,
"reasoning_effort": "no_think",
}
WITHOUT_THINK_STREAM = {
"output": "This is the rest",
"reasoning": None,
"content": "This is the rest",
"is_reasoning_end": True,
"reasoning_effort": "no_think",
}
WITH_REASONING_EFFORT_NONE = {
"output": "This is the rest",
"reasoning": None,
"content": "This is the rest",
"is_reasoning_end": True,
}
WITH_REASONING_EFFORT_NONE_STREAM = {
"output": "This is the rest",
"reasoning": None,
"content": "This is the rest",
"is_reasoning_end": True,
}
COMPLETE_REASONING = {
"output": "This is a reasoning section</think>",
"reasoning": "This is a reasoning section",
"content": None,
"is_reasoning_end": True,
"reasoning_effort": "high",
}
MULTILINE_REASONING = {
"output": "This is a reasoning\nsection</think>This is the rest\nThat",
"reasoning": "This is a reasoning\nsection",
"content": "This is the rest\nThat",
"is_reasoning_end": True,
"reasoning_effort": "high",
}
ONLY_OPEN_TAG = {
"output": "This is a reasoning section",
"reasoning": "This is a reasoning section",
"content": None,
"is_reasoning_end": False,
"reasoning_effort": "high",
}
ONLY_OPEN_TAG_STREAM = {
"output": "This is a reasoning section",
"reasoning": "This is a reasoning section",
"content": None,
"is_reasoning_end": False,
"reasoning_effort": "high",
}
TEST_CASES = [
pytest.param(
False,
WITH_THINK,
id="with_think",
),
pytest.param(
True,
WITH_THINK_STREAM,
id="with_think_stream",
),
pytest.param(
False,
WITHOUT_THINK,
id="without_think",
),
pytest.param(
True,
WITHOUT_THINK_STREAM,
id="without_think_stream",
),
pytest.param(
False,
WITH_REASONING_EFFORT_NONE,
id="with_reasoning_effort_none",
),
pytest.param(
True,
WITH_REASONING_EFFORT_NONE_STREAM,
id="with_reasoning_effort_none_stream",
),
pytest.param(
False,
COMPLETE_REASONING,
id="complete_reasoning",
),
pytest.param(
True,
COMPLETE_REASONING,
id="complete_reasoning_stream",
),
pytest.param(
False,
MULTILINE_REASONING,
id="multiline_reasoning",
),
pytest.param(
True,
MULTILINE_REASONING,
id="multiline_reasoning_stream",
),
pytest.param(
False,
ONLY_OPEN_TAG,
id="only_open_tag",
),
pytest.param(
True,
ONLY_OPEN_TAG_STREAM,
id="only_open_tag_stream",
),
]
STILL_REASONING_PROMPT = """<|hy_begin▁of▁sentence|>
You are a helpful assistant.
<|reasoning_mode|>reasoning_effort:high<|hy_User|>
What is the capital of France?<|hy_Assistant|>
<think>The user is asking for the capital of"""
DONE_REASONING_PROMPT = """<|hy_begin▁of▁sentence|>
You are a helpful assistant.
<|reasoning_mode|>reasoning_effort:high<|hy_User|>
What is the capital of France?<|hy_Assistant|>
<think>The user is asking for the capital of France.</think>
The capital of France is Paris."""
MULTI_TURN_STILL_REASONING_PROMPT = """<|hy_begin▁of▁sentence|>
You are a helpful assistant.
<|reasoning_mode|>reasoning_effort:high<|hy_User|>
What is the capital of France?<|hy_Assistant|
><think></think>The capital of France is Paris.<eos:6124c78e>
<|hy_User|>What about Chile?<|hy_Assistant|>
<think>The user is asking for the capital of"""
MULTI_TURN_DONE_REASONING_PROMPT = """<|hy_begin▁of▁sentence|>
You are a helpful assistant.
<|reasoning_mode|>reasoning_effort:high<|hy_User|>
What is the capital of France?<|hy_Assistant|
><think></think>The capital of France is Paris.<eos:6124c78e>
<|hy_User|>What about Chile?<|hy_Assistant|>
<think>The user is asking for the capital of Chile.</think>
The capital of Chile is Santiago."""
REASONING_END_TEST_CASES = [
pytest.param(STILL_REASONING_PROMPT, False, id="still_reasoning"),
pytest.param(DONE_REASONING_PROMPT, True, id="done_reasoning"),
pytest.param(
MULTI_TURN_STILL_REASONING_PROMPT, False, id="multi_turn_still_reasoning"
),
pytest.param(
MULTI_TURN_DONE_REASONING_PROMPT, True, id="multi_turn_done_reasoning"
),
]
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
def test_reasoning(
streaming: bool,
param_dict: dict,
hy_v3_tokenizer,
):
output = hy_v3_tokenizer.tokenize(param_dict["output"])
output_tokens: list[str] = [
hy_v3_tokenizer.convert_tokens_to_string([token]) for token in output
]
parser_kwargs = {}
if "reasoning_effort" in param_dict:
parser_kwargs["chat_template_kwargs"] = {
"reasoning_effort": param_dict["reasoning_effort"]
}
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
hy_v3_tokenizer,
**parser_kwargs,
)
reasoning, content = run_reasoning_extraction(
parser, output_tokens, streaming=streaming
)
assert reasoning == param_dict["reasoning"]
assert content == param_dict["content"]
output_ids = hy_v3_tokenizer.convert_tokens_to_ids(output)
is_reasoning_end = parser.is_reasoning_end(output_ids)
assert is_reasoning_end == param_dict["is_reasoning_end"]
@pytest.mark.parametrize("prompt, is_reasoning_end", REASONING_END_TEST_CASES)
def test_is_reasoning_end_full_prompt(
prompt: str, is_reasoning_end: bool, hy_v3_tokenizer
):
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
hy_v3_tokenizer,
chat_template_kwargs={"reasoning_effort": "high"},
)
tokens = hy_v3_tokenizer.tokenize(prompt)
token_ids = hy_v3_tokenizer.convert_tokens_to_ids(tokens)
check_is_reasoning_end = parser.is_reasoning_end(token_ids)
assert check_is_reasoning_end == is_reasoning_end
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
"""Tests for the HYV3 tool call parser."""
import json
from unittest.mock import Mock
import pytest
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
FunctionDefinition,
)
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.tokenizers import get_tokenizer
from vllm.tool_parsers.hy_v3_tool_parser import HYV3ToolParser
parser_name = "hy_v3"
MODEL = "tencent/Hy3-preview"
@pytest.fixture(scope="module")
def hy_v3_tokenizer():
return get_tokenizer(tokenizer_name=MODEL)
@pytest.fixture
def hy_v3_tool_parser(hy_v3_tokenizer):
return HYV3ToolParser(hy_v3_tokenizer)
@pytest.fixture
def mock_request() -> ChatCompletionRequest:
request = Mock(spec=ChatCompletionRequest)
request.tools = [
ChatCompletionToolsParam(
function=FunctionDefinition(name="get_current_date", parameters={}),
),
ChatCompletionToolsParam(
function=FunctionDefinition(
name="get_weather",
parameters={
"type": "object",
"properties": {
"city": {"type": "string"},
"date": {"type": "string"},
},
},
),
),
]
request.tool_choice = "auto"
return request
class TestHYV3ExtractToolCalls:
def test_no_tool_call(self, hy_v3_tool_parser, mock_request):
out = "This is a plain response."
r = hy_v3_tool_parser.extract_tool_calls(out, request=mock_request)
assert not r.tools_called
assert r.content == out
def test_zero_arg_inline(self, hy_v3_tool_parser, mock_request):
out = (
"<tool_calls><tool_call>get_current_date<tool_sep></tool_call></tool_calls>"
)
r = hy_v3_tool_parser.extract_tool_calls(out, request=mock_request)
assert r.tools_called
assert r.tool_calls[0].function.name == "get_current_date"
assert json.loads(r.tool_calls[0].function.arguments) == {}
assert r.content is None
def test_zero_arg_newline(self, hy_v3_tool_parser, mock_request):
out = "<tool_calls>\n<tool_call>get_current_date<tool_sep>\n</tool_call>\n</tool_calls>"
r = hy_v3_tool_parser.extract_tool_calls(out, request=mock_request)
assert r.tools_called
assert r.tool_calls[0].function.name == "get_current_date"
def test_args_same_line(self, hy_v3_tool_parser, mock_request):
out = (
"<tool_calls><tool_call>get_weather<tool_sep><arg_key>city</arg_key><arg_value>Beijing"
"</arg_value><arg_key>date</arg_key><arg_value>2026-03-30</arg_value></tool_call></tool_calls>"
)
r = hy_v3_tool_parser.extract_tool_calls(out, request=mock_request)
assert r.tools_called
assert json.loads(r.tool_calls[0].function.arguments) == {
"city": "Beijing",
"date": "2026-03-30",
}
def test_args_with_newlines(self, hy_v3_tool_parser, mock_request):
out = (
"<tool_calls>\n<tool_call>get_weather<tool_sep>\n<arg_key>city</arg_key>\n<arg_value>Beijing"
"</arg_value>\n<arg_key>date</arg_key>\n<arg_value>2026-03-30</arg_value>\n</tool_call>\n</tool_calls>"
)
r = hy_v3_tool_parser.extract_tool_calls(out, request=mock_request)
assert r.tools_called
assert json.loads(r.tool_calls[0].function.arguments) == {
"city": "Beijing",
"date": "2026-03-30",
}
def test_content_before(self, hy_v3_tool_parser, mock_request):
out = "Checking.<tool_calls>\n<tool_call>get_current_date<tool_sep>\n</tool_call>\n</tool_calls>"
r = hy_v3_tool_parser.extract_tool_calls(out, request=mock_request)
assert r.tools_called
assert r.content == "Checking."
def test_multiple(self, hy_v3_tool_parser, mock_request):
out = (
"<tool_calls>\n<tool_call>get_weather<tool_sep>\n<arg_key>city</arg_key>\n<arg_value>Beijing"
"</arg_value>\n<arg_key>date</arg_key>\n<arg_value>2026-03-30</arg_value>\n</tool_call>\n"
"<tool_call>get_weather<tool_sep>\n<arg_key>city</arg_key>\n<arg_value>Hangzhou</arg_value>\n"
"<arg_key>date</arg_key>\n<arg_value>2026-03-30</arg_value>\n</tool_call>\n</tool_calls>"
)
r = hy_v3_tool_parser.extract_tool_calls(out, request=mock_request)
assert len(r.tool_calls) == 2
def test_empty_content_none(self, hy_v3_tool_parser, mock_request):
out = "<tool_calls>\n<tool_call>get_current_date<tool_sep>\n</tool_call>\n</tool_calls>"
r = hy_v3_tool_parser.extract_tool_calls(out, request=mock_request)
assert r.content is None
def _simulate_streaming(
parser: HYV3ToolParser,
deltas: list[str],
request: ChatCompletionRequest,
) -> list[DeltaMessage | None]:
results: list[DeltaMessage | None] = []
previous_text = ""
previous_token_ids: list[int] = []
vocab = parser.vocab
for delta_text in deltas:
current_text = previous_text + delta_text
delta_token_ids = [tid for tok, tid in vocab.items() if tok in delta_text]
current_token_ids = previous_token_ids + delta_token_ids
result = parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
request=request,
)
results.append(result)
previous_text = current_text
previous_token_ids = current_token_ids
return results
def _collect_streaming_tool_calls(results: list[DeltaMessage | None]) -> list[dict]:
tool_calls: dict[int, dict] = {}
for result in results:
if result is None or not result.tool_calls:
continue
for tc in result.tool_calls:
idx = tc.index
if idx not in tool_calls:
tool_calls[idx] = {
"name": tc.function.name or "",
"arguments": tc.function.arguments or "",
}
else:
if tc.function.name:
tool_calls[idx]["name"] += tc.function.name
if tc.function.arguments:
tool_calls[idx]["arguments"] += tc.function.arguments
return [tool_calls[i] for i in sorted(tool_calls.keys())]
def _collect_streaming_content(results: list[DeltaMessage | None]) -> str:
parts = []
for result in results:
if result is not None and result.content:
parts.append(result.content)
return "".join(parts)
class TestHYV3ExtractToolCallsStreaming:
def test_no_tool_call_streaming(self, hy_v3_tool_parser, mock_request):
deltas = ["This is ", "a plain ", "response."]
results = _simulate_streaming(hy_v3_tool_parser, deltas, mock_request)
content = _collect_streaming_content(results)
assert content == "This is a plain response."
assert len(_collect_streaming_tool_calls(results)) == 0
def test_zero_arg_streaming(self, hy_v3_tool_parser, mock_request):
deltas = [
"<tool_calls>",
"\n<tool_call>",
"get_current_date",
"<tool_sep>",
"\n</tool_call>",
"\n</tool_calls>",
]
results = _simulate_streaming(hy_v3_tool_parser, deltas, mock_request)
tc = _collect_streaming_tool_calls(results)
assert len(tc) == 1
assert tc[0]["name"] == "get_current_date"
assert json.loads(tc[0]["arguments"]) == {}
def test_args_streaming(self, hy_v3_tool_parser, mock_request):
deltas = [
"<tool_calls>",
"\n<tool_call>",
"get_weather",
"<tool_sep>",
"\n<arg_key>city</arg_key>",
"\n<arg_value>Beijing</arg_value>",
"\n<arg_key>date</arg_key>",
"\n<arg_value>2026-03-30</arg_value>",
"\n</tool_call>",
"\n</tool_calls>",
]
results = _simulate_streaming(hy_v3_tool_parser, deltas, mock_request)
tc = _collect_streaming_tool_calls(results)
assert len(tc) == 1 and tc[0]["name"] == "get_weather"
assert json.loads(tc[0]["arguments"]) == {
"city": "Beijing",
"date": "2026-03-30",
}
def test_content_before_streaming(self, hy_v3_tool_parser, mock_request):
deltas = [
"Checking.",
"<tool_calls>",
"\n<tool_call>",
"get_current_date",
"<tool_sep>",
"\n</tool_call>",
"\n</tool_calls>",
]
results = _simulate_streaming(hy_v3_tool_parser, deltas, mock_request)
assert "Checking." in _collect_streaming_content(results)
tc = _collect_streaming_tool_calls(results)
assert len(tc) == 1 and tc[0]["name"] == "get_current_date"
def test_multiple_streaming(self, hy_v3_tool_parser, mock_request):
deltas = [
"<tool_calls>",
"\n<tool_call>",
"get_weather",
"<tool_sep>",
"\n<arg_key>city</arg_key>",
"\n<arg_value>Beijing</arg_value>",
"\n<arg_key>date</arg_key>",
"\n<arg_value>2026-03-30</arg_value>",
"\n</tool_call>",
"\n<tool_call>",
"get_weather",
"<tool_sep>",
"\n<arg_key>city</arg_key>",
"\n<arg_value>Hangzhou</arg_value>",
"\n<arg_key>date</arg_key>",
"\n<arg_value>2026-03-30</arg_value>",
"\n</tool_call>",
"\n</tool_calls>",
]
results = _simulate_streaming(hy_v3_tool_parser, deltas, mock_request)
tc = _collect_streaming_tool_calls(results)
assert len(tc) == 2
assert json.loads(tc[0]["arguments"])["city"] == "Beijing"
assert json.loads(tc[1]["arguments"])["city"] == "Hangzhou"
def test_all_in_one_delta_streaming(self, hy_v3_tool_parser, mock_request):
out = "<tool_calls>\n<tool_call>get_current_date<tool_sep>\n</tool_call>\n</tool_calls>"
results = _simulate_streaming(hy_v3_tool_parser, [out], mock_request)
tc = _collect_streaming_tool_calls(results)
assert len(tc) == 1 and tc[0]["name"] == "get_current_date"
assert json.loads(tc[0]["arguments"]) == {}
...@@ -47,6 +47,7 @@ MTPModelTypes = Literal[ ...@@ -47,6 +47,7 @@ MTPModelTypes = Literal[
"mtp", "mtp",
"pangu_ultra_moe_mtp", "pangu_ultra_moe_mtp",
"step3p5_mtp", "step3p5_mtp",
"hy_v3_mtp",
] ]
NgramGPUTypes = Literal["ngram_gpu"] NgramGPUTypes = Literal["ngram_gpu"]
DFlashModelTypes = Literal["dflash"] DFlashModelTypes = Literal["dflash"]
...@@ -364,6 +365,13 @@ class SpeculativeConfig: ...@@ -364,6 +365,13 @@ class SpeculativeConfig:
if initial_architecture == "MistralLarge3ForCausalLM": if initial_architecture == "MistralLarge3ForCausalLM":
hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]}) hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]})
if hf_config.model_type == "hy_v3":
hf_config.model_type = "hy_v3_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{"n_predict": n_predict, "architectures": ["HYV3MTPModel"]}
)
return hf_config return hf_config
def __post_init__(self): def __post_init__(self):
......
...@@ -1562,6 +1562,11 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None: ...@@ -1562,6 +1562,11 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
# NemotronH format: .mixer.{k,v}_proj.{k,v}_scale -> # NemotronH format: .mixer.{k,v}_proj.{k,v}_scale ->
# .mixer.attn.{k,v}_scale # .mixer.attn.{k,v}_scale
(r"\.mixer\.[kv]_proj\.([kv])_scale$", r".mixer.attn.\1_scale"), (r"\.mixer\.[kv]_proj\.([kv])_scale$", r".mixer.attn.\1_scale"),
# HYV3 format: .self_attn.q.scale -> .self_attn.attn.q_scale
(r"\.self_attn\.q\.scale$", r".self_attn.attn.q_scale"),
# HYV3 format: .self_attn.{k,v}_cache.scale ->
# .self_attn.attn.{k,v}_scale
(r"\.self_attn\.([kv])_cache\.scale$", r".self_attn.attn.\1_scale"),
# Default format: .{k,v}_scale -> .attn.{k,v}_scale # Default format: .{k,v}_scale -> .attn.{k,v}_scale
(r"\.([qkv])_scale$", r".attn.\1_scale"), (r"\.([qkv])_scale$", r".attn.\1_scale"),
(r"\.([qkv])_zero_point$", r".attn.\1_zero_point"), (r"\.([qkv])_zero_point$", r".attn.\1_zero_point"),
...@@ -1576,6 +1581,9 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None: ...@@ -1576,6 +1581,9 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
".k_zero_point", ".k_zero_point",
".v_zero_point", ".v_zero_point",
".q_zero_point", ".q_zero_point",
".q.scale",
".k_cache.scale",
".v_cache.scale",
) )
): ):
import regex as re import regex as re
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# coding=utf-8
# Copyright 2026 The HY team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only HY V3 MTP model compatible with HuggingFace weights."""
from collections.abc import Iterable
import regex as re
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.sequence import IntermediateTensors
from vllm.v1.outputs import SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler
from .hy_v3 import HYV3DecoderLayer, get_spec_layer_idx_from_weight_name
from .utils import is_pp_missing_parameter, maybe_prefix
def _is_moe(config: PretrainedConfig) -> bool:
return bool(
getattr(config, "num_experts", None)
and (
(isinstance(config.num_experts, int) and config.num_experts > 1)
or (isinstance(config.num_experts, list) and max(config.num_experts) > 1)
)
)
def _get_cla_factor(config: PretrainedConfig) -> int:
if not getattr(config, "use_cla", False):
return 1
return getattr(config, "cla_share_factor", 1)
class HYV3SharedHead(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
) -> None:
super().__init__()
self.head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states
class HYV3MultiTokenPredictorLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
prefix: str,
model_config: ModelConfig,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
) -> None:
super().__init__()
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
self.shared_head = HYV3SharedHead(config=config, quant_config=quant_config)
self.mtp_block = HYV3DecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
)
# Final layernorm applied after transformer block, before logits
# projection (matches HF HYV3MTPDecoderLayer.final_layernorm)
self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
spec_step_index: int = 0,
) -> torch.Tensor:
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
inputs_embeds[positions == 0] = 0
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1)
)
# HYV3DecoderLayer returns (hidden_states, residual)
hidden_states, residual = self.mtp_block(
positions=positions, hidden_states=hidden_states, residual=None
)
hidden_states = residual + hidden_states
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
class HYV3MultiTokenPredictor(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = config.num_nextn_predict_layers
# to map the exact layer index from weights
self.layers = torch.nn.ModuleDict(
{
str(idx): HYV3MultiTokenPredictorLayer(
config,
f"{prefix}.layers.{idx}",
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config,
)
for idx in range(
self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers,
)
}
)
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = spec_step_idx % self.num_mtp_layers
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids,
positions,
previous_hidden_states,
inputs_embeds,
current_step_idx,
)
def compute_logits(
self,
hidden_states: torch.Tensor,
spec_step_idx: int = 0,
) -> torch.Tensor:
current_step_idx = spec_step_idx % self.num_mtp_layers
mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
logits = self.logits_processor(
mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states)
)
return logits
class HYV3MTP(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
self.quant_config = vllm_config.quant_config
self.model = HYV3MultiTokenPredictor(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
hidden_states = self.model(
input_ids, positions, hidden_states, inputs_embeds, spec_step_idx
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
spec_step_idx: int = 0,
) -> torch.Tensor | None:
return self.model.compute_logits(hidden_states, spec_step_idx)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput | None:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def _split_qkv_weight(self, qkv: torch.Tensor):
num_attention_heads = self.config.num_attention_heads
num_kv_heads = getattr(
self.config, "num_key_value_heads", self.config.num_attention_heads
)
num_key_value_groups = num_attention_heads // num_kv_heads
hidden_size = self.config.hidden_size
if hasattr(self.config, "head_dim"):
attention_head_dim = self.config.head_dim
elif hasattr(self.config, "attention_head_dim"):
attention_head_dim = self.config.attention_head_dim
else:
attention_head_dim = self.config.hidden_size // num_attention_heads
qkv = qkv.reshape(
num_kv_heads, num_key_value_groups + 2, attention_head_dim, hidden_size
)
q, k, v = torch.split(qkv, (num_key_value_groups, 1, 1), dim=1)
q = q.reshape(-1, hidden_size)
k = k.reshape(-1, hidden_size)
v = v.reshape(-1, hidden_size)
return torch.concat((q, k, v))
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
cla_factor = _get_cla_factor(self.config)
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
num_attention_heads = self.config.num_attention_heads
num_kv_heads = getattr(
self.config, "num_key_value_heads", self.config.num_attention_heads
)
split_params_mapping = [
(".gate_up_proj", ".gate_and_up_proj", 2, [(1, 1), (0, 1)], None),
(
".qkv_proj",
".qkv_proj",
num_attention_heads + num_kv_heads * 2,
[("q", num_attention_heads), ("k", num_kv_heads), ("v", num_kv_heads)],
self._split_qkv_weight,
),
]
if _is_moe(self.config):
expert_params_mapping = FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
else:
expert_params_mapping = {}
params_dict = dict(self.named_parameters())
# V3 shared weights mapping:
# - embed_tokens: from main model's model.embed_tokens.weight
# - lm_head: from main model's lm_head.weight → MTP shared_head.head
# (HF infer_mtp uses head_weight=self.lm_head.weight, not the
# checkpoint's model.layers.<N>.shared_head.weight)
# - No norm mapping (V3 MTP has no intermediate norm before lm_head)
mtp_start = self.config.num_hidden_layers
v3_shared_weights = {
"model.embed_tokens.weight": "model.embed_tokens.weight",
"lm_head.weight": f"model.layers.{mtp_start}.shared_head.head.weight",
}
for name, loaded_weight in weights:
# Intercept shared weights before any other processing
if name in v3_shared_weights:
target_name = v3_shared_weights[name]
if target_name in params_dict:
param = params_dict[target_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
continue
if "rotary_emb.inv_freq" in name:
continue
if "gate_proj_bias" in name:
name = name.replace("gate_proj_bias", "gate_proj.bias")
if "up_proj_bias" in name:
name = name.replace("up_proj_bias", "up_proj.bias")
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None:
continue
name = self._rewrite_spec_layer_name(spec_layer, name)
# Skip weights that _rewrite_spec_layer_name marked for skipping
if name == "__skip__":
continue
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
is_found = False
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if "mlp.experts" in name:
continue
if weight_name == ".q_proj":
match = re.search(r"layers\.\d+", name)
if match:
layer_id = int(match.group(0).split(".")[-1])
if cla_factor > 1 and layer_id % cla_factor != 0:
continue
name = name.replace(weight_name, param_name)
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
is_found = True
break
if is_found:
continue
for param_name, weight_name, den, split_param, func in split_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
assert loaded_weight.shape[0] % den == 0
units = loaded_weight.shape[0] // den
param = params_dict[name]
weight_loader = param.weight_loader
offset = 0
for shard_id, num in split_param:
new_offset = offset + num * units
if func:
weight_loader(
param, func(loaded_weight)[offset:new_offset], shard_id
)
else:
weight_loader(param, loaded_weight[offset:new_offset], shard_id)
offset = new_offset
break
else:
if name.endswith(".bias") and name not in params_dict:
continue
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
if is_pp_missing_parameter(name, self):
continue
if "mlp.gate.wg." in name:
name = name.replace("wg.", "")
# V3 checkpoint: mlp.router.gate -> mlp.gate
if "mlp.router.gate." in name:
name = name.replace("router.gate.", "gate.")
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
"""Rewrite spec layer weight names to match vLLM module structure."""
# Skip embed_tokens (doesn't exist in V3 MTP checkpoint under spec
# layer) and shared_head (we use main model's lm_head instead)
if f"model.layers.{spec_layer}.embed_tokens" in name:
return "__skip__"
if f"model.layers.{spec_layer}.shared_head" in name:
return "__skip__"
spec_layer_weight_names = ["enorm", "hnorm", "eh_proj", "final_layernorm"]
spec_layer_weight = False
for weight_name in spec_layer_weight_names:
if weight_name in name:
spec_layer_weight = True
break
if not spec_layer_weight:
# Transformer block weights go under .mtp_block
name = name.replace(
f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block."
)
return name
...@@ -133,6 +133,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -133,6 +133,7 @@ _TEXT_GENERATION_MODELS = {
"Grok1ForCausalLM": ("grok1", "GrokForCausalLM"), "Grok1ForCausalLM": ("grok1", "GrokForCausalLM"),
"HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"), "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
"HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"), "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
"HYV3ForCausalLM": ("hy_v3", "HYV3ForCausalLM"),
"HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"), "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
"HCXVisionV2ForCausalLM": ("hyperclovax_vision_v2", "HCXVisionV2ForCausalLM"), "HCXVisionV2ForCausalLM": ("hyperclovax_vision_v2", "HCXVisionV2ForCausalLM"),
"HyperCLOVAXForCausalLM": ("hyperclovax", "HyperCLOVAXForCausalLM"), "HyperCLOVAXForCausalLM": ("hyperclovax", "HyperCLOVAXForCausalLM"),
...@@ -599,6 +600,7 @@ _SPECULATIVE_DECODING_MODELS = { ...@@ -599,6 +600,7 @@ _SPECULATIVE_DECODING_MODELS = {
"Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"), "Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"),
"Qwen3_5MTP": ("qwen3_5_mtp", "Qwen3_5MTP"), "Qwen3_5MTP": ("qwen3_5_mtp", "Qwen3_5MTP"),
"Qwen3_5MoeMTP": ("qwen3_5_mtp", "Qwen3_5MoeMTP"), "Qwen3_5MoeMTP": ("qwen3_5_mtp", "Qwen3_5MoeMTP"),
"HYV3MTPModel": ("hy_v3_mtp", "HYV3MTP"),
# Temporarily disabled. # Temporarily disabled.
# # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1. # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
# "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
......
...@@ -56,6 +56,10 @@ _REASONING_PARSERS_TO_REGISTER = { ...@@ -56,6 +56,10 @@ _REASONING_PARSERS_TO_REGISTER = {
"hunyuan_a13b_reasoning_parser", "hunyuan_a13b_reasoning_parser",
"HunyuanA13BReasoningParser", "HunyuanA13BReasoningParser",
), ),
"hy_v3": (
"hy_v3_reasoning_parser",
"HYV3ReasoningParser",
),
"kimi_k2": ( "kimi_k2": (
"kimi_k2_reasoning_parser", "kimi_k2_reasoning_parser",
"KimiK2ReasoningParser", "KimiK2ReasoningParser",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Sequence
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class HYV3ReasoningParser(BaseThinkingReasoningParser):
"""
HYV3 parser that delegates to either HYV3ReasoningParser or
IdentityReasoningParser based on `reasoning_effort`.
The HYV3 model uses <think>...</think> tokens to denote reasoning text.
This parser extracts the reasoning content from the model output.
"""
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
# First, If there is reasoning_effort in chat_kwargs,
# prioritize using chat_kwargs.reasoning_effort.
# If it's not present, use the "reasoning_effort" field
# at the outer level of the chat message.
# Otherwise, If both are empty, assign "no_think".
chat_kwargs = kwargs.pop("chat_template_kwargs", {}) or {}
reasoning_effort = chat_kwargs.pop("reasoning_effort", "no_think")
logger.debug("reasoning_effort for choosing parser: %s", reasoning_effort)
self._identity_parser: IdentityReasoningParser | None
if reasoning_effort == "no_think":
self._identity_parser = IdentityReasoningParser(tokenizer, *args, **kwargs)
else:
self._identity_parser = None
@property
def start_token(self) -> str:
"""The token that starts reasoning content."""
return "<think>"
@property
def end_token(self) -> str:
"""The token that ends reasoning content."""
return "</think>"
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
if self._identity_parser is not None:
return self._identity_parser.is_reasoning_end(input_ids)
return super().is_reasoning_end(input_ids)
def is_reasoning_end_streaming(
self, input_ids: Sequence[int], delta_ids: Iterable[int]
) -> bool:
if self._identity_parser is not None:
return self._identity_parser.is_reasoning_end_streaming(
input_ids, delta_ids
)
return super().is_reasoning_end_streaming(input_ids, delta_ids)
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
if self._identity_parser is not None:
return self._identity_parser.extract_content_ids(input_ids)
return super().extract_content_ids(input_ids)
def extract_reasoning(
self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
) -> tuple[str | None, str | None]:
if self._identity_parser is not None:
return self._identity_parser.extract_reasoning(model_output, request)
return super().extract_reasoning(model_output, request)
def extract_reasoning_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
if self._identity_parser is not None:
return self._identity_parser.extract_reasoning_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
)
ret = super().extract_reasoning_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
)
if (
ret is not None
and self.start_token_id not in previous_token_ids
and self.start_token_id not in delta_token_ids
):
if self.end_token_id in delta_token_ids:
# end token in delta with more tokens,
# extract reasoning content and content
end_index = delta_text.find(self.end_token)
reasoning = delta_text[:end_index]
content = delta_text[end_index + len(self.end_token) :]
return DeltaMessage(
reasoning=reasoning,
content=content if content else None,
)
elif self.end_token_id in previous_token_ids:
# end token in previous, thinking content ends
return DeltaMessage(content=delta_text)
else:
# no end token in previous or delta, reasoning content continues
return DeltaMessage(reasoning=delta_text)
return ret
...@@ -66,6 +66,10 @@ _TOOL_PARSERS_TO_REGISTER = { ...@@ -66,6 +66,10 @@ _TOOL_PARSERS_TO_REGISTER = {
"hunyuan_a13b_tool_parser", "hunyuan_a13b_tool_parser",
"HunyuanA13BToolParser", "HunyuanA13BToolParser",
), ),
"hy_v3": (
"hy_v3_tool_parser",
"HYV3ToolParser",
),
"internlm": ( "internlm": (
"internlm2_tool_parser", "internlm2_tool_parser",
"Internlm2ToolParser", "Internlm2ToolParser",
......
This diff is collapsed.
...@@ -94,6 +94,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( ...@@ -94,6 +94,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
funaudiochat="FunAudioChatConfig", funaudiochat="FunAudioChatConfig",
granite4_vision="Granite4VisionConfig", granite4_vision="Granite4VisionConfig",
hunyuan_vl="HunYuanVLConfig", hunyuan_vl="HunYuanVLConfig",
hy_v3="HYV3Config",
isaac="IsaacConfig", isaac="IsaacConfig",
kimi_k2="DeepseekV3Config", # Kimi K2 uses same architecture as DeepSeek V3 kimi_k2="DeepseekV3Config", # Kimi K2 uses same architecture as DeepSeek V3
kimi_linear="KimiLinearConfig", kimi_linear="KimiLinearConfig",
......
...@@ -36,6 +36,7 @@ _CLASS_TO_MODULE: dict[str, str] = { ...@@ -36,6 +36,7 @@ _CLASS_TO_MODULE: dict[str, str] = {
"HunYuanVLConfig": "vllm.transformers_utils.configs.hunyuan_vl", "HunYuanVLConfig": "vllm.transformers_utils.configs.hunyuan_vl",
"HunYuanVLTextConfig": "vllm.transformers_utils.configs.hunyuan_vl", "HunYuanVLTextConfig": "vllm.transformers_utils.configs.hunyuan_vl",
"HunYuanVLVisionConfig": "vllm.transformers_utils.configs.hunyuan_vl", "HunYuanVLVisionConfig": "vllm.transformers_utils.configs.hunyuan_vl",
"HYV3Config": "vllm.transformers_utils.configs.hy_v3",
"HyperCLOVAXConfig": "vllm.transformers_utils.configs.hyperclovax", "HyperCLOVAXConfig": "vllm.transformers_utils.configs.hyperclovax",
"IsaacConfig": "vllm.transformers_utils.configs.isaac", "IsaacConfig": "vllm.transformers_utils.configs.isaac",
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and # RWConfig is for the original tiiuae/falcon-40b(-instruct) and
...@@ -97,6 +98,7 @@ __all__ = [ ...@@ -97,6 +98,7 @@ __all__ = [
"HunYuanVLConfig", "HunYuanVLConfig",
"HunYuanVLTextConfig", "HunYuanVLTextConfig",
"HunYuanVLVisionConfig", "HunYuanVLVisionConfig",
"HYV3Config",
"HyperCLOVAXConfig", "HyperCLOVAXConfig",
"IsaacConfig", "IsaacConfig",
"RWConfig", "RWConfig",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from transformers.configuration_utils import PretrainedConfig
class HYV3Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`HYV3Model`].
It is used to instantiate a HYV3 model (HY V3 MoE language model) according to
the specified arguments.
Configuration objects inherit from [`PretrainedConfig`] and can be used to
control the model outputs. Read the documentation from [`PretrainedConfig`]
for more information.
Args:
vocab_size (`int`, *optional*, defaults to 120832):
Vocabulary size of the model.
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 13312):
Dimension of the dense FFN intermediate representations.
num_hidden_layers (`int`, *optional*, defaults to 80):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 64):
Number of attention heads for each attention layer.
num_key_value_heads (`int`, *optional*, defaults to 8):
Number of key-value heads for grouped-query attention.
head_dim (`int`, *optional*, defaults to 128):
Dimension per attention head.
hidden_act (`str`, *optional*, defaults to `"silu"`):
Activation function used in FFN layers.
max_position_embeddings (`int`, *optional*, defaults to 131072):
Maximum sequence length supported by the model.
initializer_range (`float`, *optional*, defaults to 0.006):
Standard deviation of the truncated normal initializer for weight
initialization.
rms_norm_eps (`float`, *optional*, defaults to 1e-5):
Epsilon for RMS normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether to use KV cache for decoding.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*):
Beginning-of-sequence token id.
eos_token_id (`int` or `List[int]`, *optional*):
End-of-sequence token id(s).
rope_parameters (`dict`, *optional*):
The parameters of the RoPE embeddings.
qk_norm (`bool`, *optional*, defaults to `True`):
Whether to apply RMSNorm to query and key states before attention.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie input and output embedding weights.
enable_attention_fp32_softmax (`bool`, *optional*, defaults to `False`):
Whether to upcast attention softmax to float32. Note: the eager attention
path always computes softmax in float32 regardless of this setting; this
flag is reserved for future use with custom attention backends.
enable_lm_head_fp32 (`bool`, *optional*, defaults to `True`):
Whether to upcast the LM head computation to float32.
num_experts (`int`, *optional*, defaults to 192):
Total number of MoE experts.
num_experts_per_tok (`int`, *optional*, defaults to 8):
Number of experts selected per token (top-k routing).
num_shared_experts (`int`, *optional*, defaults to 1):
Number of always-active shared experts combined into a single MLP.
expert_hidden_dim (`int`, *optional*, defaults to 1536):
Intermediate dimension of each individual MoE expert.
moe_router_enable_expert_bias (`bool`, *optional*, defaults to `True`):
Whether to use per-expert load-balancing bias in the router.
moe_router_use_sigmoid (`bool`, *optional*, defaults to `True`):
Whether to use sigmoid (instead of softmax) for router scoring.
route_norm (`bool`, *optional*, defaults to `True`):
Whether to normalize routing scores when using sigmoid routing.
router_scaling_factor (`float`, *optional*):
Optional multiplicative scaling factor applied to routing scores.
use_grouped_mm (`bool`, *optional*, defaults to `False`):
Whether to use grouped GEMM for expert computation (not yet implemented).
enable_moe_fp32_combine (`bool`, *optional*, defaults to `False`):
Whether to accumulate expert outputs in float32.
first_k_dense_replace (`int`, *optional*, defaults to 1):
Number of initial decoder layers that use a dense FFN instead of MoE.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether to output router logits from each MoE layer. Useful for computing
auxiliary load-balancing loss during training. Disabled by default to avoid
the memory overhead of storing per-layer router tensors during inference.
Example:
```python
>>> from transformers import HYV3Config, HYV3Model
>>> config = HYV3Config()
>>> model = HYV3Model(config)
```
"""
model_type = "hy_v3"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=120832,
hidden_size=4096,
intermediate_size=13312,
num_hidden_layers=80,
num_attention_heads=64,
num_key_value_heads=8,
head_dim=128,
hidden_act="silu",
max_position_embeddings=131072,
initializer_range=0.006,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=None,
bos_token_id=None,
eos_token_id=None,
rope_parameters: dict[str, Any] | None = None,
qk_norm=True,
tie_word_embeddings=False,
enable_attention_fp32_softmax=False,
enable_lm_head_fp32=True,
# MoE specific
num_experts=192,
num_experts_per_tok=8,
num_shared_experts=1,
expert_hidden_dim=1536,
moe_router_enable_expert_bias=True,
moe_router_use_sigmoid=True,
route_norm=True,
router_scaling_factor=None,
use_grouped_mm=False,
enable_moe_fp32_combine=False,
# Dense/MoE layer control
first_k_dense_replace=1,
output_router_logits=False,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
rope_theta = kwargs.pop("rope_theta", 11158840.0)
if rope_parameters is None:
rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
self.rope_parameters = rope_parameters
self.qk_norm = qk_norm
self.tie_word_embeddings = tie_word_embeddings
self.enable_lm_head_fp32 = enable_lm_head_fp32
self.enable_attention_fp32_softmax = enable_attention_fp32_softmax
# MoE specific
self.num_experts = num_experts
self.num_experts_per_tok = num_experts_per_tok
self.num_shared_experts = num_shared_experts
self.expert_hidden_dim = expert_hidden_dim
self.moe_router_enable_expert_bias = moe_router_enable_expert_bias
self.moe_router_use_sigmoid = moe_router_use_sigmoid
self.route_norm = route_norm
self.use_grouped_mm = use_grouped_mm
self.router_scaling_factor = router_scaling_factor
self.enable_moe_fp32_combine = enable_moe_fp32_combine
# Dense/MoE layer control
self.first_k_dense_replace = first_k_dense_replace
self.output_router_logits = output_router_logits
if eos_token_id is not None and isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment