Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d939d6fd
Commit
d939d6fd
authored
Jun 27, 2019
by
thomwolf
Browse files
fix hidden-state extraction
parent
0c2ff348
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
3 deletions
+6
-3
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+6
-3
No files found.
pytorch_pretrained_bert/modeling_xlnet.py
View file @
d939d6fd
...
...
@@ -855,18 +855,21 @@ class XLNetModel(XLNetPreTrainedModel):
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
# cache new mems
new_mems
.
append
(
self
.
cache_mem
(
output_h
,
mems
[
i
]))
hidden_states
.
append
((
output_h
,
output_g
))
hidden_states
.
append
((
output_h
,
output_g
)
if
output_g
is
not
None
else
output_h
)
output_h
,
output_g
=
layer_module
(
output_h
,
output_g
,
attn_mask_h
=
non_tgt_mask
,
attn_mask_g
=
attn_mask
,
r
=
pos_emb
,
seg_mat
=
seg_mat
,
mems
=
mems
[
i
],
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
hidden_states
.
append
((
output_h
,
output_g
))
hidden_states
.
append
((
output_h
,
output_g
)
if
output_g
is
not
None
else
output_h
)
output
=
self
.
dropout
(
output_g
if
output_g
is
not
None
else
output_h
)
# We transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
output
=
output
.
permute
(
1
,
0
,
2
).
contiguous
()
if
output_g
is
not
None
:
hidden_states
=
[
h
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
for
h
in
hs
]
else
:
hidden_states
=
[
hs
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
]
return
output
,
hidden_states
,
new_mems
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment