Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
971c2468
Commit
971c2468
authored
Jul 03, 2019
by
LysandreJik
Browse files
XLNET can be exported to TorchScript
parent
be54b169
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
21 deletions
+22
-21
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+22
-21
No files found.
pytorch_pretrained_bert/modeling_xlnet.py
View file @
971c2468
...
@@ -384,7 +384,8 @@ class XLNetRelativeAttention(nn.Module):
...
@@ -384,7 +384,8 @@ class XLNetRelativeAttention(nn.Module):
x
=
x
.
reshape
(
x_size
[
1
],
x_size
[
0
],
x_size
[
2
],
x_size
[
3
])
x
=
x
.
reshape
(
x_size
[
1
],
x_size
[
0
],
x_size
[
2
],
x_size
[
3
])
x
=
x
[
1
:,
...]
x
=
x
[
1
:,
...]
x
=
x
.
reshape
(
x_size
[
0
],
x_size
[
1
]
-
1
,
x_size
[
2
],
x_size
[
3
])
x
=
x
.
reshape
(
x_size
[
0
],
x_size
[
1
]
-
1
,
x_size
[
2
],
x_size
[
3
])
x
=
x
[:,
0
:
klen
,
:,
:]
# x = x[:, 0:klen, :, :]
x
=
torch
.
index_select
(
x
,
1
,
torch
.
arange
(
klen
))
return
x
return
x
...
@@ -527,9 +528,9 @@ class XLNetRelativeAttention(nn.Module):
...
@@ -527,9 +528,9 @@ class XLNetRelativeAttention(nn.Module):
output_h
=
self
.
post_attention
(
h
,
attn_vec
)
output_h
=
self
.
post_attention
(
h
,
attn_vec
)
output_g
=
None
output_g
=
None
outputs
=
[
output_h
,
output_g
]
outputs
=
(
output_h
,
output_g
)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
outputs
=
outputs
+
[
attn_prob
]
outputs
+=
(
attn_prob
,)
return
outputs
return
outputs
class
XLNetFeedForward
(
nn
.
Module
):
class
XLNetFeedForward
(
nn
.
Module
):
...
@@ -574,7 +575,7 @@ class XLNetLayer(nn.Module):
...
@@ -574,7 +575,7 @@ class XLNetLayer(nn.Module):
output_g
=
self
.
ff
(
output_g
)
output_g
=
self
.
ff
(
output_g
)
output_h
=
self
.
ff
(
output_h
)
output_h
=
self
.
ff
(
output_h
)
outputs
=
[
output_h
,
output_g
]
+
outputs
[
2
:]
# Add again attentions if there are there
outputs
=
(
output_h
,
output_g
)
+
outputs
[
2
:]
# Add again attentions if there are there
return
outputs
return
outputs
...
@@ -688,7 +689,7 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -688,7 +689,7 @@ class XLNetModel(XLNetPreTrainedModel):
def
relative_positional_encoding
(
self
,
qlen
,
klen
,
bsz
=
None
):
def
relative_positional_encoding
(
self
,
qlen
,
klen
,
bsz
=
None
):
"""create relative positional encoding."""
"""create relative positional encoding."""
freq_seq
=
torch
.
arange
(
0
,
self
.
d_model
,
2.0
,
dtype
=
torch
.
float
)
freq_seq
=
torch
.
arange
(
0
,
self
.
d_model
,
2.0
,
dtype
=
torch
.
float
)
inv_freq
=
1
/
(
10000
**
(
freq_seq
/
self
.
d_model
))
inv_freq
=
1
/
torch
.
pow
(
10000
,
(
freq_seq
/
self
.
d_model
))
if
self
.
attn_type
==
'bi'
:
if
self
.
attn_type
==
'bi'
:
# beg, end = klen - 1, -qlen
# beg, end = klen - 1, -qlen
...
@@ -869,7 +870,7 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -869,7 +870,7 @@ class XLNetModel(XLNetPreTrainedModel):
else
:
else
:
head_mask
=
[
None
]
*
self
.
n_layer
head_mask
=
[
None
]
*
self
.
n_layer
new_mems
=
[]
new_mems
=
()
if
mems
is
None
:
if
mems
is
None
:
mems
=
[
None
]
*
len
(
self
.
layer
)
mems
=
[
None
]
*
len
(
self
.
layer
)
...
@@ -877,7 +878,7 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -877,7 +878,7 @@ class XLNetModel(XLNetPreTrainedModel):
hidden_states
=
[]
hidden_states
=
[]
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
# cache new mems
# cache new mems
new_mems
.
append
(
self
.
cache_mem
(
output_h
,
mems
[
i
]))
new_mems
+=
(
self
.
cache_mem
(
output_h
,
mems
[
i
])
,
)
if
self
.
output_hidden_states
:
if
self
.
output_hidden_states
:
hidden_states
.
append
((
output_h
,
output_g
)
if
output_g
is
not
None
else
output_h
)
hidden_states
.
append
((
output_h
,
output_g
)
if
output_g
is
not
None
else
output_h
)
...
@@ -895,16 +896,16 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -895,16 +896,16 @@ class XLNetModel(XLNetPreTrainedModel):
output
=
self
.
dropout
(
output_g
if
output_g
is
not
None
else
output_h
)
output
=
self
.
dropout
(
output_g
if
output_g
is
not
None
else
output_h
)
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
outputs
=
[
output
.
permute
(
1
,
0
,
2
).
contiguous
(),
new_mems
]
outputs
=
(
output
.
permute
(
1
,
0
,
2
).
contiguous
(),
new_mems
)
if
self
.
output_hidden_states
:
if
self
.
output_hidden_states
:
if
output_g
is
not
None
:
if
output_g
is
not
None
:
hidden_states
=
[
h
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
for
h
in
hs
]
hidden_states
=
tuple
(
h
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
for
h
in
hs
)
else
:
else
:
hidden_states
=
[
hs
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
]
hidden_states
=
tuple
(
hs
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
)
outputs
.
append
(
hidden_states
)
outputs
+=
(
hidden_states
,
)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
attentions
=
list
(
t
.
permute
(
2
,
3
,
0
,
1
).
contiguous
()
for
t
in
attentions
)
attentions
=
tuple
(
t
.
permute
(
2
,
3
,
0
,
1
).
contiguous
()
for
t
in
attentions
)
outputs
.
append
(
attentions
)
outputs
+=
(
attentions
,
)
return
outputs
# outputs, new_mems, (hidden_states), (attentions)
return
outputs
# outputs, new_mems, (hidden_states), (attentions)
...
@@ -986,7 +987,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -986,7 +987,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def
tie_weights
(
self
):
def
tie_weights
(
self
):
""" Make sure we are sharing the embeddings
""" Make sure we are sharing the embeddings
"""
"""
self
.
lm_loss
.
weight
=
self
.
transformer
.
word_embedding
.
weight
self
.
lm_loss
.
weight
=
nn
.
Parameter
(
self
.
transformer
.
word_embedding
.
weight
.
clone
())
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
...
@@ -1026,14 +1027,14 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -1026,14 +1027,14 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
logits
=
self
.
lm_loss
(
transformer_outputs
[
0
])
logits
=
self
.
lm_loss
(
transformer_outputs
[
0
])
outputs
=
[
logits
]
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
outputs
=
(
logits
,)
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
if
labels
is
not
None
:
if
labels
is
not
None
:
# Flatten the tokens
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
logits
.
view
(
-
1
,
logits
.
size
(
-
1
)),
loss
=
loss_fct
(
logits
.
view
(
-
1
,
logits
.
size
(
-
1
)),
labels
.
view
(
-
1
))
labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
...
@@ -1061,7 +1062,7 @@ class XLNetSequenceSummary(nn.Module):
...
@@ -1061,7 +1062,7 @@ class XLNetSequenceSummary(nn.Module):
output
=
hidden_states
[:,
0
]
output
=
hidden_states
[:,
0
]
elif
self
.
summary_type
==
'mean'
:
elif
self
.
summary_type
==
'mean'
:
output
=
hidden_states
.
mean
(
dim
=
1
)
output
=
hidden_states
.
mean
(
dim
=
1
)
elif
summary_type
==
'attn'
:
elif
self
.
summary_type
==
'attn'
:
raise
NotImplementedError
raise
NotImplementedError
output
=
self
.
summary
(
output
)
output
=
self
.
summary
(
output
)
...
@@ -1180,7 +1181,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
...
@@ -1180,7 +1181,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
output
=
self
.
sequence_summary
(
output
)
output
=
self
.
sequence_summary
(
output
)
logits
=
self
.
logits_proj
(
output
)
logits
=
self
.
logits_proj
(
output
)
outputs
=
[
logits
]
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
outputs
=
(
logits
,)
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
if
labels
is
not
None
:
if
labels
is
not
None
:
if
self
.
num_labels
==
1
:
if
self
.
num_labels
==
1
:
...
@@ -1190,7 +1191,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
...
@@ -1190,7 +1191,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
else
:
else
:
loss_fct
=
CrossEntropyLoss
()
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
...
@@ -1271,7 +1272,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
...
@@ -1271,7 +1272,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
start_logits
=
start_logits
.
squeeze
(
-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
outputs
=
[
start_logits
,
end_logits
]
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
outputs
=
(
start_logits
,
end_logits
,)
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
if
start_positions
is
not
None
and
end_positions
is
not
None
:
if
start_positions
is
not
None
and
end_positions
is
not
None
:
# If we are on multi-GPU, split add a dimension
# If we are on multi-GPU, split add a dimension
...
@@ -1288,6 +1289,6 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
...
@@ -1288,6 +1289,6 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
total_loss
=
(
start_loss
+
end_loss
)
/
2
outputs
=
[
total_loss
]
+
outputs
outputs
=
(
total_loss
,)
+
outputs
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment