Commit fe1ab618 authored by guanyu1's avatar guanyu1
Browse files

test2_强制修改为ALL

parent 3146b529
......@@ -322,7 +322,7 @@ def new_hy_05b_dense_official_classification(cls: _T) -> _T:
quant_config=quant_config,
params_dtype=torch.float32,
prefix=maybe_prefix(prefix, "pool_head2"),
return_bias=True,
return_bias=False,
)
# 兼容 ForSequenceClassification:将 score 直接指向最终分类头
# 不再单独创建一层;pool_head2 即最终打分层
......@@ -457,7 +457,7 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
quant_config=quant_config,
params_dtype=torch.float32,
prefix=maybe_prefix(prefix, "pool_head2"),
return_bias=True,
return_bias=False,
)
self.qfeat_emb =ReplicatedLinear(
2,
......@@ -514,9 +514,10 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
"PoolerConfig must be provided to use classification head")
# Determine pooling type (fallback to config.pool_type)
pooling_type_str = (pooler_config.pooling_type
if pooler_config.pooling_type is not None
else getattr(config, "pool_type", "LAST")).upper()
# pooling_type_str = (pooler_config.pooling_type
# if pooler_config.pooling_type is not None
# else getattr(config, "pool_type", "LAST")).upper()
pooling_type_str="ALL"
if pooling_type_str == "LASTTOKEN":
pooling_type_str = "LAST"
pooling_type = PoolingType[pooling_type_str]
......@@ -552,10 +553,10 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
a_wei = self.qfeat_fc2(qhidden)
a_bias = self.qfeat_fc3(qhidden)
sat_logits = pooled_output_sat[:,-1]
auth_logits = pooled_output_auth[:,-2]
time_logits = pooled_output_time[:,-3]
rel_logits = pooled_output_rel[:,-4]
sat_logits = pooled_output_sat[:,-1,:]
auth_logits = pooled_output_auth[:,-2,:]
time_logits = pooled_output_time[:,-3,:]
rel_logits = pooled_output_rel[:,-4,:]
multii_logits = torch.concat([rel_logits, time_logits, auth_logits], dim=1)
task_logits = (a_wei * multii_logits + a_bias).sum(dim=1, keepdim=True)
task_logits = torch.sigmoid(task_logits)
......@@ -586,7 +587,7 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
return pooled_output#reward
return reward
def forward(
self,
......
......@@ -2,10 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
try:
# __version__ = "0.11.0"
# __version_tuple__ = (0, 11, 0)
# __hcu_version__ = f'0.11.0+das.opt1.alpha.c16e075.dtk25042'
# from vllm.version import __version__, __version_tuple__, __hcu_version__
#__version__ = "0.11.0"
#__version_tuple__ = (0, 11, 0)
#__hcu_version__ = f'0.11.0+das.opt1.alpha.c16e075.dtk25042'
#from vllm.version import __version__, __version_tuple__, __hcu_version__
from ._version import __version__, __version_tuple__
except Exception as e:
import warnings
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment