Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
971c2468
Commit
971c2468
authored
Jul 03, 2019
by
LysandreJik
Browse files
XLNET can be exported to TorchScript
parent
be54b169
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
21 deletions
+22
-21
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+22
-21
No files found.
pytorch_pretrained_bert/modeling_xlnet.py
View file @
971c2468
...
...
@@ -384,7 +384,8 @@ class XLNetRelativeAttention(nn.Module):
x
=
x
.
reshape
(
x_size
[
1
],
x_size
[
0
],
x_size
[
2
],
x_size
[
3
])
x
=
x
[
1
:,
...]
x
=
x
.
reshape
(
x_size
[
0
],
x_size
[
1
]
-
1
,
x_size
[
2
],
x_size
[
3
])
x
=
x
[:,
0
:
klen
,
:,
:]
# x = x[:, 0:klen, :, :]
x
=
torch
.
index_select
(
x
,
1
,
torch
.
arange
(
klen
))
return
x
...
...
@@ -527,9 +528,9 @@ class XLNetRelativeAttention(nn.Module):
output_h
=
self
.
post_attention
(
h
,
attn_vec
)
output_g
=
None
outputs
=
[
output_h
,
output_g
]
outputs
=
(
output_h
,
output_g
)
if
self
.
output_attentions
:
outputs
=
outputs
+
[
attn_prob
]
outputs
+=
(
attn_prob
,)
return
outputs
class
XLNetFeedForward
(
nn
.
Module
):
...
...
@@ -574,7 +575,7 @@ class XLNetLayer(nn.Module):
output_g
=
self
.
ff
(
output_g
)
output_h
=
self
.
ff
(
output_h
)
outputs
=
[
output_h
,
output_g
]
+
outputs
[
2
:]
# Add again attentions if there are there
outputs
=
(
output_h
,
output_g
)
+
outputs
[
2
:]
# Add again attentions if there are there
return
outputs
...
...
@@ -688,7 +689,7 @@ class XLNetModel(XLNetPreTrainedModel):
def
relative_positional_encoding
(
self
,
qlen
,
klen
,
bsz
=
None
):
"""create relative positional encoding."""
freq_seq
=
torch
.
arange
(
0
,
self
.
d_model
,
2.0
,
dtype
=
torch
.
float
)
inv_freq
=
1
/
(
10000
**
(
freq_seq
/
self
.
d_model
))
inv_freq
=
1
/
torch
.
pow
(
10000
,
(
freq_seq
/
self
.
d_model
))
if
self
.
attn_type
==
'bi'
:
# beg, end = klen - 1, -qlen
...
...
@@ -869,7 +870,7 @@ class XLNetModel(XLNetPreTrainedModel):
else
:
head_mask
=
[
None
]
*
self
.
n_layer
new_mems
=
[]
new_mems
=
()
if
mems
is
None
:
mems
=
[
None
]
*
len
(
self
.
layer
)
...
...
@@ -877,7 +878,7 @@ class XLNetModel(XLNetPreTrainedModel):
hidden_states
=
[]
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
# cache new mems
new_mems
.
append
(
self
.
cache_mem
(
output_h
,
mems
[
i
]))
new_mems
+=
(
self
.
cache_mem
(
output_h
,
mems
[
i
])
,
)
if
self
.
output_hidden_states
:
hidden_states
.
append
((
output_h
,
output_g
)
if
output_g
is
not
None
else
output_h
)
...
...
@@ -895,16 +896,16 @@ class XLNetModel(XLNetPreTrainedModel):
output
=
self
.
dropout
(
output_g
if
output_g
is
not
None
else
output_h
)
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
outputs
=
[
output
.
permute
(
1
,
0
,
2
).
contiguous
(),
new_mems
]
outputs
=
(
output
.
permute
(
1
,
0
,
2
).
contiguous
(),
new_mems
)
if
self
.
output_hidden_states
:
if
output_g
is
not
None
:
hidden_states
=
[
h
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
for
h
in
hs
]
hidden_states
=
tuple
(
h
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
for
h
in
hs
)
else
:
hidden_states
=
[
hs
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
]
outputs
.
append
(
hidden_states
)
hidden_states
=
tuple
(
hs
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
)
outputs
+=
(
hidden_states
,
)
if
self
.
output_attentions
:
attentions
=
list
(
t
.
permute
(
2
,
3
,
0
,
1
).
contiguous
()
for
t
in
attentions
)
outputs
.
append
(
attentions
)
attentions
=
tuple
(
t
.
permute
(
2
,
3
,
0
,
1
).
contiguous
()
for
t
in
attentions
)
outputs
+=
(
attentions
,
)
return
outputs
# outputs, new_mems, (hidden_states), (attentions)
...
...
@@ -986,7 +987,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def
tie_weights
(
self
):
""" Make sure we are sharing the embeddings
"""
self
.
lm_loss
.
weight
=
self
.
transformer
.
word_embedding
.
weight
self
.
lm_loss
.
weight
=
nn
.
Parameter
(
self
.
transformer
.
word_embedding
.
weight
.
clone
())
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
...
...
@@ -1026,14 +1027,14 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
logits
=
self
.
lm_loss
(
transformer_outputs
[
0
])
outputs
=
[
logits
]
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
outputs
=
(
logits
,)
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
if
labels
is
not
None
:
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
logits
.
view
(
-
1
,
logits
.
size
(
-
1
)),
labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
...
...
@@ -1061,7 +1062,7 @@ class XLNetSequenceSummary(nn.Module):
output
=
hidden_states
[:,
0
]
elif
self
.
summary_type
==
'mean'
:
output
=
hidden_states
.
mean
(
dim
=
1
)
elif
summary_type
==
'attn'
:
elif
self
.
summary_type
==
'attn'
:
raise
NotImplementedError
output
=
self
.
summary
(
output
)
...
...
@@ -1180,7 +1181,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
output
=
self
.
sequence_summary
(
output
)
logits
=
self
.
logits_proj
(
output
)
outputs
=
[
logits
]
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
outputs
=
(
logits
,)
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
if
labels
is
not
None
:
if
self
.
num_labels
==
1
:
...
...
@@ -1190,7 +1191,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
else
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
...
...
@@ -1271,7 +1272,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
outputs
=
[
start_logits
,
end_logits
]
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
outputs
=
(
start_logits
,
end_logits
,)
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
if
start_positions
is
not
None
and
end_positions
is
not
None
:
# If we are on multi-GPU, split add a dimension
...
...
@@ -1288,6 +1289,6 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
outputs
=
[
total_loss
]
+
outputs
outputs
=
(
total_loss
,)
+
outputs
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment