Commit 643c0d59 authored by luopl's avatar luopl
Browse files

解决test2的qfeat传输问题

parent 9e61b53d
#!/usr/bin/env python3
"""Minimal classify demo using token IDs as input.
This mirrors the docs example:
llm = LLM(model="...", runner="pooling")
(output,) = llm.classify("Hello, my name is")
but feeds DEFAULT_PROMPT_TOKEN_IDS via token_inputs instead of text.
"""
from vllm import LLM
from vllm.inputs import token_inputs
from transformers import AutoTokenizer
DEFAULT_PROMPT_TOKEN_IDS = [
[127958, 58, 10172, 24575, 8437, 7489, 51, 60, 220, 57668, 102832, 80073, 75761, 102245, 39045, 57668, 105982, 103429, 88852, 9743, 34208, 2929, 3922, 101423, 83125, 110357, 107759, 82317, 101505, 101009, 1811, 15225, 61633, 3922, 101992, 80073, 120702, 17, 15, 17, 20, 8107, 15, 23, 9953, 17, 22, 9080, 3490, 2929, 5232, 86461, 102160, 36827, 31867, 19, 34208, 86461, 102160, 36827, 31867, 21, 107938, 105528, 198, 12, 11615, 101241, 5232, 111642, 198, 12, 11615, 104780, 101526, 105344, 5232, 100377, 104780, 198, 12, 11615, 106444, 101526, 105344, 5232, 101055, 106444, 271, 9743, 29411, 12, 52561, 229, 34972, 5232, 86461, 102160, 36827, 31867, 21, 101129, 91547, 6704, 238, 255, 28466, 100433, 118068, 5463, 55, 220, 20, 15, 21, 15, 100824, 123798, 20, 24, 24, 24, 24186, 72718, 62, 59505, 112203, 32239, 198, 12, 73028, 96, 17161, 5232, 86461, 102160, 36827, 31867, 21, 101129, 91547, 6704, 238, 255, 28466, 100433, 118068, 5463, 55, 220, 20, 15, 21, 15, 100824, 123798, 20, 24, 24, 24, 24186, 72718, 198, 20, 9953, 16, 24, 104223, 3922, 86461, 102160, 106518, 101612, 100733, 17039, 33671, 24186, 101016, 22656, 36827, 31867, 21, 101129, 102984, 91547, 1811, 101291, 57, 115718, 100532, 100880, 101612, 108616, 101016, 105164, 3922, 36827, 31867, 21, 10447, 37197, 32218, 112561, 117408, 35287, 42246, 111379, 100992, 103735, 1811, 107743, 100433, 118068, 5463, 55, 220, 20, 15, 21, 15, 120928, 50991, 107758, 31867, 21, 1322, 102072, 36827, 31867, 21, 112848, 104946, 111379, 102735, 105541, 3922, 92672, 109034, 119927, 3922, 101739, 101911, 100532, 5486, 9080, 102874, 102581, 100518, 39135, 113577, 1811, 102317, 57237, 16, 13, 22, 24, 6358, 100547, 95337, 17, 13, 17, 7501, 10110, 36827, 31867, 21, 49372, 46281, 32943, 28037, 48915, 101026, 101058, 123721, 1811, 108895, 101035, 3922, 36827, 31867, 21, 101129, 100503, 72843, 17, 15, 4, 100477, 105516, 3922, 110349, 107371, 24946, 101971, 36, 100646, 105653, 3922, 123798, 20, 24, 24, 24, 24186, 72718, 6447, 198, 86461, 102160, 36827, 31867, 21, 5232, 116967, 109845, 101037, 10447, 94, 105, 72237, 27327, 76537, 198, 86461, 102160, 36827, 31867, 21, 113577, 105868, 58805, 94278, 238, 100667, 15592, 220, 22, 473, 220, 17, 21, 15, 107311, 3922, 23, 101021, 16, 21, 120312, 71600, 3922, 19, 20211, 102271, 102812, 103486, 1811, 103954, 9688, 19085, 10860, 55, 220, 20, 15, 21, 15, 120928, 50991, 3922, 16, 16, 20, 54, 100433, 118068, 71600, 3922, 101046, 16931, 1242, 220, 19, 100344, 1811, 105874, 17905, 3922, 107743, 16, 21, 108465, 17, 13, 20, 42, 220, 16, 21, 20, 11732, 109943, 53434, 108018, 3922, 101046, 38, 6354, 72501, 100344, 3922, 43292, 103897, 104196, 120822, 82317, 102698, 58322, 3922, 19, 15, 271, 9743, 91547, 58521, 29411, 482, 220, 58521, 31091, 5232, 964, 16, 21, 23, 198, 482, 220, 58521, 101241, 5232, 111642, 11, 76771, 239, 83301, 11, 47850, 233, 33748, 198, 482, 220, 58521, 105302, 5232, 109173, 271, 22452, 91547, 101143, 5232, 112203, 32239, 198, 482, 75677, 111, 55038, 101241, 5232, 104312, 11, 4996, 223, 98, 100563, 11, 220, 101766, 198, 482, 41766, 229, 81742, 33005, 5232, 100359, 198, 482, 75677, 111, 55038, 105344, 5232, 23, 198, 482, 61696, 225, 101028, 105344, 5232, 17, 271, 9743, 91547, 21082, 5232, 17, 15, 17, 20, 8107, 15, 20, 9953, 16, 24, 9080, 198, 482, 9085, 115, 251, 104944, 9039, 5232, 24, 24, 271, 15225, 67117, 83125, 110357, 107759, 5232, 127962, 127972, 127973, 127974, 127975, 127967],
[127958, 58, 10172, 24575, 8437, 7489, 51, 60, 220, 57668, 102832, 80073, 75761, 102245, 39045, 57668, 105982, 103429, 88852, 9743, 34208, 2929, 3922, 101423, 83125, 110357, 107759, 82317, 101505, 101009, 1811, 15225, 61633, 3922, 101992, 80073, 120702, 17, 15, 17, 20, 8107, 15, 23, 9953, 17, 22, 9080, 3490, 2929, 5232, 86461, 102160, 36827, 31867, 19, 34208, 86461, 102160, 36827, 31867, 21, 107938, 105528, 198, 12, 11615, 101241, 5232, 111642, 198, 12, 11615, 104780, 101526, 105344, 5232, 100377, 104780, 198, 12, 11615, 106444, 101526, 105344, 5232, 101055, 106444, 271, 9743, 29411, 12, 52561, 229, 34972, 5232, 86461, 102160, 36827, 31867, 21, 101129, 91547, 6704, 238, 255, 28466, 100433, 118068, 5463, 55, 220, 20, 15, 21, 15, 100824, 123798, 20, 24, 24, 24, 24186, 72718, 62, 59505, 112203, 32239, 198, 12, 73028, 96, 17161, 5232, 86461, 102160, 36827, 31867, 21, 101129, 91547, 6704, 238, 255, 28466, 100433, 118068, 5463, 55, 220, 20, 15, 21, 15, 100824, 123798, 20, 24, 24, 24, 24186, 72718, 198, 20, 9953, 16, 24, 104223, 3922, 86461, 102160, 106518, 101612, 100733, 17039, 33671, 24186, 101016, 22656, 36827, 31867, 21, 101129, 102984, 91547, 1811, 101291, 57, 115718, 100532, 100880, 101612, 108616, 101016, 105164, 3922, 36827, 31867, 21, 10447, 37197, 32218, 112561, 117408, 35287, 42246, 111379, 100992, 103735, 1811, 107743, 100433, 118068, 5463, 55, 220, 20, 15, 21, 15, 120928, 50991, 107758, 31867, 21, 1322, 102072, 36827, 31867, 21, 112848, 104946, 111379, 102735, 105541, 3922, 92672, 109034, 119927, 3922, 101739, 101911, 100532, 5486, 9080, 102874, 102581, 100518, 39135, 113577, 1811, 102317, 57237, 16, 13, 22, 24, 6358, 100547, 95337, 17, 13, 17, 7501, 10110, 36827, 31867, 21, 49372, 46281, 32943, 28037, 48915, 101026, 101058, 123721, 1811, 108895, 101035, 3922, 36827, 31867, 21, 101129, 100503, 72843, 17, 15, 4, 100477, 105516, 3922, 110349, 107371, 24946, 101971, 36, 100646, 105653, 3922, 123798, 20, 24, 24, 24, 24186, 72718, 6447, 198, 86461, 102160, 36827, 31867, 21, 5232, 116967, 109845, 101037, 10447, 94, 105, 72237, 27327, 76537, 198, 86461, 102160, 36827, 31867, 21, 113577, 105868, 58805, 94278, 238, 100667, 15592, 220, 22, 473, 220, 17, 21, 15, 107311, 3922, 23, 101021, 16, 21, 120312, 71600, 3922, 19, 20211, 102271, 102812, 103486, 1811, 103954, 9688, 19085, 10860, 55, 220, 20, 15, 21, 15, 120928, 50991, 3922, 16, 16, 20, 54, 100433, 118068, 71600, 3922, 101046, 16931, 1242, 220, 19, 100344, 1811, 105874, 17905, 3922, 107743, 16, 21, 108465, 17, 13, 20, 42, 220, 16, 21, 20, 11732, 109943, 53434, 108018, 3922, 101046, 38, 6354, 72501, 100344, 3922, 43292, 103897, 104196, 120822, 82317, 102698, 58322, 3922, 19, 15, 271, 9743, 91547, 58521, 29411, 482, 220, 58521, 31091, 5232, 964, 16, 21, 23, 198, 482, 220, 58521, 101241, 5232, 111642, 11, 76771, 239, 83301, 11, 47850, 233, 33748, 198, 482, 220, 58521, 105302, 5232, 109173, 271, 22452, 91547, 101143, 5232, 112203, 32239, 198, 482, 75677, 111, 55038, 101241, 5232, 104312, 11, 4996, 223, 98, 100563, 11, 220, 101766, 198, 482, 41766, 229, 81742, 33005, 5232, 100359, 198, 482, 75677, 111, 55038, 105344, 5232, 23, 198, 482, 61696, 225, 101028, 105344, 5232, 17, 271, 9743, 91547, 21082, 5232, 17, 15, 17, 20, 8107, 15, 20, 9953, 16, 24, 9080, 198, 482, 9085, 115, 251, 104944, 9039, 5232, 24, 24, 271, 15225, 67117, 83125, 110357, 107759, 5232, 127962, 127972, 127973, 127974, 127975, 127967]
]
PROMPTS = "[GenRM-vCoT] 你是一个搜索排序专家,请你仔细阅读以下Doc和Query,给出文章满意度评分及具体原因。请注意,本次搜索时间是2025年08月27日。\n\nQuery:华硕天选4和华硕天选6性价比对比\n- Query领域:电子产品\n- Query时效需求等级:低时效\n- Query权威需求等级:弱权威\n\nDoc:\n- 标题:华硕天选6系列发布 搭载满功耗RTX 5060实际到手5999元起_手机新浪网\n- 正文:<pcut>华硕天选6系列发布 搭载满功耗RTX 5060实际到手5999元起\n5月19日晚,华硕旗下潮玩新次元游戏本天选6系列正式发布。作为Z世代青年的潮酷游戏装备,天选6 系列再一次印证了其出色的综合实力。搭载满功耗RTX 5060笔记本电脑GPU的天选6 Pro以及天选6皆拥有出色的性能释放,同时颜值出众,魔幻青、日蚀灰双色可选。薄至1.79cm轻约2.2kg(天选6),从内到外实现全面进阶。首发期间,天选6系列均享20%国家补贴,叠加晒单返E卡福利,到手5999元起!\n华硕天选6:超高选购价值 硬核能打\n华硕天选6可选全新AMD 锐龙 AI 7 H 260处理器,8核心16线程设计,4nm先进工艺打造。配备GeForce RTX 5060笔记本电脑GPU,115W满功耗设计,支持DLSS 4技术。屏幕上,搭载16英寸2.5K 165Hz电竞级面板,支持G-SYNC技术,无惧画面撕裂及拖影,40<pcut>\n\nDoc发布作者:\n - 作者名称:IT168\n - 作者领域:电子产品, 科技, 手机\n - 作者认证:未知\n\n Doc发布平台:新浪网\n - 平台领域:财经, 健康, 旅游\n - 备案类型:企业\n - 平台等级:8\n - 权威等级:2\n\nDoc发布时间:2025年05月19日\n - 距今天数:99\n\n请输出文章满意度评分:"
MODELPATH = "/home/luopl/hunyuan_tx/test_2"
def test_prompt(llm):
tokenizer = AutoTokenizer.from_pretrained(MODELPATH, trust_remote_code=True)
input_ids = tokenizer(PROMPTS, return_tensors="pt", trust_remote_code=True)["input_ids"]
outputs = llm.classify(token_inputs(input_ids[0], qfeat=[2, 0, 20]))
for i, out in enumerate(outputs):
probs = out.outputs.probs
print(f"Request {i}, class probs = {probs}")
def test_tokenid(llm):
for ids in DEFAULT_PROMPT_TOKEN_IDS:
outputs = llm.classify(token_inputs(ids, qfeat=[2, 0, 20]))
for i, out in enumerate(outputs):
probs = out.outputs.probs
print(f"Request {i}, class probs = {probs}")
if __name__ == "__main__":
llm = LLM(model=MODELPATH, task="classify",
trust_remote_code=True,
enforce_eager=True,
enable_chunked_prefill=False)
test_prompt(llm)
test_tokenid(llm)
# print(input_ids)
\ No newline at end of file
......@@ -10,7 +10,7 @@ import torch.nn as nn
from pydantic import ValidationError
from tqdm.auto import tqdm
from typing_extensions import TypeVar
from vllm.inputs import token_inputs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence,
create_sort_beams_key_function)
......@@ -89,7 +89,7 @@ class LLM:
or videos from directories specified by the server file system.
This is a security risk. Should only be enabled in trusted
environments.
allowed_media_domains: If set, only media URLs that belong to this
allowed_media_domains: If set, only media URLs that belong to this
domain can be used for multi-modal inputs.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
......@@ -984,6 +984,9 @@ class LLM:
# Use default pooling params.
pooling_params = PoolingParams()
if prompts["qfeat"] is not None:
pooling_params.qfeat = prompts["qfeat"]
for param in as_iter(pooling_params):
param.verify(pooling_task, model_config)
# for backwards compatibility
......
......@@ -120,15 +120,15 @@ Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
prompts explicitly, i.e.
prompts explicitly, i.e.
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
A prompt of type [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] may be
A prompt of type [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] may be
employed as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
more than one prompt, i.e.
more than one prompt, i.e.
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
"""
......@@ -220,16 +220,21 @@ def token_inputs(
prompt_token_ids: list[int],
prompt: Optional[str] = None,
cache_salt: Optional[str] = None,
qfeat: Optional[list] = None,
) -> TokenInputs:
"""Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional
values."""
# print("************* {} ***************".format(qfeat))
if isinstance(prompt_token_ids, torch.Tensor):
prompt_token_ids = prompt_token_ids.tolist()
inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
if prompt is not None:
inputs["prompt"] = prompt
if cache_salt is not None:
inputs["cache_salt"] = cache_salt
if qfeat is not None:
inputs["qfeat"] = qfeat
return inputs
......
......@@ -314,7 +314,6 @@ class InputPreprocessor:
) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = self._truncate_inputs(
parsed_content["prompt_token_ids"], tokenization_kwargs)
inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"):
inputs = self._process_multimodal(
......@@ -325,7 +324,8 @@ class InputPreprocessor:
mm_uuids=mm_uuids,
)
else:
inputs = token_inputs(prompt_token_ids=prompt_token_ids)
inputs = token_inputs(prompt_token_ids=prompt_token_ids,
qfeat=parsed_content["qfeat"])
if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt
......@@ -358,6 +358,7 @@ class InputPreprocessor:
inputs = token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
qfeat=parsed_content["qfeat"]
)
if cache_salt := parsed_content.get("cache_salt"):
......
......@@ -347,7 +347,7 @@ def new_hy_05b_dense_official_classification(cls: _T) -> _T:
classifier=self._classifier,
act_fn=PoolerIdentity(),
)
})
def _classifier(self, x: torch.Tensor, pooling_metadata: PoolingMetadata = None):
......@@ -377,6 +377,7 @@ def new_hy_05b_dense_official_classification(cls: _T) -> _T:
# seq_length].squeeze(-1)
return reward
def forward(
self,
input_ids: torch.Tensor,
......@@ -423,7 +424,7 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
if is_pooling_model(cls):
return cls
# Lazy import
# Lazy importGPUModelRunner
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.pooler import (ClassifierPooler,
DispatchPooler, Pooler,
......@@ -452,6 +453,7 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
prefix=maybe_prefix(prefix, "pool_head"),
return_bias=False,
)
self.pool_head2 = ReplicatedLinear(
config.hidden_size,
config.class_num,
......@@ -461,7 +463,8 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
prefix=maybe_prefix(prefix, "pool_head2"),
return_bias=False,
)
self.qfeat_emb =ReplicatedLinear(
self.qfeat_emb = ReplicatedLinear(
2,
128,
bias=True,
......@@ -470,14 +473,15 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
prefix=maybe_prefix(prefix, "qfeat_emb"),
return_bias=False,
)
self.qfeat_emb_topic = VocabParallelEmbedding(
100,
128,
quant_config=quant_config,
prefix=f"{prefix}.qfeat_emb_topic",
)
self.qfeat_fc1 =ReplicatedLinear(
100,
128,
quant_config=quant_config,
prefix=f"{prefix}.qfeat_emb_topic",
)
self.qfeat_fc1 = ReplicatedLinear(
256,
256,
bias=True,
......@@ -486,8 +490,8 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
prefix=maybe_prefix(prefix, "qfeat_fc1"),
return_bias=False,
)
self.qfeat_fc2 =ReplicatedLinear(
self.qfeat_fc2 = ReplicatedLinear(
256,
3,
bias=True,
......@@ -496,8 +500,8 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
prefix=maybe_prefix(prefix, "qfeat_fc2"),
return_bias=False,
)
self.qfeat_fc3 =ReplicatedLinear(
self.qfeat_fc3 = ReplicatedLinear(
256,
3,
bias=True,
......@@ -530,17 +534,20 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
classifier=self._classifier,
act_fn=PoolerIdentity(),
)
})
def encode_qfeat(self, qfeat):
emb1 = self.qfeat_emb(qfeat[:,:2])
emb2 = self.qfeat_emb_topic(qfeat[:,2].to(torch.long))
hidden = torch.cat([emb1, emb2], dim=1)
hidden = self.qfeat_fc1(hidden)
hidden = torch.relu(hidden)
# hidden = torch.softmax(hidden, dim=1)
return hidden
def _classifier(self, x: torch.Tensor,pooling_metadata: PoolingMetadata=None):
def _classifier(self, x: torch.Tensor, pooling_metadata: PoolingMetadata=None):
pooled_output= self.pool_head(x)
if isinstance(pooled_output, tuple):
pooled_output = pooled_output[0]
......@@ -549,50 +556,38 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
pooled_output_rel = self.pool_head2(pooled_output) # bs * class_num
pooled_output_time = self.pool_head2(pooled_output) # bs * class_num
pooled_output_auth = self.pool_head2(pooled_output) # bs * class_num
qfeat=torch.tensor([[2, 0, 20]],device=pooled_output.device)
# print(f"**************sat_logits:{sat_logits}********")
# qfeat = torch.tensor([[2, 0, 20]], device=pooled_output.device)
# print("*** adapters_classify.py L558 pooling_metadata.pooling_params", pooling_metadata.pooling_params[0].qfeat)
#PoolingParams(task=classify, normalize=None, dimensions=None, activation=True, softmax=None, step_tag_id=None, returned_token_ids=None, requires_token_ids=False, qfeat=None, extra_kwargs=None)
qfeat = pooling_metadata.pooling_params[0].qfeat
qfeat = torch.tensor([qfeat], device=pooled_output.device) if qfeat is not None else torch.tensor([[1, 2, 3]], device=pooled_output.device)
qfeat = qfeat.to(dtype=pooled_output.dtype)
print("*** qfeat ***", qfeat)
qhidden = self.encode_qfeat(qfeat)
a_wei = self.qfeat_fc2(qhidden)
a_bias = self.qfeat_fc3(qhidden)
if pooled_output.size()[1]<3:
if pooled_output.size()[1] < 3:
batch_size = pooled_output.size(0) # 或 pooled_output.shape[0]
reward = torch.full((batch_size, 1), float('inf'), device=pooled_output.device, dtype=pooled_output.dtype)
return reward
last_token_idx = pooling_metadata.pooling_cursor.num_scheduled_tokens_cpu-1
last_token_idx = pooling_metadata.pooling_cursor.num_scheduled_tokens_cpu - 1
batch_indices = torch.arange(last_token_idx.size(0))
sat_logits = pooled_output_sat[batch_indices, last_token_idx,:]
auth_logits = pooled_output_auth[batch_indices, last_token_idx-1,:]
time_logits = pooled_output_time[batch_indices, last_token_idx-2,:]
rel_logits = pooled_output_rel[batch_indices, last_token_idx-3,:]
sat_logits = pooled_output_sat[batch_indices, last_token_idx-1,:]
# print(f"**************sat_logits:{sat_logits}********")
auth_logits = pooled_output_auth[batch_indices, last_token_idx-2,:]
# print(f"**************auth_logits:{auth_logits}********")
time_logits = pooled_output_time[batch_indices, last_token_idx-3,:]
# print(f"**************time_logits:{time_logits}********")
rel_logits = pooled_output_rel[batch_indices, last_token_idx-4,:]
# print(f"**************rel_logits:{rel_logits}********")
multii_logits = torch.concat([rel_logits, time_logits, auth_logits], dim=1)
task_logits = (a_wei * multii_logits + a_bias).sum(dim=1, keepdim=True)
task_logits = torch.sigmoid(task_logits)
sat_logits_new = task_logits * sat_logits
logits = sat_logits_new
reward = logits
# sat_logits = pooled_output_sat[torch.arange(batch_size, device=pooled_output.device), seq_length-1]
# auth_logits = pooled_output_auth[torch.arange(batch_size, device=pooled_output.device), seq_length-2]
# time_logits = pooled_output_time[torch.arange(batch_size, device=pooled_output.device), seq_length-3]
# rel_logits = pooled_output_rel[torch.arange(batch_size, device=pooled_output.device), seq_length-4]
# # a_score = torch.sigmoid(torch.concat([rel_logits, time_logits, auth_logits], dim=1))
# multii_logits = torch.concat([rel_logits, time_logits, auth_logits], dim=1)
# task_logits = (a_wei * multii_logits + a_bias).sum(dim=1, keepdim=True)
# task_logits = torch.sigmoid(task_logits)
# #gate_time = (a_wei * multii_logits + wei_time).sum(dim=1, keepdim=True)
# #gate_time = torch.sigmoid(gate_time)
# #gate_auth = (a_wei * multii_logits + wei_auth).sum(dim=1, keepdim=True)
# #gate_auth = torch.sigmoid(gate_auth)
# sat_logits_new = task_logits * sat_logits
# #logits = 2.0 * sat_logits_new.detach() + 0.25 * (qfeat[:,0].float().unsqueeze(1)) * gate_time * time_logits.detach() + 0.5 * (qfeat[:,1].float().unsqueeze(1) + 0.4) * gate_auth * auth_logits.detach()
# logits = sat_logits_new
# reward = logits.squeeze(-1)
# print(reward)
return reward
......@@ -601,12 +596,13 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None
) -> torch.Tensor:
self.input_ids =input_ids
return super().forward(input_ids, positions, intermediate_tensors,
inputs_embeds)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
tokens = getattr(self.config, "classifier_from_token", None)
method = getattr(self.config, "method", None)
......
......@@ -21,7 +21,6 @@ import regex as re
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
......@@ -36,9 +35,9 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import (ClassifierPooler,
DispatchPooler, Pooler,
PoolingMethod, PoolingType)
# from vllm.model_executor.layers.pooler import (ClassifierPooler,
# DispatchPooler, Pooler,
# PoolingMethod, PoolingType)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
......@@ -52,7 +51,7 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_layers, maybe_prefix)
import torch.nn.functional as F
def _is_moe(config: PretrainedConfig) -> bool:
num_experts = getattr(config, "num_experts", None)
......@@ -124,7 +123,7 @@ class HunYuanSparseMoeBlock(nn.Module):
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank()
self.ep_size = self.ep_group.size()
......@@ -134,7 +133,7 @@ class HunYuanSparseMoeBlock(nn.Module):
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")
# Get layer_id topk if config.moe_topk is a list
if isinstance(config.moe_topk, list):
assert layer_id >= 0
......@@ -238,7 +237,7 @@ class HunYuanAttention(nn.Module):
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
layer_id: int = -1,
layer_id: int = -1
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -327,17 +326,40 @@ class HunYuanAttention(nn.Module):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
ori_k = k
ori_v = v
if self.use_qk_norm:
q = self.query_layernorm(
q.view(-1, self.num_heads, self.head_dim).contiguous())
k = self.key_layernorm(
k.view(-1, self.num_kv_heads, self.head_dim).contiguous())
attn_output = self.attn(q, k, v)
# For o_proj
attn_output = attn_output.view(q.shape[0], -1)
# Transpose.
q = q.view(-1, self.num_heads, self.head_dim)
# Expand the key and value to handle GQA.
num_queries_per_kv = self.num_heads // self.num_kv_heads
k = k.view(-1, self.num_kv_heads, self.head_dim)
k = k.repeat_interleave(num_queries_per_kv, dim=-2)
v = v.view(-1, self.num_kv_heads, self.head_dim)
v = v.repeat_interleave(num_queries_per_kv, dim=-2)
q = q.unsqueeze(0).transpose(1, 2).contiguous()
k = k.unsqueeze(0).transpose(1, 2).contiguous()
v = v.unsqueeze(0).transpose(1, 2).contiguous()
# print("*** hunyuan.py L354 q.shape[0]", q.shape[0])
attn_mask = torch.ones(q.shape[0], 1, q.shape[2], q.shape[2]).cuda().contiguous()
# print("====",q.shape, k.shape, v.shape, attn_mask.shape)
attn_output = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask,
dropout_p=0.0
)
attn_output = attn_output.transpose(1, 2).flatten(-2, -1).squeeze()
output, _ = self.o_proj(attn_output)
return output, (ori_k, v)
return output, (ori_k, ori_v)
class HunYuanCrossAttention(nn.Module):
......@@ -439,6 +461,7 @@ class HunYuanCrossAttention(nn.Module):
hidden_states: torch.Tensor,
kv_states: Optional[Tuple[torch.Tensor]] = None,
) -> torch.Tensor:
attention_mask=[]
assert kv_states is not None
ori_k, v = kv_states # use last layer kv,
k = ori_k
......@@ -450,8 +473,8 @@ class HunYuanCrossAttention(nn.Module):
q.view(-1, self.num_heads, self.head_dim).contiguous())
k = self.key_layernorm(
k.view(-1, self.num_kv_heads, self.head_dim).contiguous())
attn_output = self.attn(q, k, v)
attn_output = self.attn(hidden_states, attention_mask, positions,kv_states)
# For o_proj
attn_output = attn_output.view(q.shape[0], -1)
output, _ = self.o_proj(attn_output)
......@@ -467,14 +490,14 @@ class HunYuanDecoderLayer(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
layer_id: int = -1,
enable_eplb: bool = False,
enable_eplb: bool = False
) -> None:
super().__init__()
assert layer_id >= 0
self.layer_id = layer_id
self.hidden_size = config.hidden_size
self.intermediate_size = (config.intermediate_size if isinstance(
config.intermediate_size, int) else
config.intermediate_size, int) else
config.intermediate_size[layer_id])
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
......@@ -490,7 +513,7 @@ class HunYuanDecoderLayer(nn.Module):
attention_type = (AttentionType.ENCODER_DECODER
if layer_id >= 0 and layer_id % cla_factor != 0 else
AttentionType.DECODER)
if attention_type == AttentionType.DECODER:
self.self_attn = HunYuanAttention(
config=config,
......@@ -505,7 +528,7 @@ class HunYuanDecoderLayer(nn.Module):
bias=attention_bias,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
layer_id=layer_id,
layer_id=layer_id
)
elif attention_type == AttentionType.ENCODER_DECODER:
self.self_attn = HunYuanCrossAttention(
......@@ -526,7 +549,7 @@ class HunYuanDecoderLayer(nn.Module):
else:
raise RuntimeError(f"Unsupported attention type: {attention_type}")
if _is_moe(config):
self.mlp = HunYuanSparseMoeBlock(
config=config,
......@@ -556,8 +579,8 @@ class HunYuanDecoderLayer(nn.Module):
residual: Optional[torch.Tensor],
kv_states: Optional[Tuple[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
residual=hidden_states
hidden_states = self.input_layernorm(hidden_states)
......@@ -566,12 +589,13 @@ class HunYuanDecoderLayer(nn.Module):
hidden_states=hidden_states,
kv_states=kv_states,
)
hidden_states =residual+hidden_states
residual=hidden_states
hidden_states= self.post_attention_layernorm(hidden_states)
hidden_states=self.mlp(hidden_states)
hidden_states=hidden_states+residual
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = hidden_states + residual
return hidden_states, residual, ori_kv_states
......@@ -581,7 +605,6 @@ class HunYuanModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
......@@ -614,16 +637,15 @@ class HunYuanModel(nn.Module):
layer_id=int(prefix.split(".")[-1]),
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
prefix=prefix
),
prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
......@@ -644,9 +666,9 @@ class HunYuanModel(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
cla_factor = _get_cla_factor(self.config)
prev_kv_states = None
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual, kv_states = layer(
......@@ -655,7 +677,7 @@ class HunYuanModel(nn.Module):
residual,
prev_kv_states,
)
if (getattr(self.config, "use_cla", False)
and (i - self.start_layer) % cla_factor == 0):
prev_kv_states = kv_states
......@@ -692,7 +714,7 @@ class HunYuanModel(nn.Module):
k = k.reshape(-1, hidden_size)
v = v.reshape(-1, hidden_size)
return torch.concat((q, k, v))
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
if _is_moe(self.config):
# Params for weights, fp8 weight scales, fp8 activation scales
......@@ -706,7 +728,7 @@ class HunYuanModel(nn.Module):
)
else:
return []
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
cla_factor = _get_cla_factor(self.config)
stacked_params_mapping = [
......@@ -894,11 +916,6 @@ class HunYuanModel(nn.Module):
return loaded_params
class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
......@@ -920,10 +937,10 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.config = config
self.quant_config = quant_config
self.pad_id = self.config.pad_id
self.model = HunYuanModel(vllm_config=vllm_config, prefix="model")
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
......@@ -979,9 +996,6 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def forward(
self,
input_ids: torch.Tensor,
......@@ -990,13 +1004,13 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return model_output
def compute_logits(
self,
......@@ -1030,4 +1044,4 @@ class HunYuanForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
......@@ -23,7 +23,7 @@ class PoolingParams(
truncate_prompt_tokens: Controls prompt truncation.
Set to -1 to use the model's default truncation size.
Set to k to keep only the last k tokens (left truncation).
Set to None to disable truncation.
Set to None to disable truncation.
normalize: Whether to normalize the embeddings outputs.
dimensions: Reduce the dimensions of embeddings
if model support matryoshka representation.
......@@ -64,6 +64,8 @@ class PoolingParams(
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
qfeat: Optional[list] = None
@property
def all_parameters(self) -> list[str]:
return [
......@@ -184,6 +186,7 @@ class PoolingParams(
f"step_tag_id={self.step_tag_id}, "
f"returned_token_ids={self.returned_token_ids}, "
f"requires_token_ids={self.requires_token_ids}, "
f"qfeat={self.qfeat}, "
f"extra_kwargs={self.extra_kwargs})")
def __post_init__(self) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment