Commit e99b2014 authored by thomwolf's avatar thomwolf
Browse files

fixes #471

parent 19666dcb
...@@ -371,8 +371,8 @@ class OpenAIGPTMultipleChoiceHead(nn.Module): ...@@ -371,8 +371,8 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
def forward(self, hidden_states, mc_token_ids): def forward(self, hidden_states, mc_token_ids):
# Classification logits # Classification logits
# hidden_state (bsz, num_choices, seq_length, hidden_size) # hidden_state (bsz, num_choices, seq_length, hidden_size)
# mc_token_ids (bsz, num_choices) # mc_token_ids (bsz, num_choices, 1)
mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1)) mc_token_ids = mc_token_ids.unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1))
# (bsz, num_choices, 1, hidden_size) # (bsz, num_choices, 1, hidden_size)
multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2) multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)
# (bsz, num_choices, hidden_size) # (bsz, num_choices, hidden_size)
...@@ -605,14 +605,12 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -605,14 +605,12 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
return return
# Update config # Update config
self.config.n_special = num_special_tokens self.config.n_special = num_special_tokens
# # Build new embeddings and initialize # Build new embeddings and initialize all new embeddings (in particular the special tokens)
old_embed = self.tokens_embed old_embed = self.tokens_embed
self.tokens_embed = nn.Embedding(self.config.total_tokens_embeddings, self.config.n_embd) self.tokens_embed = nn.Embedding(self.config.total_tokens_embeddings, self.config.n_embd)
# Initialize all new embeddings (in particular the special tokens)
self.init_weights(self.tokens_embed) self.init_weights(self.tokens_embed)
# Copy word and positional embeddings from the previous weights # Copy word embeddings from the previous weights
self.tokens_embed.weight.data[: self.config.vocab_size, :] = old_embed.weight.data[: self.config.vocab_size, :] self.tokens_embed.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]
self.tokens_embed.weight.data[-self.config.n_positions :, :] = old_embed.weight.data[-self.config.n_positions :, :]
def forward(self, input_ids, position_ids=None, token_type_ids=None): def forward(self, input_ids, position_ids=None, token_type_ids=None):
if position_ids is None: if position_ids is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment