Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
275179a0
Commit
275179a0
authored
May 08, 2019
by
thomwolf
Browse files
output attentions in GPT-2
parent
366a3b02
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
47 additions
and
13 deletions
+47
-13
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+47
-13
No files found.
pytorch_pretrained_bert/modeling_gpt2.py
View file @
275179a0
...
@@ -223,7 +223,7 @@ class Conv1D(nn.Module):
...
@@ -223,7 +223,7 @@ class Conv1D(nn.Module):
class
Attention
(
nn
.
Module
):
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
nx
,
n_ctx
,
config
,
scale
=
False
):
def
__init__
(
self
,
nx
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
):
super
(
Attention
,
self
).
__init__
()
super
(
Attention
,
self
).
__init__
()
n_state
=
nx
# in Attention: n_state=768 (nx=n_embd)
n_state
=
nx
# in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
...
@@ -232,6 +232,7 @@ class Attention(nn.Module):
...
@@ -232,6 +232,7 @@ class Attention(nn.Module):
self
.
n_head
=
config
.
n_head
self
.
n_head
=
config
.
n_head
self
.
split_size
=
n_state
self
.
split_size
=
n_state
self
.
scale
=
scale
self
.
scale
=
scale
self
.
output_attentions
=
output_attentions
self
.
c_attn
=
Conv1D
(
n_state
*
3
,
nx
)
self
.
c_attn
=
Conv1D
(
n_state
*
3
,
nx
)
self
.
c_proj
=
Conv1D
(
n_state
,
nx
)
self
.
c_proj
=
Conv1D
(
n_state
,
nx
)
self
.
attn_dropout
=
nn
.
Dropout
(
config
.
attn_pdrop
)
self
.
attn_dropout
=
nn
.
Dropout
(
config
.
attn_pdrop
)
...
@@ -247,6 +248,8 @@ class Attention(nn.Module):
...
@@ -247,6 +248,8 @@ class Attention(nn.Module):
w
=
nn
.
Softmax
(
dim
=-
1
)(
w
)
w
=
nn
.
Softmax
(
dim
=-
1
)(
w
)
w
=
self
.
attn_dropout
(
w
)
w
=
self
.
attn_dropout
(
w
)
if
self
.
output_attentions
:
return
w
,
torch
.
matmul
(
w
,
v
)
return
torch
.
matmul
(
w
,
v
)
return
torch
.
matmul
(
w
,
v
)
def
merge_heads
(
self
,
x
):
def
merge_heads
(
self
,
x
):
...
@@ -274,9 +277,13 @@ class Attention(nn.Module):
...
@@ -274,9 +277,13 @@ class Attention(nn.Module):
value
=
torch
.
cat
((
past_value
,
value
),
dim
=-
2
)
value
=
torch
.
cat
((
past_value
,
value
),
dim
=-
2
)
present
=
torch
.
stack
((
key
.
transpose
(
-
2
,
-
1
),
value
))
# transpose to have same shapes for stacking
present
=
torch
.
stack
((
key
.
transpose
(
-
2
,
-
1
),
value
))
# transpose to have same shapes for stacking
a
=
self
.
_attn
(
query
,
key
,
value
)
a
=
self
.
_attn
(
query
,
key
,
value
)
if
self
.
output_attentions
:
attentions
,
a
=
a
a
=
self
.
merge_heads
(
a
)
a
=
self
.
merge_heads
(
a
)
a
=
self
.
c_proj
(
a
)
a
=
self
.
c_proj
(
a
)
a
=
self
.
resid_dropout
(
a
)
a
=
self
.
resid_dropout
(
a
)
if
self
.
output_attentions
:
return
attentions
,
a
,
present
return
a
,
present
return
a
,
present
...
@@ -296,19 +303,26 @@ class MLP(nn.Module):
...
@@ -296,19 +303,26 @@ class MLP(nn.Module):
class
Block
(
nn
.
Module
):
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
n_ctx
,
config
,
scale
=
False
):
def
__init__
(
self
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
):
super
(
Block
,
self
).
__init__
()
super
(
Block
,
self
).
__init__
()
nx
=
config
.
n_embd
nx
=
config
.
n_embd
self
.
output_attentions
=
output_attentions
self
.
ln_1
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_1
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
Attention
(
nx
,
n_ctx
,
config
,
scale
)
self
.
attn
=
Attention
(
nx
,
n_ctx
,
config
,
scale
,
output_attentions
)
self
.
ln_2
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_2
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
MLP
(
4
*
nx
,
config
)
self
.
mlp
=
MLP
(
4
*
nx
,
config
)
def
forward
(
self
,
x
,
layer_past
=
None
):
def
forward
(
self
,
x
,
layer_past
=
None
):
a
,
present
=
self
.
attn
(
self
.
ln_1
(
x
),
layer_past
=
layer_past
)
output_attn
=
self
.
attn
(
self
.
ln_1
(
x
),
layer_past
=
layer_past
)
if
self
.
output_attentions
:
attentions
,
a
,
present
=
output_attn
else
:
a
,
present
=
output_attn
x
=
x
+
a
x
=
x
+
a
m
=
self
.
mlp
(
self
.
ln_2
(
x
))
m
=
self
.
mlp
(
self
.
ln_2
(
x
))
x
=
x
+
m
x
=
x
+
m
if
self
.
output_attentions
:
return
attentions
,
x
,
present
return
x
,
present
return
x
,
present
...
@@ -567,12 +581,13 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -567,12 +581,13 @@ class GPT2Model(GPT2PreTrainedModel):
```
```
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
GPT2Model
,
self
).
__init__
(
config
)
super
(
GPT2Model
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
wte
=
nn
.
Embedding
(
config
.
total_tokens_embeddings
,
config
.
n_embd
)
self
.
wte
=
nn
.
Embedding
(
config
.
total_tokens_embeddings
,
config
.
n_embd
)
self
.
wpe
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
self
.
wpe
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
,
output_attentions
=
output_attentions
)
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
config
.
n_layer
)])
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
config
.
n_layer
)])
self
.
ln_f
=
LayerNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_f
=
LayerNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
...
@@ -617,11 +632,18 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -617,11 +632,18 @@ class GPT2Model(GPT2PreTrainedModel):
hidden_states
=
self
.
drop
(
hidden_states
)
hidden_states
=
self
.
drop
(
hidden_states
)
presents
=
[]
presents
=
[]
all_attentions
=
[]
for
block
,
layer_past
in
zip
(
self
.
h
,
past
):
for
block
,
layer_past
in
zip
(
self
.
h
,
past
):
hidden_states
,
present
=
block
(
hidden_states
,
layer_past
)
if
self
.
output_attentions
:
attentions
,
hidden_states
,
present
=
block
(
hidden_states
,
layer_past
)
all_attentions
.
append
(
attentions
)
else
:
hidden_states
,
present
=
block
(
hidden_states
,
layer_past
)
presents
.
append
(
present
)
presents
.
append
(
present
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
if
self
.
output_attentions
:
return
all_attentions
,
hidden_states
.
view
(
*
output_shape
),
presents
return
hidden_states
.
view
(
*
output_shape
),
presents
return
hidden_states
.
view
(
*
output_shape
),
presents
...
@@ -669,9 +691,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
...
@@ -669,9 +691,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
```
```
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
GPT2LMHeadModel
,
self
).
__init__
(
config
)
super
(
GPT2LMHeadModel
,
self
).
__init__
(
config
)
self
.
transformer
=
GPT2Model
(
config
)
self
.
transformer
=
GPT2Model
(
config
,
output_attentions
=
output_attentions
)
self
.
lm_head
=
GPT2LMHead
(
self
.
transformer
.
wte
.
weight
,
config
)
self
.
lm_head
=
GPT2LMHead
(
self
.
transformer
.
wte
.
weight
,
config
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
...
@@ -684,7 +706,11 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
...
@@ -684,7 +706,11 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
wte
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
wte
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
lm_labels
=
None
,
past
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
lm_labels
=
None
,
past
=
None
):
hidden_states
,
presents
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
)
transformer_output
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
)
if
self
.
transformer
.
output_attentions
:
all_attentions
,
hidden_states
,
presents
=
transformer_output
else
:
hidden_states
,
presents
=
transformer_output
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
if
lm_labels
is
not
None
:
if
lm_labels
is
not
None
:
# Shift so that tokens < n predict n
# Shift so that tokens < n predict n
...
@@ -695,6 +721,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
...
@@ -695,6 +721,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
shift_labels
.
view
(
-
1
))
return
loss
return
loss
if
self
.
transformer
.
output_attentions
:
return
all_attentions
,
lm_logits
,
presents
return
lm_logits
,
presents
return
lm_logits
,
presents
...
@@ -747,9 +775,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
...
@@ -747,9 +775,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
```
```
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
GPT2DoubleHeadsModel
,
self
).
__init__
(
config
)
super
(
GPT2DoubleHeadsModel
,
self
).
__init__
(
config
)
self
.
transformer
=
GPT2Model
(
config
)
self
.
transformer
=
GPT2Model
(
config
,
output_attentions
=
output_attentions
)
self
.
lm_head
=
GPT2LMHead
(
self
.
transformer
.
wte
.
weight
,
config
)
self
.
lm_head
=
GPT2LMHead
(
self
.
transformer
.
wte
.
weight
,
config
)
self
.
multiple_choice_head
=
GPT2MultipleChoiceHead
(
config
)
self
.
multiple_choice_head
=
GPT2MultipleChoiceHead
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
...
@@ -763,7 +791,11 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
...
@@ -763,7 +791,11 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
wte
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
wte
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
def
forward
(
self
,
input_ids
,
mc_token_ids
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
past
=
None
):
def
forward
(
self
,
input_ids
,
mc_token_ids
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
past
=
None
):
hidden_states
,
presents
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
)
transformer_output
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
)
if
self
.
transformer
.
output_attentions
:
all_attentions
,
hidden_states
,
presents
=
transformer_output
else
:
hidden_states
,
presents
=
transformer_output
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
losses
=
[]
losses
=
[]
...
@@ -777,4 +809,6 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
...
@@ -777,4 +809,6 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
losses
.
append
(
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
)))
losses
.
append
(
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
)))
if
losses
:
if
losses
:
return
losses
return
losses
if
self
.
transformer
.
output_attentions
:
return
all_attentions
,
lm_logits
,
mc_logits
,
presents
return
lm_logits
,
mc_logits
,
presents
return
lm_logits
,
mc_logits
,
presents
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment