Commit 9e61b53d authored by guanyu1's avatar guanyu1
Browse files

test2除了qfeat全部改完

parent fe1ab618
...@@ -6,7 +6,7 @@ from dataclasses import dataclass ...@@ -6,7 +6,7 @@ from dataclasses import dataclass
from enum import IntEnum from enum import IntEnum
from itertools import groupby from itertools import groupby
from typing import Callable, Optional, TypeVar, Union from typing import Callable, Optional, TypeVar, Union
from torch.nn.utils.rnn import pad_sequence
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -641,7 +641,8 @@ class ClassifierPooler(Pooler): ...@@ -641,7 +641,8 @@ class ClassifierPooler(Pooler):
) -> PoolerOutput: ) -> PoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata) pooled_data = self.pooling(hidden_states, pooling_metadata)
if isinstance(pooled_data, list): if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data) pooled_data = pad_sequence(pooled_data, batch_first=True, padding_value=0.0)
#pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_size] # pooled_data shape: [batchsize, hidden_size]
pooled_data = pooled_data.to(self.head_dtype) pooled_data = pooled_data.to(self.head_dtype)
......
...@@ -356,10 +356,12 @@ def new_hy_05b_dense_official_classification(cls: _T) -> _T: ...@@ -356,10 +356,12 @@ def new_hy_05b_dense_official_classification(cls: _T) -> _T:
if isinstance(pooled_output, tuple): if isinstance(pooled_output, tuple):
pooled_output = pooled_output[0] pooled_output = pooled_output[0]
pooled_output = torch.tanh(pooled_output) pooled_output = torch.tanh(pooled_output)
pooled_output = self.pool_head2(pooled_output) reward = self.pool_head2(pooled_output)
if isinstance(pooled_output, tuple): if isinstance(pooled_output, tuple):
pooled_output = pooled_output[0] reward = reward[0]
# last_token_idx = pooling_metadata.pooling_cursor.num_scheduled_tokens_cpu-1
# batch_indices = torch.arange(last_token_idx.size(0))
# reward=pooled_output[batch_indices, last_token_idx,:]
# Select logits at the last non-pad token position per sequence # Select logits at the last non-pad token position per sequence
# seq_length: [batch] # seq_length: [batch]
# cursor = pooling_metadata.pooling_cursor # cursor = pooling_metadata.pooling_cursor
...@@ -374,7 +376,7 @@ def new_hy_05b_dense_official_classification(cls: _T) -> _T: ...@@ -374,7 +376,7 @@ def new_hy_05b_dense_official_classification(cls: _T) -> _T:
# reward = pooled_output[torch.arange(batch_size, device=pooled_output.device), # reward = pooled_output[torch.arange(batch_size, device=pooled_output.device),
# seq_length].squeeze(-1) # seq_length].squeeze(-1)
return pooled_output return reward
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -552,11 +554,16 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T: ...@@ -552,11 +554,16 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
qhidden = self.encode_qfeat(qfeat) qhidden = self.encode_qfeat(qfeat)
a_wei = self.qfeat_fc2(qhidden) a_wei = self.qfeat_fc2(qhidden)
a_bias = self.qfeat_fc3(qhidden) a_bias = self.qfeat_fc3(qhidden)
if pooled_output.size()[1]<3:
sat_logits = pooled_output_sat[:,-1,:] batch_size = pooled_output.size(0) # 或 pooled_output.shape[0]
auth_logits = pooled_output_auth[:,-2,:] reward = torch.full((batch_size, 1), float('inf'), device=pooled_output.device, dtype=pooled_output.dtype)
time_logits = pooled_output_time[:,-3,:] return reward
rel_logits = pooled_output_rel[:,-4,:] last_token_idx = pooling_metadata.pooling_cursor.num_scheduled_tokens_cpu-1
batch_indices = torch.arange(last_token_idx.size(0))
sat_logits = pooled_output_sat[batch_indices, last_token_idx,:]
auth_logits = pooled_output_auth[batch_indices, last_token_idx-1,:]
time_logits = pooled_output_time[batch_indices, last_token_idx-2,:]
rel_logits = pooled_output_rel[batch_indices, last_token_idx-3,:]
multii_logits = torch.concat([rel_logits, time_logits, auth_logits], dim=1) multii_logits = torch.concat([rel_logits, time_logits, auth_logits], dim=1)
task_logits = (a_wei * multii_logits + a_bias).sum(dim=1, keepdim=True) task_logits = (a_wei * multii_logits + a_bias).sum(dim=1, keepdim=True)
task_logits = torch.sigmoid(task_logits) task_logits = torch.sigmoid(task_logits)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment