Commit 643c0d59 authored by luopl's avatar luopl
Browse files

解决test2的qfeat传输问题

parent 9e61b53d
#!/usr/bin/env python3
"""Minimal classify demo using token IDs as input.
This mirrors the docs example:
llm = LLM(model="...", runner="pooling")
(output,) = llm.classify("Hello, my name is")
but feeds DEFAULT_PROMPT_TOKEN_IDS via token_inputs instead of text.
"""
from vllm import LLM
from vllm.inputs import token_inputs
from transformers import AutoTokenizer
DEFAULT_PROMPT_TOKEN_IDS = [
[127958, 58, 10172, 24575, 8437, 7489, 51, 60, 220, 57668, 102832, 80073, 75761, 102245, 39045, 57668, 105982, 103429, 88852, 9743, 34208, 2929, 3922, 101423, 83125, 110357, 107759, 82317, 101505, 101009, 1811, 15225, 61633, 3922, 101992, 80073, 120702, 17, 15, 17, 20, 8107, 15, 23, 9953, 17, 22, 9080, 3490, 2929, 5232, 86461, 102160, 36827, 31867, 19, 34208, 86461, 102160, 36827, 31867, 21, 107938, 105528, 198, 12, 11615, 101241, 5232, 111642, 198, 12, 11615, 104780, 101526, 105344, 5232, 100377, 104780, 198, 12, 11615, 106444, 101526, 105344, 5232, 101055, 106444, 271, 9743, 29411, 12, 52561, 229, 34972, 5232, 86461, 102160, 36827, 31867, 21, 101129, 91547, 6704, 238, 255, 28466, 100433, 118068, 5463, 55, 220, 20, 15, 21, 15, 100824, 123798, 20, 24, 24, 24, 24186, 72718, 62, 59505, 112203, 32239, 198, 12, 73028, 96, 17161, 5232, 86461, 102160, 36827, 31867, 21, 101129, 91547, 6704, 238, 255, 28466, 100433, 118068, 5463, 55, 220, 20, 15, 21, 15, 100824, 123798, 20, 24, 24, 24, 24186, 72718, 198, 20, 9953, 16, 24, 104223, 3922, 86461, 102160, 106518, 101612, 100733, 17039, 33671, 24186, 101016, 22656, 36827, 31867, 21, 101129, 102984, 91547, 1811, 101291, 57, 115718, 100532, 100880, 101612, 108616, 101016, 105164, 3922, 36827, 31867, 21, 10447, 37197, 32218, 112561, 117408, 35287, 42246, 111379, 100992, 103735, 1811, 107743, 100433, 118068, 5463, 55, 220, 20, 15, 21, 15, 120928, 50991, 107758, 31867, 21, 1322, 102072, 36827, 31867, 21, 112848, 104946, 111379, 102735, 105541, 3922, 92672, 109034, 119927, 3922, 101739, 101911, 100532, 5486, 9080, 102874, 102581, 100518, 39135, 113577, 1811, 102317, 57237, 16, 13, 22, 24, 6358, 100547, 95337, 17, 13, 17, 7501, 10110, 36827, 31867, 21, 49372, 46281, 32943, 28037, 48915, 101026, 101058, 123721, 1811, 108895, 101035, 3922, 36827, 31867, 21, 101129, 100503, 72843, 17, 15, 4, 100477, 105516, 3922, 110349, 107371, 24946, 101971, 36, 100646, 105653, 3922, 123798, 20, 24, 24, 24, 24186, 72718, 6447, 198, 86461, 102160, 36827, 31867, 21, 5232, 116967, 109845, 101037, 10447, 94, 105, 72237, 27327, 76537, 198, 86461, 102160, 36827, 31867, 21, 113577, 105868, 58805, 94278, 238, 100667, 15592, 220, 22, 473, 220, 17, 21, 15, 107311, 3922, 23, 101021, 16, 21, 120312, 71600, 3922, 19, 20211, 102271, 102812, 103486, 1811, 103954, 9688, 19085, 10860, 55, 220, 20, 15, 21, 15, 120928, 50991, 3922, 16, 16, 20, 54, 100433, 118068, 71600, 3922, 101046, 16931, 1242, 220, 19, 100344, 1811, 105874, 17905, 3922, 107743, 16, 21, 108465, 17, 13, 20, 42, 220, 16, 21, 20, 11732, 109943, 53434, 108018, 3922, 101046, 38, 6354, 72501, 100344, 3922, 43292, 103897, 104196, 120822, 82317, 102698, 58322, 3922, 19, 15, 271, 9743, 91547, 58521, 29411, 482, 220, 58521, 31091, 5232, 964, 16, 21, 23, 198, 482, 220, 58521, 101241, 5232, 111642, 11, 76771, 239, 83301, 11, 47850, 233, 33748, 198, 482, 220, 58521, 105302, 5232, 109173, 271, 22452, 91547, 101143, 5232, 112203, 32239, 198, 482, 75677, 111, 55038, 101241, 5232, 104312, 11, 4996, 223, 98, 100563, 11, 220, 101766, 198, 482, 41766, 229, 81742, 33005, 5232, 100359, 198, 482, 75677, 111, 55038, 105344, 5232, 23, 198, 482, 61696, 225, 101028, 105344, 5232, 17, 271, 9743, 91547, 21082, 5232, 17, 15, 17, 20, 8107, 15, 20, 9953, 16, 24, 9080, 198, 482, 9085, 115, 251, 104944, 9039, 5232, 24, 24, 271, 15225, 67117, 83125, 110357, 107759, 5232, 127962, 127972, 127973, 127974, 127975, 127967],
[127958, 58, 10172, 24575, 8437, 7489, 51, 60, 220, 57668, 102832, 80073, 75761, 102245, 39045, 57668, 105982, 103429, 88852, 9743, 34208, 2929, 3922, 101423, 83125, 110357, 107759, 82317, 101505, 101009, 1811, 15225, 61633, 3922, 101992, 80073, 120702, 17, 15, 17, 20, 8107, 15, 23, 9953, 17, 22, 9080, 3490, 2929, 5232, 86461, 102160, 36827, 31867, 19, 34208, 86461, 102160, 36827, 31867, 21, 107938, 105528, 198, 12, 11615, 101241, 5232, 111642, 198, 12, 11615, 104780, 101526, 105344, 5232, 100377, 104780, 198, 12, 11615, 106444, 101526, 105344, 5232, 101055, 106444, 271, 9743, 29411, 12, 52561, 229, 34972, 5232, 86461, 102160, 36827, 31867, 21, 101129, 91547, 6704, 238, 255, 28466, 100433, 118068, 5463, 55, 220, 20, 15, 21, 15, 100824, 123798, 20, 24, 24, 24, 24186, 72718, 62, 59505, 112203, 32239, 198, 12, 73028, 96, 17161, 5232, 86461, 102160, 36827, 31867, 21, 101129, 91547, 6704, 238, 255, 28466, 100433, 118068, 5463, 55, 220, 20, 15, 21, 15, 100824, 123798, 20, 24, 24, 24, 24186, 72718, 198, 20, 9953, 16, 24, 104223, 3922, 86461, 102160, 106518, 101612, 100733, 17039, 33671, 24186, 101016, 22656, 36827, 31867, 21, 101129, 102984, 91547, 1811, 101291, 57, 115718, 100532, 100880, 101612, 108616, 101016, 105164, 3922, 36827, 31867, 21, 10447, 37197, 32218, 112561, 117408, 35287, 42246, 111379, 100992, 103735, 1811, 107743, 100433, 118068, 5463, 55, 220, 20, 15, 21, 15, 120928, 50991, 107758, 31867, 21, 1322, 102072, 36827, 31867, 21, 112848, 104946, 111379, 102735, 105541, 3922, 92672, 109034, 119927, 3922, 101739, 101911, 100532, 5486, 9080, 102874, 102581, 100518, 39135, 113577, 1811, 102317, 57237, 16, 13, 22, 24, 6358, 100547, 95337, 17, 13, 17, 7501, 10110, 36827, 31867, 21, 49372, 46281, 32943, 28037, 48915, 101026, 101058, 123721, 1811, 108895, 101035, 3922, 36827, 31867, 21, 101129, 100503, 72843, 17, 15, 4, 100477, 105516, 3922, 110349, 107371, 24946, 101971, 36, 100646, 105653, 3922, 123798, 20, 24, 24, 24, 24186, 72718, 6447, 198, 86461, 102160, 36827, 31867, 21, 5232, 116967, 109845, 101037, 10447, 94, 105, 72237, 27327, 76537, 198, 86461, 102160, 36827, 31867, 21, 113577, 105868, 58805, 94278, 238, 100667, 15592, 220, 22, 473, 220, 17, 21, 15, 107311, 3922, 23, 101021, 16, 21, 120312, 71600, 3922, 19, 20211, 102271, 102812, 103486, 1811, 103954, 9688, 19085, 10860, 55, 220, 20, 15, 21, 15, 120928, 50991, 3922, 16, 16, 20, 54, 100433, 118068, 71600, 3922, 101046, 16931, 1242, 220, 19, 100344, 1811, 105874, 17905, 3922, 107743, 16, 21, 108465, 17, 13, 20, 42, 220, 16, 21, 20, 11732, 109943, 53434, 108018, 3922, 101046, 38, 6354, 72501, 100344, 3922, 43292, 103897, 104196, 120822, 82317, 102698, 58322, 3922, 19, 15, 271, 9743, 91547, 58521, 29411, 482, 220, 58521, 31091, 5232, 964, 16, 21, 23, 198, 482, 220, 58521, 101241, 5232, 111642, 11, 76771, 239, 83301, 11, 47850, 233, 33748, 198, 482, 220, 58521, 105302, 5232, 109173, 271, 22452, 91547, 101143, 5232, 112203, 32239, 198, 482, 75677, 111, 55038, 101241, 5232, 104312, 11, 4996, 223, 98, 100563, 11, 220, 101766, 198, 482, 41766, 229, 81742, 33005, 5232, 100359, 198, 482, 75677, 111, 55038, 105344, 5232, 23, 198, 482, 61696, 225, 101028, 105344, 5232, 17, 271, 9743, 91547, 21082, 5232, 17, 15, 17, 20, 8107, 15, 20, 9953, 16, 24, 9080, 198, 482, 9085, 115, 251, 104944, 9039, 5232, 24, 24, 271, 15225, 67117, 83125, 110357, 107759, 5232, 127962, 127972, 127973, 127974, 127975, 127967]
]
PROMPTS = "[GenRM-vCoT] 你是一个搜索排序专家,请你仔细阅读以下Doc和Query,给出文章满意度评分及具体原因。请注意,本次搜索时间是2025年08月27日。\n\nQuery:华硕天选4和华硕天选6性价比对比\n- Query领域:电子产品\n- Query时效需求等级:低时效\n- Query权威需求等级:弱权威\n\nDoc:\n- 标题:华硕天选6系列发布 搭载满功耗RTX 5060实际到手5999元起_手机新浪网\n- 正文:<pcut>华硕天选6系列发布 搭载满功耗RTX 5060实际到手5999元起\n5月19日晚,华硕旗下潮玩新次元游戏本天选6系列正式发布。作为Z世代青年的潮酷游戏装备,天选6 系列再一次印证了其出色的综合实力。搭载满功耗RTX 5060笔记本电脑GPU的天选6 Pro以及天选6皆拥有出色的性能释放,同时颜值出众,魔幻青、日蚀灰双色可选。薄至1.79cm轻约2.2kg(天选6),从内到外实现全面进阶。首发期间,天选6系列均享20%国家补贴,叠加晒单返E卡福利,到手5999元起!\n华硕天选6:超高选购价值 硬核能打\n华硕天选6可选全新AMD 锐龙 AI 7 H 260处理器,8核心16线程设计,4nm先进工艺打造。配备GeForce RTX 5060笔记本电脑GPU,115W满功耗设计,支持DLSS 4技术。屏幕上,搭载16英寸2.5K 165Hz电竞级面板,支持G-SYNC技术,无惧画面撕裂及拖影,40<pcut>\n\nDoc发布作者:\n - 作者名称:IT168\n - 作者领域:电子产品, 科技, 手机\n - 作者认证:未知\n\n Doc发布平台:新浪网\n - 平台领域:财经, 健康, 旅游\n - 备案类型:企业\n - 平台等级:8\n - 权威等级:2\n\nDoc发布时间:2025年05月19日\n - 距今天数:99\n\n请输出文章满意度评分:"
MODELPATH = "/home/luopl/hunyuan_tx/test_2"
def test_prompt(llm):
tokenizer = AutoTokenizer.from_pretrained(MODELPATH, trust_remote_code=True)
input_ids = tokenizer(PROMPTS, return_tensors="pt", trust_remote_code=True)["input_ids"]
outputs = llm.classify(token_inputs(input_ids[0], qfeat=[2, 0, 20]))
for i, out in enumerate(outputs):
probs = out.outputs.probs
print(f"Request {i}, class probs = {probs}")
def test_tokenid(llm):
for ids in DEFAULT_PROMPT_TOKEN_IDS:
outputs = llm.classify(token_inputs(ids, qfeat=[2, 0, 20]))
for i, out in enumerate(outputs):
probs = out.outputs.probs
print(f"Request {i}, class probs = {probs}")
if __name__ == "__main__":
llm = LLM(model=MODELPATH, task="classify",
trust_remote_code=True,
enforce_eager=True,
enable_chunked_prefill=False)
test_prompt(llm)
test_tokenid(llm)
# print(input_ids)
\ No newline at end of file
...@@ -10,7 +10,7 @@ import torch.nn as nn ...@@ -10,7 +10,7 @@ import torch.nn as nn
from pydantic import ValidationError from pydantic import ValidationError
from tqdm.auto import tqdm from tqdm.auto import tqdm
from typing_extensions import TypeVar from typing_extensions import TypeVar
from vllm.inputs import token_inputs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, BeamSearchSequence,
create_sort_beams_key_function) create_sort_beams_key_function)
...@@ -89,7 +89,7 @@ class LLM: ...@@ -89,7 +89,7 @@ class LLM:
or videos from directories specified by the server file system. or videos from directories specified by the server file system.
This is a security risk. Should only be enabled in trusted This is a security risk. Should only be enabled in trusted
environments. environments.
allowed_media_domains: If set, only media URLs that belong to this allowed_media_domains: If set, only media URLs that belong to this
domain can be used for multi-modal inputs. domain can be used for multi-modal inputs.
tensor_parallel_size: The number of GPUs to use for distributed tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism. execution with tensor parallelism.
...@@ -984,6 +984,9 @@ class LLM: ...@@ -984,6 +984,9 @@ class LLM:
# Use default pooling params. # Use default pooling params.
pooling_params = PoolingParams() pooling_params = PoolingParams()
if prompts["qfeat"] is not None:
pooling_params.qfeat = prompts["qfeat"]
for param in as_iter(pooling_params): for param in as_iter(pooling_params):
param.verify(pooling_task, model_config) param.verify(pooling_task, model_config)
# for backwards compatibility # for backwards compatibility
......
...@@ -120,15 +120,15 @@ Note that "singleton" is as opposed to a data structure ...@@ -120,15 +120,15 @@ Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder the user desires to express both the encoder & decoder
prompts explicitly, i.e. prompts explicitly, i.e.
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt] [`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
A prompt of type [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] may be A prompt of type [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] may be
employed as (1) input to a decoder-only model, (2) input to employed as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating (3) as a member of a larger data structure encapsulating
more than one prompt, i.e. more than one prompt, i.e.
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt] [`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
""" """
...@@ -220,16 +220,21 @@ def token_inputs( ...@@ -220,16 +220,21 @@ def token_inputs(
prompt_token_ids: list[int], prompt_token_ids: list[int],
prompt: Optional[str] = None, prompt: Optional[str] = None,
cache_salt: Optional[str] = None, cache_salt: Optional[str] = None,
qfeat: Optional[list] = None,
) -> TokenInputs: ) -> TokenInputs:
"""Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional """Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional
values.""" values."""
# print("************* {} ***************".format(qfeat))
if isinstance(prompt_token_ids, torch.Tensor):
prompt_token_ids = prompt_token_ids.tolist()
inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids) inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
if prompt is not None: if prompt is not None:
inputs["prompt"] = prompt inputs["prompt"] = prompt
if cache_salt is not None: if cache_salt is not None:
inputs["cache_salt"] = cache_salt inputs["cache_salt"] = cache_salt
if qfeat is not None:
inputs["qfeat"] = qfeat
return inputs return inputs
......
...@@ -314,7 +314,6 @@ class InputPreprocessor: ...@@ -314,7 +314,6 @@ class InputPreprocessor:
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = self._truncate_inputs( prompt_token_ids = self._truncate_inputs(
parsed_content["prompt_token_ids"], tokenization_kwargs) parsed_content["prompt_token_ids"], tokenization_kwargs)
inputs: Union[TokenInputs, MultiModalInputs] inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"): if multi_modal_data := parsed_content.get("multi_modal_data"):
inputs = self._process_multimodal( inputs = self._process_multimodal(
...@@ -325,7 +324,8 @@ class InputPreprocessor: ...@@ -325,7 +324,8 @@ class InputPreprocessor:
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
else: else:
inputs = token_inputs(prompt_token_ids=prompt_token_ids) inputs = token_inputs(prompt_token_ids=prompt_token_ids,
qfeat=parsed_content["qfeat"])
if cache_salt := parsed_content.get("cache_salt"): if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt inputs["cache_salt"] = cache_salt
...@@ -358,6 +358,7 @@ class InputPreprocessor: ...@@ -358,6 +358,7 @@ class InputPreprocessor:
inputs = token_inputs( inputs = token_inputs(
prompt=prompt_text, prompt=prompt_text,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
qfeat=parsed_content["qfeat"]
) )
if cache_salt := parsed_content.get("cache_salt"): if cache_salt := parsed_content.get("cache_salt"):
......
...@@ -347,7 +347,7 @@ def new_hy_05b_dense_official_classification(cls: _T) -> _T: ...@@ -347,7 +347,7 @@ def new_hy_05b_dense_official_classification(cls: _T) -> _T:
classifier=self._classifier, classifier=self._classifier,
act_fn=PoolerIdentity(), act_fn=PoolerIdentity(),
) )
}) })
def _classifier(self, x: torch.Tensor, pooling_metadata: PoolingMetadata = None): def _classifier(self, x: torch.Tensor, pooling_metadata: PoolingMetadata = None):
...@@ -377,6 +377,7 @@ def new_hy_05b_dense_official_classification(cls: _T) -> _T: ...@@ -377,6 +377,7 @@ def new_hy_05b_dense_official_classification(cls: _T) -> _T:
# seq_length].squeeze(-1) # seq_length].squeeze(-1)
return reward return reward
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -423,7 +424,7 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T: ...@@ -423,7 +424,7 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
if is_pooling_model(cls): if is_pooling_model(cls):
return cls return cls
# Lazy import # Lazy importGPUModelRunner
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.pooler import (ClassifierPooler, from vllm.model_executor.layers.pooler import (ClassifierPooler,
DispatchPooler, Pooler, DispatchPooler, Pooler,
...@@ -452,6 +453,7 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T: ...@@ -452,6 +453,7 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
prefix=maybe_prefix(prefix, "pool_head"), prefix=maybe_prefix(prefix, "pool_head"),
return_bias=False, return_bias=False,
) )
self.pool_head2 = ReplicatedLinear( self.pool_head2 = ReplicatedLinear(
config.hidden_size, config.hidden_size,
config.class_num, config.class_num,
...@@ -461,7 +463,8 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T: ...@@ -461,7 +463,8 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
prefix=maybe_prefix(prefix, "pool_head2"), prefix=maybe_prefix(prefix, "pool_head2"),
return_bias=False, return_bias=False,
) )
self.qfeat_emb =ReplicatedLinear(
self.qfeat_emb = ReplicatedLinear(
2, 2,
128, 128,
bias=True, bias=True,
...@@ -470,14 +473,15 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T: ...@@ -470,14 +473,15 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
prefix=maybe_prefix(prefix, "qfeat_emb"), prefix=maybe_prefix(prefix, "qfeat_emb"),
return_bias=False, return_bias=False,
) )
self.qfeat_emb_topic = VocabParallelEmbedding( self.qfeat_emb_topic = VocabParallelEmbedding(
100, 100,
128, 128,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qfeat_emb_topic", prefix=f"{prefix}.qfeat_emb_topic",
) )
self.qfeat_fc1 =ReplicatedLinear(
self.qfeat_fc1 = ReplicatedLinear(
256, 256,
256, 256,
bias=True, bias=True,
...@@ -486,8 +490,8 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T: ...@@ -486,8 +490,8 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
prefix=maybe_prefix(prefix, "qfeat_fc1"), prefix=maybe_prefix(prefix, "qfeat_fc1"),
return_bias=False, return_bias=False,
) )
self.qfeat_fc2 =ReplicatedLinear( self.qfeat_fc2 = ReplicatedLinear(
256, 256,
3, 3,
bias=True, bias=True,
...@@ -496,8 +500,8 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T: ...@@ -496,8 +500,8 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
prefix=maybe_prefix(prefix, "qfeat_fc2"), prefix=maybe_prefix(prefix, "qfeat_fc2"),
return_bias=False, return_bias=False,
) )
self.qfeat_fc3 =ReplicatedLinear( self.qfeat_fc3 = ReplicatedLinear(
256, 256,
3, 3,
bias=True, bias=True,
...@@ -530,17 +534,20 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T: ...@@ -530,17 +534,20 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
classifier=self._classifier, classifier=self._classifier,
act_fn=PoolerIdentity(), act_fn=PoolerIdentity(),
) )
}) })
def encode_qfeat(self, qfeat): def encode_qfeat(self, qfeat):
emb1 = self.qfeat_emb(qfeat[:,:2]) emb1 = self.qfeat_emb(qfeat[:,:2])
emb2 = self.qfeat_emb_topic(qfeat[:,2].to(torch.long)) emb2 = self.qfeat_emb_topic(qfeat[:,2].to(torch.long))
hidden = torch.cat([emb1, emb2], dim=1) hidden = torch.cat([emb1, emb2], dim=1)
hidden = self.qfeat_fc1(hidden) hidden = self.qfeat_fc1(hidden)
hidden = torch.relu(hidden) hidden = torch.relu(hidden)
# hidden = torch.softmax(hidden, dim=1) # hidden = torch.softmax(hidden, dim=1)
return hidden return hidden
def _classifier(self, x: torch.Tensor,pooling_metadata: PoolingMetadata=None):
def _classifier(self, x: torch.Tensor, pooling_metadata: PoolingMetadata=None):
pooled_output= self.pool_head(x) pooled_output= self.pool_head(x)
if isinstance(pooled_output, tuple): if isinstance(pooled_output, tuple):
pooled_output = pooled_output[0] pooled_output = pooled_output[0]
...@@ -549,50 +556,38 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T: ...@@ -549,50 +556,38 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
pooled_output_rel = self.pool_head2(pooled_output) # bs * class_num pooled_output_rel = self.pool_head2(pooled_output) # bs * class_num
pooled_output_time = self.pool_head2(pooled_output) # bs * class_num pooled_output_time = self.pool_head2(pooled_output) # bs * class_num
pooled_output_auth = self.pool_head2(pooled_output) # bs * class_num pooled_output_auth = self.pool_head2(pooled_output) # bs * class_num
qfeat=torch.tensor([[2, 0, 20]],device=pooled_output.device) # print(f"**************sat_logits:{sat_logits}********")
# qfeat = torch.tensor([[2, 0, 20]], device=pooled_output.device)
# print("*** adapters_classify.py L558 pooling_metadata.pooling_params", pooling_metadata.pooling_params[0].qfeat)
#PoolingParams(task=classify, normalize=None, dimensions=None, activation=True, softmax=None, step_tag_id=None, returned_token_ids=None, requires_token_ids=False, qfeat=None, extra_kwargs=None)
qfeat = pooling_metadata.pooling_params[0].qfeat
qfeat = torch.tensor([qfeat], device=pooled_output.device) if qfeat is not None else torch.tensor([[1, 2, 3]], device=pooled_output.device)
qfeat = qfeat.to(dtype=pooled_output.dtype) qfeat = qfeat.to(dtype=pooled_output.dtype)
print("*** qfeat ***", qfeat)
qhidden = self.encode_qfeat(qfeat) qhidden = self.encode_qfeat(qfeat)
a_wei = self.qfeat_fc2(qhidden) a_wei = self.qfeat_fc2(qhidden)
a_bias = self.qfeat_fc3(qhidden) a_bias = self.qfeat_fc3(qhidden)
if pooled_output.size()[1]<3: if pooled_output.size()[1] < 3:
batch_size = pooled_output.size(0) # 或 pooled_output.shape[0] batch_size = pooled_output.size(0) # 或 pooled_output.shape[0]
reward = torch.full((batch_size, 1), float('inf'), device=pooled_output.device, dtype=pooled_output.dtype) reward = torch.full((batch_size, 1), float('inf'), device=pooled_output.device, dtype=pooled_output.dtype)
return reward return reward
last_token_idx = pooling_metadata.pooling_cursor.num_scheduled_tokens_cpu-1 last_token_idx = pooling_metadata.pooling_cursor.num_scheduled_tokens_cpu - 1
batch_indices = torch.arange(last_token_idx.size(0)) batch_indices = torch.arange(last_token_idx.size(0))
sat_logits = pooled_output_sat[batch_indices, last_token_idx,:] sat_logits = pooled_output_sat[batch_indices, last_token_idx-1,:]
auth_logits = pooled_output_auth[batch_indices, last_token_idx-1,:] # print(f"**************sat_logits:{sat_logits}********")
time_logits = pooled_output_time[batch_indices, last_token_idx-2,:] auth_logits = pooled_output_auth[batch_indices, last_token_idx-2,:]
rel_logits = pooled_output_rel[batch_indices, last_token_idx-3,:] # print(f"**************auth_logits:{auth_logits}********")
time_logits = pooled_output_time[batch_indices, last_token_idx-3,:]
# print(f"**************time_logits:{time_logits}********")
rel_logits = pooled_output_rel[batch_indices, last_token_idx-4,:]
# print(f"**************rel_logits:{rel_logits}********")
multii_logits = torch.concat([rel_logits, time_logits, auth_logits], dim=1) multii_logits = torch.concat([rel_logits, time_logits, auth_logits], dim=1)
task_logits = (a_wei * multii_logits + a_bias).sum(dim=1, keepdim=True) task_logits = (a_wei * multii_logits + a_bias).sum(dim=1, keepdim=True)
task_logits = torch.sigmoid(task_logits) task_logits = torch.sigmoid(task_logits)
sat_logits_new = task_logits * sat_logits sat_logits_new = task_logits * sat_logits
logits = sat_logits_new logits = sat_logits_new
reward = logits reward = logits
# print(reward)
# sat_logits = pooled_output_sat[torch.arange(batch_size, device=pooled_output.device), seq_length-1]
# auth_logits = pooled_output_auth[torch.arange(batch_size, device=pooled_output.device), seq_length-2]
# time_logits = pooled_output_time[torch.arange(batch_size, device=pooled_output.device), seq_length-3]
# rel_logits = pooled_output_rel[torch.arange(batch_size, device=pooled_output.device), seq_length-4]
# # a_score = torch.sigmoid(torch.concat([rel_logits, time_logits, auth_logits], dim=1))
# multii_logits = torch.concat([rel_logits, time_logits, auth_logits], dim=1)
# task_logits = (a_wei * multii_logits + a_bias).sum(dim=1, keepdim=True)
# task_logits = torch.sigmoid(task_logits)
# #gate_time = (a_wei * multii_logits + wei_time).sum(dim=1, keepdim=True)
# #gate_time = torch.sigmoid(gate_time)
# #gate_auth = (a_wei * multii_logits + wei_auth).sum(dim=1, keepdim=True)
# #gate_auth = torch.sigmoid(gate_auth)
# sat_logits_new = task_logits * sat_logits
# #logits = 2.0 * sat_logits_new.detach() + 0.25 * (qfeat[:,0].float().unsqueeze(1)) * gate_time * time_logits.detach() + 0.5 * (qfeat[:,1].float().unsqueeze(1) + 0.4) * gate_auth * auth_logits.detach()
# logits = sat_logits_new
# reward = logits.squeeze(-1)
return reward return reward
...@@ -601,12 +596,13 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T: ...@@ -601,12 +596,13 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
self.input_ids =input_ids self.input_ids =input_ids
return super().forward(input_ids, positions, intermediate_tensors, return super().forward(input_ids, positions, intermediate_tensors,
inputs_embeds) inputs_embeds)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
tokens = getattr(self.config, "classifier_from_token", None) tokens = getattr(self.config, "classifier_from_token", None)
method = getattr(self.config, "method", None) method = getattr(self.config, "method", None)
......
...@@ -21,7 +21,6 @@ import regex as re ...@@ -21,7 +21,6 @@ import regex as re
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionType from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
...@@ -36,9 +35,9 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -36,9 +35,9 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import (ClassifierPooler, # from vllm.model_executor.layers.pooler import (ClassifierPooler,
DispatchPooler, Pooler, # DispatchPooler, Pooler,
PoolingMethod, PoolingType) # PoolingMethod, PoolingType)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
...@@ -52,7 +51,7 @@ from vllm.sequence import IntermediateTensors ...@@ -52,7 +51,7 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_layers, maybe_prefix) make_layers, maybe_prefix)
import torch.nn.functional as F
def _is_moe(config: PretrainedConfig) -> bool: def _is_moe(config: PretrainedConfig) -> bool:
num_experts = getattr(config, "num_experts", None) num_experts = getattr(config, "num_experts", None)
...@@ -124,7 +123,7 @@ class HunYuanSparseMoeBlock(nn.Module): ...@@ -124,7 +123,7 @@ class HunYuanSparseMoeBlock(nn.Module):
): ):
super().__init__() super().__init__()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank() self.ep_rank = self.ep_group.rank()
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
...@@ -134,7 +133,7 @@ class HunYuanSparseMoeBlock(nn.Module): ...@@ -134,7 +133,7 @@ class HunYuanSparseMoeBlock(nn.Module):
raise ValueError( raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than " f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.") f"the number of experts {config.num_experts}.")
# Get layer_id topk if config.moe_topk is a list # Get layer_id topk if config.moe_topk is a list
if isinstance(config.moe_topk, list): if isinstance(config.moe_topk, list):
assert layer_id >= 0 assert layer_id >= 0
...@@ -238,7 +237,7 @@ class HunYuanAttention(nn.Module): ...@@ -238,7 +237,7 @@ class HunYuanAttention(nn.Module):
bias: bool = False, bias: bool = False,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
prefix: str = "", prefix: str = "",
layer_id: int = -1, layer_id: int = -1
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -327,17 +326,40 @@ class HunYuanAttention(nn.Module): ...@@ -327,17 +326,40 @@ class HunYuanAttention(nn.Module):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
ori_k = k ori_k = k
ori_v = v
if self.use_qk_norm: if self.use_qk_norm:
q = self.query_layernorm( q = self.query_layernorm(
q.view(-1, self.num_heads, self.head_dim).contiguous()) q.view(-1, self.num_heads, self.head_dim).contiguous())
k = self.key_layernorm( k = self.key_layernorm(
k.view(-1, self.num_kv_heads, self.head_dim).contiguous()) k.view(-1, self.num_kv_heads, self.head_dim).contiguous())
# Transpose.
attn_output = self.attn(q, k, v)
# For o_proj q = q.view(-1, self.num_heads, self.head_dim)
attn_output = attn_output.view(q.shape[0], -1) # Expand the key and value to handle GQA.
num_queries_per_kv = self.num_heads // self.num_kv_heads
k = k.view(-1, self.num_kv_heads, self.head_dim)
k = k.repeat_interleave(num_queries_per_kv, dim=-2)
v = v.view(-1, self.num_kv_heads, self.head_dim)
v = v.repeat_interleave(num_queries_per_kv, dim=-2)
q = q.unsqueeze(0).transpose(1, 2).contiguous()
k = k.unsqueeze(0).transpose(1, 2).contiguous()
v = v.unsqueeze(0).transpose(1, 2).contiguous()
# print("*** hunyuan.py L354 q.shape[0]", q.shape[0])
attn_mask = torch.ones(q.shape[0], 1, q.shape[2], q.shape[2]).cuda().contiguous()
# print("====",q.shape, k.shape, v.shape, attn_mask.shape)
attn_output = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask,
dropout_p=0.0
)
attn_output = attn_output.transpose(1, 2).flatten(-2, -1).squeeze()
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output, (ori_k, v) return output, (ori_k, ori_v)
class HunYuanCrossAttention(nn.Module): class HunYuanCrossAttention(nn.Module):
...@@ -439,6 +461,7 @@ class HunYuanCrossAttention(nn.Module): ...@@ -439,6 +461,7 @@ class HunYuanCrossAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_states: Optional[Tuple[torch.Tensor]] = None, kv_states: Optional[Tuple[torch.Tensor]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
attention_mask=[]
assert kv_states is not None assert kv_states is not None
ori_k, v = kv_states # use last layer kv, ori_k, v = kv_states # use last layer kv,
k = ori_k k = ori_k
...@@ -450,8 +473,8 @@ class HunYuanCrossAttention(nn.Module): ...@@ -450,8 +473,8 @@ class HunYuanCrossAttention(nn.Module):
q.view(-1, self.num_heads, self.head_dim).contiguous()) q.view(-1, self.num_heads, self.head_dim).contiguous())
k = self.key_layernorm( k = self.key_layernorm(
k.view(-1, self.num_kv_heads, self.head_dim).contiguous()) k.view(-1, self.num_kv_heads, self.head_dim).contiguous())
attn_output = self.attn(q, k, v) attn_output = self.attn(hidden_states, attention_mask, positions,kv_states)
# For o_proj # For o_proj
attn_output = attn_output.view(q.shape[0], -1) attn_output = attn_output.view(q.shape[0], -1)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
...@@ -467,14 +490,14 @@ class HunYuanDecoderLayer(nn.Module): ...@@ -467,14 +490,14 @@ class HunYuanDecoderLayer(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
layer_id: int = -1, layer_id: int = -1,
enable_eplb: bool = False, enable_eplb: bool = False
) -> None: ) -> None:
super().__init__() super().__init__()
assert layer_id >= 0 assert layer_id >= 0
self.layer_id = layer_id self.layer_id = layer_id
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.intermediate_size = (config.intermediate_size if isinstance( self.intermediate_size = (config.intermediate_size if isinstance(
config.intermediate_size, int) else config.intermediate_size, int) else
config.intermediate_size[layer_id]) config.intermediate_size[layer_id])
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
...@@ -490,7 +513,7 @@ class HunYuanDecoderLayer(nn.Module): ...@@ -490,7 +513,7 @@ class HunYuanDecoderLayer(nn.Module):
attention_type = (AttentionType.ENCODER_DECODER attention_type = (AttentionType.ENCODER_DECODER
if layer_id >= 0 and layer_id % cla_factor != 0 else if layer_id >= 0 and layer_id % cla_factor != 0 else
AttentionType.DECODER) AttentionType.DECODER)
if attention_type == AttentionType.DECODER: if attention_type == AttentionType.DECODER:
self.self_attn = HunYuanAttention( self.self_attn = HunYuanAttention(
config=config, config=config,
...@@ -505,7 +528,7 @@ class HunYuanDecoderLayer(nn.Module): ...@@ -505,7 +528,7 @@ class HunYuanDecoderLayer(nn.Module):
bias=attention_bias, bias=attention_bias,
cache_config=cache_config, cache_config=cache_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
layer_id=layer_id, layer_id=layer_id
) )
elif attention_type == AttentionType.ENCODER_DECODER: elif attention_type == AttentionType.ENCODER_DECODER:
self.self_attn = HunYuanCrossAttention( self.self_attn = HunYuanCrossAttention(
...@@ -526,7 +549,7 @@ class HunYuanDecoderLayer(nn.Module): ...@@ -526,7 +549,7 @@ class HunYuanDecoderLayer(nn.Module):
else: else:
raise RuntimeError(f"Unsupported attention type: {attention_type}") raise RuntimeError(f"Unsupported attention type: {attention_type}")
if _is_moe(config): if _is_moe(config):
self.mlp = HunYuanSparseMoeBlock( self.mlp = HunYuanSparseMoeBlock(
config=config, config=config,
...@@ -556,8 +579,8 @@ class HunYuanDecoderLayer(nn.Module): ...@@ -556,8 +579,8 @@ class HunYuanDecoderLayer(nn.Module):
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
kv_states: Optional[Tuple[torch.Tensor]] = None, kv_states: Optional[Tuple[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
residual=hidden_states residual=hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
...@@ -566,12 +589,13 @@ class HunYuanDecoderLayer(nn.Module): ...@@ -566,12 +589,13 @@ class HunYuanDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
kv_states=kv_states, kv_states=kv_states,
) )
hidden_states =residual+hidden_states
residual=hidden_states
hidden_states= self.post_attention_layernorm(hidden_states) hidden_states = residual + hidden_states
hidden_states=self.mlp(hidden_states) residual = hidden_states
hidden_states=hidden_states+residual
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = hidden_states + residual
return hidden_states, residual, ori_kv_states return hidden_states, residual, ori_kv_states
...@@ -581,7 +605,6 @@ class HunYuanModel(nn.Module): ...@@ -581,7 +605,6 @@ class HunYuanModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
...@@ -614,16 +637,15 @@ class HunYuanModel(nn.Module): ...@@ -614,16 +637,15 @@ class HunYuanModel(nn.Module):
layer_id=int(prefix.split(".")[-1]), layer_id=int(prefix.split(".")[-1]),
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix
), ),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
...@@ -644,9 +666,9 @@ class HunYuanModel(nn.Module): ...@@ -644,9 +666,9 @@ class HunYuanModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
cla_factor = _get_cla_factor(self.config) cla_factor = _get_cla_factor(self.config)
prev_kv_states = None prev_kv_states = None
for i in range(self.start_layer, self.end_layer): for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual, kv_states = layer( hidden_states, residual, kv_states = layer(
...@@ -655,7 +677,7 @@ class HunYuanModel(nn.Module): ...@@ -655,7 +677,7 @@ class HunYuanModel(nn.Module):
residual, residual,
prev_kv_states, prev_kv_states,
) )
if (getattr(self.config, "use_cla", False) if (getattr(self.config, "use_cla", False)
and (i - self.start_layer) % cla_factor == 0): and (i - self.start_layer) % cla_factor == 0):
prev_kv_states = kv_states prev_kv_states = kv_states
...@@ -692,7 +714,7 @@ class HunYuanModel(nn.Module): ...@@ -692,7 +714,7 @@ class HunYuanModel(nn.Module):
k = k.reshape(-1, hidden_size) k = k.reshape(-1, hidden_size)
v = v.reshape(-1, hidden_size) v = v.reshape(-1, hidden_size)
return torch.concat((q, k, v)) return torch.concat((q, k, v))
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
if _is_moe(self.config): if _is_moe(self.config):
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
...@@ -706,7 +728,7 @@ class HunYuanModel(nn.Module): ...@@ -706,7 +728,7 @@ class HunYuanModel(nn.Module):
) )
else: else:
return [] return []
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
cla_factor = _get_cla_factor(self.config) cla_factor = _get_cla_factor(self.config)
stacked_params_mapping = [ stacked_params_mapping = [
...@@ -894,11 +916,6 @@ class HunYuanModel(nn.Module): ...@@ -894,11 +916,6 @@ class HunYuanModel(nn.Module):
return loaded_params return loaded_params
class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP): class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
...@@ -920,10 +937,10 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -920,10 +937,10 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.pad_id = self.config.pad_id self.pad_id = self.config.pad_id
self.model = HunYuanModel(vllm_config=vllm_config, prefix="model") self.model = HunYuanModel(vllm_config=vllm_config, prefix="model")
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
...@@ -979,9 +996,6 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -979,9 +996,6 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
moe.n_redundant_experts = self.num_redundant_experts moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map() moe.experts.update_expert_map()
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -990,13 +1004,13 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -990,13 +1004,13 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, intermediate_tensors, model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds) inputs_embeds)
return model_output return model_output
def compute_logits( def compute_logits(
self, self,
...@@ -1030,4 +1044,4 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -1030,4 +1044,4 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
...@@ -23,7 +23,7 @@ class PoolingParams( ...@@ -23,7 +23,7 @@ class PoolingParams(
truncate_prompt_tokens: Controls prompt truncation. truncate_prompt_tokens: Controls prompt truncation.
Set to -1 to use the model's default truncation size. Set to -1 to use the model's default truncation size.
Set to k to keep only the last k tokens (left truncation). Set to k to keep only the last k tokens (left truncation).
Set to None to disable truncation. Set to None to disable truncation.
normalize: Whether to normalize the embeddings outputs. normalize: Whether to normalize the embeddings outputs.
dimensions: Reduce the dimensions of embeddings dimensions: Reduce the dimensions of embeddings
if model support matryoshka representation. if model support matryoshka representation.
...@@ -64,6 +64,8 @@ class PoolingParams( ...@@ -64,6 +64,8 @@ class PoolingParams(
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
qfeat: Optional[list] = None
@property @property
def all_parameters(self) -> list[str]: def all_parameters(self) -> list[str]:
return [ return [
...@@ -184,6 +186,7 @@ class PoolingParams( ...@@ -184,6 +186,7 @@ class PoolingParams(
f"step_tag_id={self.step_tag_id}, " f"step_tag_id={self.step_tag_id}, "
f"returned_token_ids={self.returned_token_ids}, " f"returned_token_ids={self.returned_token_ids}, "
f"requires_token_ids={self.requires_token_ids}, " f"requires_token_ids={self.requires_token_ids}, "
f"qfeat={self.qfeat}, "
f"extra_kwargs={self.extra_kwargs})") f"extra_kwargs={self.extra_kwargs})")
def __post_init__(self) -> None: def __post_init__(self) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment