Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d8e83de7
Commit
d8e83de7
authored
Jul 02, 2019
by
LysandreJik
Browse files
GPT2 can be exported to TorchScript
parent
e891bb43
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
15 deletions
+16
-15
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+16
-15
No files found.
pytorch_pretrained_bert/modeling_gpt2.py
View file @
d8e83de7
...
@@ -328,7 +328,8 @@ class GPT2LMHead(nn.Module):
...
@@ -328,7 +328,8 @@ class GPT2LMHead(nn.Module):
def
set_embeddings_weights
(
self
,
model_embeddings_weights
,
predict_special_tokens
=
True
):
def
set_embeddings_weights
(
self
,
model_embeddings_weights
,
predict_special_tokens
=
True
):
self
.
predict_special_tokens
=
predict_special_tokens
self
.
predict_special_tokens
=
predict_special_tokens
self
.
decoder
.
weight
=
model_embeddings_weights
# Tied weights
# Export to TorchScript can't handle parameter sharing so we are cloning them.
self
.
decoder
.
weight
=
nn
.
Parameter
(
model_embeddings_weights
.
clone
())
# Tied weights
def
forward
(
self
,
hidden_state
):
def
forward
(
self
,
hidden_state
):
lm_logits
=
self
.
decoder
(
hidden_state
)
lm_logits
=
self
.
decoder
(
hidden_state
)
...
@@ -557,16 +558,16 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -557,16 +558,16 @@ class GPT2Model(GPT2PreTrainedModel):
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
presents
=
[]
presents
=
()
all_attentions
=
[]
all_attentions
=
[]
all_hidden_states
=
[]
all_hidden_states
=
()
for
i
,
(
block
,
layer_past
)
in
enumerate
(
zip
(
self
.
h
,
past
)):
for
i
,
(
block
,
layer_past
)
in
enumerate
(
zip
(
self
.
h
,
past
)):
if
self
.
output_hidden_states
:
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
.
view
(
*
output_shape
))
all_hidden_states
+=
(
hidden_states
.
view
(
*
output_shape
)
,
)
outputs
=
block
(
hidden_states
,
layer_past
,
head_mask
[
i
])
outputs
=
block
(
hidden_states
,
layer_past
,
head_mask
[
i
])
hidden_states
,
present
=
outputs
[:
2
]
hidden_states
,
present
=
outputs
[:
2
]
presents
.
append
(
present
)
presents
+=
(
present
,
)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
all_attentions
.
append
(
outputs
[
2
])
all_attentions
.
append
(
outputs
[
2
])
...
@@ -576,16 +577,16 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -576,16 +577,16 @@ class GPT2Model(GPT2PreTrainedModel):
hidden_states
=
hidden_states
.
view
(
*
output_shape
)
hidden_states
=
hidden_states
.
view
(
*
output_shape
)
# Add last hidden state
# Add last hidden state
if
self
.
output_hidden_states
:
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
)
all_hidden_states
+=
(
hidden_states
,
)
outputs
=
[
hidden_states
,
presents
]
outputs
=
(
hidden_states
,
presents
)
if
self
.
output_hidden_states
:
if
self
.
output_hidden_states
:
outputs
.
append
(
all_hidden_states
)
outputs
+=
(
all_hidden_states
,
)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
# let the number of heads free (-1) so we can extract attention even after head pruning
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape
=
input_shape
[:
-
1
]
+
(
-
1
,)
+
all_attentions
[
0
].
shape
[
-
2
:]
attention_output_shape
=
input_shape
[:
-
1
]
+
(
-
1
,)
+
all_attentions
[
0
].
shape
[
-
2
:]
all_attentions
=
list
(
t
.
view
(
*
attention_output_shape
)
for
t
in
all_attentions
)
all_attentions
=
tuple
(
t
.
view
(
*
attention_output_shape
)
for
t
in
all_attentions
)
outputs
.
append
(
all_attentions
)
outputs
+=
(
all_attentions
,
)
return
outputs
# last hidden state, presents, (all hidden_states), (attentions)
return
outputs
# last hidden state, presents, (all hidden_states), (attentions)
...
@@ -658,7 +659,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
...
@@ -658,7 +659,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
outputs
=
[
lm_logits
]
+
transformer_outputs
[
1
:]
outputs
=
(
lm_logits
,)
+
transformer_outputs
[
1
:]
if
lm_labels
is
not
None
:
if
lm_labels
is
not
None
:
# Shift so that tokens < n predict n
# Shift so that tokens < n predict n
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
...
@@ -667,7 +668,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
...
@@ -667,7 +668,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
shift_labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), lm_logits, presents, (all hidden_states), (attentions)
return
outputs
# (loss), lm_logits, presents, (all hidden_states), (attentions)
...
@@ -750,18 +751,18 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
...
@@ -750,18 +751,18 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
outputs
=
[
lm_logits
,
mc_logits
]
+
transformer_outputs
[
1
:]
outputs
=
(
lm_logits
,
mc_logits
)
+
transformer_outputs
[
1
:]
if
mc_labels
is
not
None
:
if
mc_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
loss
=
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
))
mc_labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
if
lm_labels
is
not
None
:
if
lm_labels
is
not
None
:
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
lm_labels
[...,
1
:].
contiguous
()
shift_labels
=
lm_labels
[...,
1
:].
contiguous
()
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
shift_labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)
return
outputs
# (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment