"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "a21cedf4ff1facaee601a635e3c092fe02742290"
Unverified Commit 01aae1cc authored by Maybewuss's avatar Maybewuss Committed by GitHub
Browse files

[Model] Remove redundant softmax when using PoolingType.STEP (#10415)

parent c7dec926
...@@ -118,14 +118,13 @@ class Pooler(nn.Module): ...@@ -118,14 +118,13 @@ class Pooler(nn.Module):
if returned_token_ids is not None and len(returned_token_ids) > 0: if returned_token_ids is not None and len(returned_token_ids) > 0:
hidden_states = hidden_states[:, returned_token_ids] hidden_states = hidden_states[:, returned_token_ids]
logits = hidden_states.softmax(dim=-1)
step_tag_id = self.step_tag_id step_tag_id = self.step_tag_id
offset = 0 offset = 0
pooled_data_lst = [] pooled_data_lst = []
for prompt_len, seq_data_i in zip( for prompt_len, seq_data_i in zip(
prompt_lens, pooling_metadata.seq_data.values()): prompt_lens, pooling_metadata.seq_data.values()):
pooled_data_i = logits[offset:offset + prompt_len] pooled_data_i = hidden_states[offset:offset + prompt_len]
if step_tag_id is not None: if step_tag_id is not None:
token_ids = torch.tensor(seq_data_i.prompt_token_ids) token_ids = torch.tensor(seq_data_i.prompt_token_ids)
pooled_data_i = pooled_data_i[token_ids == step_tag_id] pooled_data_i = pooled_data_i[token_ids == step_tag_id]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment