Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0c2ff348
Commit
0c2ff348
authored
Jun 27, 2019
by
thomwolf
Browse files
extracting double hidden-state from xlnet
parent
3deea56c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
21 deletions
+9
-21
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+9
-21
No files found.
pytorch_pretrained_bert/modeling_xlnet.py
View file @
0c2ff348
...
@@ -703,8 +703,7 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -703,8 +703,7 @@ class XLNetModel(XLNetPreTrainedModel):
return
pos_emb
return
pos_emb
def
forward
(
self
,
inp_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
def
forward
(
self
,
inp_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
head_mask
=
None
):
output_all_encoded_layers
=
True
,
head_mask
=
None
):
"""
"""
Args:
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
...
@@ -856,13 +855,14 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -856,13 +855,14 @@ class XLNetModel(XLNetPreTrainedModel):
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
# cache new mems
# cache new mems
new_mems
.
append
(
self
.
cache_mem
(
output_h
,
mems
[
i
]))
new_mems
.
append
(
self
.
cache_mem
(
output_h
,
mems
[
i
]))
hidden_states
.
append
((
output_h
,
output_g
))
output_h
,
output_g
=
layer_module
(
output_h
,
output_g
,
output_h
,
output_g
=
layer_module
(
output_h
,
output_g
,
attn_mask_h
=
non_tgt_mask
,
attn_mask_g
=
attn_mask
,
attn_mask_h
=
non_tgt_mask
,
attn_mask_g
=
attn_mask
,
r
=
pos_emb
,
seg_mat
=
seg_mat
,
r
=
pos_emb
,
seg_mat
=
seg_mat
,
mems
=
mems
[
i
],
target_mapping
=
target_mapping
,
mems
=
mems
[
i
],
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
head_mask
=
head_mask
)
hidden_states
.
append
(
output_h
)
hidden_states
.
append
(
(
output_h
,
output_g
)
)
output
=
self
.
dropout
(
output_g
if
output_g
is
not
None
else
output_h
)
output
=
self
.
dropout
(
output_g
if
output_g
is
not
None
else
output_h
)
# We transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
# We transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
...
@@ -955,7 +955,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -955,7 +955,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def
forward
(
self
,
inp_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
def
forward
(
self
,
inp_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
labels
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
labels
=
None
,
head_mask
=
None
):
"""
"""
Args:
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
...
@@ -987,8 +987,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -987,8 +987,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
to pool the input to get a vector representation.
to pool the input to get a vector representation.
"""
"""
output
,
hidden_states
,
new_mems
=
self
.
transformer
(
inp_k
,
token_type_ids
,
input_mask
,
attention_mask
,
output
,
hidden_states
,
new_mems
=
self
.
transformer
(
inp_k
,
token_type_ids
,
input_mask
,
attention_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
head_mask
)
output_all_encoded_layers
,
head_mask
)
logits
=
self
.
lm_loss
(
output
)
logits
=
self
.
lm_loss
(
output
)
...
@@ -1001,10 +1000,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -1001,10 +1000,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
# if self.output_attentions:
# if self.output_attentions:
# all_attentions, encoded_layers = encoded_layers
# all_attentions, encoded_layers = encoded_layers
# sequence_output = encoded_layers[-1]
# pooled_output = self.pooler(sequence_output)
# if not output_all_encoded_layers:
# encoded_layers = encoded_layers[-1]
# if self.output_attentions:
# if self.output_attentions:
return
logits
,
new_mems
return
logits
,
new_mems
# return all_attentions, encoded_layers, pooled_output
# return all_attentions, encoded_layers, pooled_output
...
@@ -1127,7 +1122,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
...
@@ -1127,7 +1122,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
def
forward
(
self
,
inp_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
def
forward
(
self
,
inp_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
labels
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
labels
=
None
,
head_mask
=
None
):
"""
"""
Args:
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
...
@@ -1156,8 +1151,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
...
@@ -1156,8 +1151,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
Set to None during finetuning.
Set to None during finetuning.
"""
"""
output
,
_
,
new_mems
=
self
.
transformer
(
inp_k
,
token_type_ids
,
input_mask
,
attention_mask
,
output
,
_
,
new_mems
=
self
.
transformer
(
inp_k
,
token_type_ids
,
input_mask
,
attention_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
head_mask
)
output_all_encoded_layers
,
head_mask
)
output
=
self
.
sequence_summary
(
output
)
output
=
self
.
sequence_summary
(
output
)
logits
=
self
.
logits_proj
(
output
)
logits
=
self
.
logits_proj
(
output
)
...
@@ -1174,10 +1168,6 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
...
@@ -1174,10 +1168,6 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
# if self.output_attentions:
# if self.output_attentions:
# all_attentions, encoded_layers = encoded_layers
# all_attentions, encoded_layers = encoded_layers
# sequence_output = encoded_layers[-1]
# pooled_output = self.pooler(sequence_output)
# if not output_all_encoded_layers:
# encoded_layers = encoded_layers[-1]
# if self.output_attentions:
# if self.output_attentions:
return
logits
,
new_mems
return
logits
,
new_mems
# return all_attentions, encoded_layers, pooled_output
# return all_attentions, encoded_layers, pooled_output
...
@@ -1248,11 +1238,9 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
...
@@ -1248,11 +1238,9 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
def
forward
(
self
,
inp_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
def
forward
(
self
,
inp_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
start_positions
=
None
,
end_positions
=
None
,
start_positions
=
None
,
end_positions
=
None
,
head_mask
=
None
):
output_all_encoded_layers
=
True
,
head_mask
=
None
):
output
,
_
,
new_mems
=
self
.
transformer
(
inp_k
,
token_type_ids
,
input_mask
,
attention_mask
,
output
,
_
,
new_mems
=
self
.
transformer
(
inp_k
,
token_type_ids
,
input_mask
,
attention_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
head_mask
)
output_all_encoded_layers
,
head_mask
)
logits
=
self
.
qa_outputs
(
output
)
logits
=
self
.
qa_outputs
(
output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment