Commit 0c2ff348 authored by thomwolf's avatar thomwolf
Browse files

extracting double hidden-state from xlnet

parent 3deea56c
...@@ -703,8 +703,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -703,8 +703,7 @@ class XLNetModel(XLNetPreTrainedModel):
return pos_emb return pos_emb
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None, def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None, mems=None, perm_mask=None, target_mapping=None, inp_q=None, head_mask=None):
output_all_encoded_layers=True, head_mask=None):
""" """
Args: Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs. inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
...@@ -856,13 +855,14 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -856,13 +855,14 @@ class XLNetModel(XLNetPreTrainedModel):
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
# cache new mems # cache new mems
new_mems.append(self.cache_mem(output_h, mems[i])) new_mems.append(self.cache_mem(output_h, mems[i]))
hidden_states.append((output_h, output_g))
output_h, output_g = layer_module(output_h, output_g, output_h, output_g = layer_module(output_h, output_g,
attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask, attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask,
r=pos_emb, seg_mat=seg_mat, r=pos_emb, seg_mat=seg_mat,
mems=mems[i], target_mapping=target_mapping, mems=mems[i], target_mapping=target_mapping,
head_mask=head_mask) head_mask=head_mask)
hidden_states.append(output_h) hidden_states.append((output_h, output_g))
output = self.dropout(output_g if output_g is not None else output_h) output = self.dropout(output_g if output_g is not None else output_h)
# We transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method) # We transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
...@@ -955,7 +955,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -955,7 +955,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None, def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None, mems=None, perm_mask=None, target_mapping=None, inp_q=None,
labels=None, output_all_encoded_layers=True, head_mask=None): labels=None, head_mask=None):
""" """
Args: Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs. inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
...@@ -987,8 +987,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -987,8 +987,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
to pool the input to get a vector representation. to pool the input to get a vector representation.
""" """
output, hidden_states, new_mems = self.transformer(inp_k, token_type_ids, input_mask, attention_mask, output, hidden_states, new_mems = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
mems, perm_mask, target_mapping, inp_q, mems, perm_mask, target_mapping, inp_q, head_mask)
output_all_encoded_layers, head_mask)
logits = self.lm_loss(output) logits = self.lm_loss(output)
...@@ -1001,10 +1000,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1001,10 +1000,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
# if self.output_attentions: # if self.output_attentions:
# all_attentions, encoded_layers = encoded_layers # all_attentions, encoded_layers = encoded_layers
# sequence_output = encoded_layers[-1]
# pooled_output = self.pooler(sequence_output)
# if not output_all_encoded_layers:
# encoded_layers = encoded_layers[-1]
# if self.output_attentions: # if self.output_attentions:
return logits, new_mems return logits, new_mems
# return all_attentions, encoded_layers, pooled_output # return all_attentions, encoded_layers, pooled_output
...@@ -1127,7 +1122,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1127,7 +1122,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None, def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None, mems=None, perm_mask=None, target_mapping=None, inp_q=None,
labels=None, output_all_encoded_layers=True, head_mask=None): labels=None, head_mask=None):
""" """
Args: Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs. inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
...@@ -1156,8 +1151,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1156,8 +1151,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
Set to None during finetuning. Set to None during finetuning.
""" """
output, _, new_mems = self.transformer(inp_k, token_type_ids, input_mask, attention_mask, output, _, new_mems = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
mems, perm_mask, target_mapping, inp_q, mems, perm_mask, target_mapping, inp_q, head_mask)
output_all_encoded_layers, head_mask)
output = self.sequence_summary(output) output = self.sequence_summary(output)
logits = self.logits_proj(output) logits = self.logits_proj(output)
...@@ -1174,10 +1168,6 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1174,10 +1168,6 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
# if self.output_attentions: # if self.output_attentions:
# all_attentions, encoded_layers = encoded_layers # all_attentions, encoded_layers = encoded_layers
# sequence_output = encoded_layers[-1]
# pooled_output = self.pooler(sequence_output)
# if not output_all_encoded_layers:
# encoded_layers = encoded_layers[-1]
# if self.output_attentions: # if self.output_attentions:
return logits, new_mems return logits, new_mems
# return all_attentions, encoded_layers, pooled_output # return all_attentions, encoded_layers, pooled_output
...@@ -1248,11 +1238,9 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): ...@@ -1248,11 +1238,9 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None, def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None, mems=None, perm_mask=None, target_mapping=None, inp_q=None,
start_positions=None, end_positions=None, start_positions=None, end_positions=None, head_mask=None):
output_all_encoded_layers=True, head_mask=None):
output, _, new_mems = self.transformer(inp_k, token_type_ids, input_mask, attention_mask, output, _, new_mems = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
mems, perm_mask, target_mapping, inp_q, mems, perm_mask, target_mapping, inp_q, head_mask)
output_all_encoded_layers, head_mask)
logits = self.qa_outputs(output) logits = self.qa_outputs(output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment