Commit fe2756ff authored by thomwolf's avatar thomwolf
Browse files

update double head model

parent b509bf76
...@@ -371,7 +371,7 @@ class OpenAIGPTMultipleChoiceHead(nn.Module): ...@@ -371,7 +371,7 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
def forward(self, hidden_states, mc_token_ids): def forward(self, hidden_states, mc_token_ids):
# Classification logits # Classification logits
# hidden_state (bsz, num_choices, seq_length, hidden_size) # hidden_state (bsz, num_choices, seq_length, hidden_size)
# mc_token_ids (bsz, num_choices, 1) # mc_token_ids (bsz, num_choices)
mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1)) mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1))
# (bsz, num_choices, 1, hidden_size) # (bsz, num_choices, 1, hidden_size)
multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2) multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment