Commit 9e61b53d authored by guanyu1's avatar guanyu1
Browse files

test2除了qfeat全部改完

parent fe1ab618
......@@ -6,7 +6,7 @@ from dataclasses import dataclass
from enum import IntEnum
from itertools import groupby
from typing import Callable, Optional, TypeVar, Union
from torch.nn.utils.rnn import pad_sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -641,7 +641,8 @@ class ClassifierPooler(Pooler):
) -> PoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
pooled_data = pad_sequence(pooled_data, batch_first=True, padding_value=0.0)
#pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_size]
pooled_data = pooled_data.to(self.head_dtype)
......
......@@ -356,10 +356,12 @@ def new_hy_05b_dense_official_classification(cls: _T) -> _T:
if isinstance(pooled_output, tuple):
pooled_output = pooled_output[0]
pooled_output = torch.tanh(pooled_output)
pooled_output = self.pool_head2(pooled_output)
reward = self.pool_head2(pooled_output)
if isinstance(pooled_output, tuple):
pooled_output = pooled_output[0]
reward = reward[0]
# last_token_idx = pooling_metadata.pooling_cursor.num_scheduled_tokens_cpu-1
# batch_indices = torch.arange(last_token_idx.size(0))
# reward=pooled_output[batch_indices, last_token_idx,:]
# Select logits at the last non-pad token position per sequence
# seq_length: [batch]
# cursor = pooling_metadata.pooling_cursor
......@@ -374,7 +376,7 @@ def new_hy_05b_dense_official_classification(cls: _T) -> _T:
# reward = pooled_output[torch.arange(batch_size, device=pooled_output.device),
# seq_length].squeeze(-1)
return pooled_output
return reward
def forward(
self,
input_ids: torch.Tensor,
......@@ -552,11 +554,16 @@ def hy_2b_dense_classification_official_hf_multihead_full_mask(cls: _T) -> _T:
qhidden = self.encode_qfeat(qfeat)
a_wei = self.qfeat_fc2(qhidden)
a_bias = self.qfeat_fc3(qhidden)
sat_logits = pooled_output_sat[:,-1,:]
auth_logits = pooled_output_auth[:,-2,:]
time_logits = pooled_output_time[:,-3,:]
rel_logits = pooled_output_rel[:,-4,:]
if pooled_output.size()[1]<3:
batch_size = pooled_output.size(0) # 或 pooled_output.shape[0]
reward = torch.full((batch_size, 1), float('inf'), device=pooled_output.device, dtype=pooled_output.dtype)
return reward
last_token_idx = pooling_metadata.pooling_cursor.num_scheduled_tokens_cpu-1
batch_indices = torch.arange(last_token_idx.size(0))
sat_logits = pooled_output_sat[batch_indices, last_token_idx,:]
auth_logits = pooled_output_auth[batch_indices, last_token_idx-1,:]
time_logits = pooled_output_time[batch_indices, last_token_idx-2,:]
rel_logits = pooled_output_rel[batch_indices, last_token_idx-3,:]
multii_logits = torch.concat([rel_logits, time_logits, auth_logits], dim=1)
task_logits = (a_wei * multii_logits + a_bias).sum(dim=1, keepdim=True)
task_logits = torch.sigmoid(task_logits)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment