Commit 0c2ff348 authored by thomwolf's avatar thomwolf
Browse files

extracting double hidden-state from xlnet

parent 3deea56c
......@@ -703,8 +703,7 @@ class XLNetModel(XLNetPreTrainedModel):
return pos_emb
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
output_all_encoded_layers=True, head_mask=None):
mems=None, perm_mask=None, target_mapping=None, inp_q=None, head_mask=None):
"""
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
......@@ -856,13 +855,14 @@ class XLNetModel(XLNetPreTrainedModel):
for i, layer_module in enumerate(self.layer):
# cache new mems
new_mems.append(self.cache_mem(output_h, mems[i]))
hidden_states.append((output_h, output_g))
output_h, output_g = layer_module(output_h, output_g,
attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask,
r=pos_emb, seg_mat=seg_mat,
mems=mems[i], target_mapping=target_mapping,
head_mask=head_mask)
hidden_states.append(output_h)
hidden_states.append((output_h, output_g))
output = self.dropout(output_g if output_g is not None else output_h)
# We transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
......@@ -955,7 +955,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
labels=None, output_all_encoded_layers=True, head_mask=None):
labels=None, head_mask=None):
"""
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
......@@ -987,8 +987,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
to pool the input to get a vector representation.
"""
output, hidden_states, new_mems = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
mems, perm_mask, target_mapping, inp_q,
output_all_encoded_layers, head_mask)
mems, perm_mask, target_mapping, inp_q, head_mask)
logits = self.lm_loss(output)
......@@ -1001,10 +1000,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
# if self.output_attentions:
# all_attentions, encoded_layers = encoded_layers
# sequence_output = encoded_layers[-1]
# pooled_output = self.pooler(sequence_output)
# if not output_all_encoded_layers:
# encoded_layers = encoded_layers[-1]
# if self.output_attentions:
return logits, new_mems
# return all_attentions, encoded_layers, pooled_output
......@@ -1127,7 +1122,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
labels=None, output_all_encoded_layers=True, head_mask=None):
labels=None, head_mask=None):
"""
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
......@@ -1156,8 +1151,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
Set to None during finetuning.
"""
output, _, new_mems = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
mems, perm_mask, target_mapping, inp_q,
output_all_encoded_layers, head_mask)
mems, perm_mask, target_mapping, inp_q, head_mask)
output = self.sequence_summary(output)
logits = self.logits_proj(output)
......@@ -1174,10 +1168,6 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
# if self.output_attentions:
# all_attentions, encoded_layers = encoded_layers
# sequence_output = encoded_layers[-1]
# pooled_output = self.pooler(sequence_output)
# if not output_all_encoded_layers:
# encoded_layers = encoded_layers[-1]
# if self.output_attentions:
return logits, new_mems
# return all_attentions, encoded_layers, pooled_output
......@@ -1248,11 +1238,9 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
start_positions=None, end_positions=None,
output_all_encoded_layers=True, head_mask=None):
start_positions=None, end_positions=None, head_mask=None):
output, _, new_mems = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
mems, perm_mask, target_mapping, inp_q,
output_all_encoded_layers, head_mask)
mems, perm_mask, target_mapping, inp_q, head_mask)
logits = self.qa_outputs(output)
start_logits, end_logits = logits.split(1, dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment