Commit d939d6fd authored by thomwolf's avatar thomwolf
Browse files

fix hidden-state extraction

parent 0c2ff348
...@@ -855,18 +855,21 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -855,18 +855,21 @@ class XLNetModel(XLNetPreTrainedModel):
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
# cache new mems # cache new mems
new_mems.append(self.cache_mem(output_h, mems[i])) new_mems.append(self.cache_mem(output_h, mems[i]))
hidden_states.append((output_h, output_g)) hidden_states.append((output_h, output_g) if output_g is not None else output_h)
output_h, output_g = layer_module(output_h, output_g, output_h, output_g = layer_module(output_h, output_g,
attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask, attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask,
r=pos_emb, seg_mat=seg_mat, r=pos_emb, seg_mat=seg_mat,
mems=mems[i], target_mapping=target_mapping, mems=mems[i], target_mapping=target_mapping,
head_mask=head_mask) head_mask=head_mask)
hidden_states.append((output_h, output_g)) hidden_states.append((output_h, output_g) if output_g is not None else output_h)
output = self.dropout(output_g if output_g is not None else output_h) output = self.dropout(output_g if output_g is not None else output_h)
# We transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method) # We transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
output = output.permute(1, 0, 2).contiguous() output = output.permute(1, 0, 2).contiguous()
if output_g is not None:
hidden_states = [h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs]
else:
hidden_states = [hs.permute(1, 0, 2).contiguous() for hs in hidden_states] hidden_states = [hs.permute(1, 0, 2).contiguous() for hs in hidden_states]
return output, hidden_states, new_mems return output, hidden_states, new_mems
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment