Commit fe1ab618 authored by guanyu1's avatar guanyu1
Browse files

test2_强制修改为ALL

parent 3146b529
...@@ -322,7 +322,7 @@ def new_hy_05b_dense_official_classification(cls: _T) -> _T: ...@@ -322,7 +322,7 @@ def new_hy_05b_dense_official_classification(cls: _T) -> _T:
quant_config=quant_config, quant_config=quant_config,
params_dtype=torch.float32, params_dtype=torch.float32,
prefix=maybe_prefix(prefix, "pool_head2"), prefix=maybe_prefix(prefix, "pool_head2"),
return_bias=True, return_bias=False,
) )
# 兼容 ForSequenceClassification:将 score 直接指向最终分类头 # 兼容 ForSequenceClassification:将 score 直接指向最终分类头
# 不再单独创建一层;pool_head2 即最终打分层 # 不再单独创建一层;pool_head2 即最终打分层
...@@ -457,7 +457,7 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T: ...@@ -457,7 +457,7 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
quant_config=quant_config, quant_config=quant_config,
params_dtype=torch.float32, params_dtype=torch.float32,
prefix=maybe_prefix(prefix, "pool_head2"), prefix=maybe_prefix(prefix, "pool_head2"),
return_bias=True, return_bias=False,
) )
self.qfeat_emb =ReplicatedLinear( self.qfeat_emb =ReplicatedLinear(
2, 2,
...@@ -514,9 +514,10 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T: ...@@ -514,9 +514,10 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
"PoolerConfig must be provided to use classification head") "PoolerConfig must be provided to use classification head")
# Determine pooling type (fallback to config.pool_type) # Determine pooling type (fallback to config.pool_type)
pooling_type_str = (pooler_config.pooling_type # pooling_type_str = (pooler_config.pooling_type
if pooler_config.pooling_type is not None # if pooler_config.pooling_type is not None
else getattr(config, "pool_type", "LAST")).upper() # else getattr(config, "pool_type", "LAST")).upper()
pooling_type_str="ALL"
if pooling_type_str == "LASTTOKEN": if pooling_type_str == "LASTTOKEN":
pooling_type_str = "LAST" pooling_type_str = "LAST"
pooling_type = PoolingType[pooling_type_str] pooling_type = PoolingType[pooling_type_str]
...@@ -552,10 +553,10 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T: ...@@ -552,10 +553,10 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
a_wei = self.qfeat_fc2(qhidden) a_wei = self.qfeat_fc2(qhidden)
a_bias = self.qfeat_fc3(qhidden) a_bias = self.qfeat_fc3(qhidden)
sat_logits = pooled_output_sat[:,-1] sat_logits = pooled_output_sat[:,-1,:]
auth_logits = pooled_output_auth[:,-2] auth_logits = pooled_output_auth[:,-2,:]
time_logits = pooled_output_time[:,-3] time_logits = pooled_output_time[:,-3,:]
rel_logits = pooled_output_rel[:,-4] rel_logits = pooled_output_rel[:,-4,:]
multii_logits = torch.concat([rel_logits, time_logits, auth_logits], dim=1) multii_logits = torch.concat([rel_logits, time_logits, auth_logits], dim=1)
task_logits = (a_wei * multii_logits + a_bias).sum(dim=1, keepdim=True) task_logits = (a_wei * multii_logits + a_bias).sum(dim=1, keepdim=True)
task_logits = torch.sigmoid(task_logits) task_logits = torch.sigmoid(task_logits)
...@@ -586,7 +587,7 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T: ...@@ -586,7 +587,7 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
return pooled_output#reward return reward
def forward( def forward(
self, self,
......
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
try: try:
# __version__ = "0.11.0" #__version__ = "0.11.0"
# __version_tuple__ = (0, 11, 0) #__version_tuple__ = (0, 11, 0)
# __hcu_version__ = f'0.11.0+das.opt1.alpha.c16e075.dtk25042' #__hcu_version__ = f'0.11.0+das.opt1.alpha.c16e075.dtk25042'
# from vllm.version import __version__, __version_tuple__, __hcu_version__ #from vllm.version import __version__, __version_tuple__, __hcu_version__
from ._version import __version__, __version_tuple__ from ._version import __version__, __version_tuple__
except Exception as e: except Exception as e:
import warnings import warnings
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment