Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e211785a
"...git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "9a013fdb1d7e3d529d6fd6e28198a4cb12ac3502"
Commit
e211785a
authored
May 02, 2019
by
thomwolf
Browse files
extract attention weights from GPT
parent
db98a4a4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
10 deletions
+38
-10
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+38
-10
No files found.
pytorch_pretrained_bert/modeling_openai.py
View file @
e211785a
...
...
@@ -253,7 +253,7 @@ class Conv1D(nn.Module):
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
nx
,
n_ctx
,
config
,
scale
=
False
):
def
__init__
(
self
,
nx
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
):
super
(
Attention
,
self
).
__init__
()
n_state
=
nx
# in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
...
...
@@ -262,6 +262,7 @@ class Attention(nn.Module):
self
.
n_head
=
config
.
n_head
self
.
split_size
=
n_state
self
.
scale
=
scale
self
.
output_attentions
=
output_attentions
self
.
c_attn
=
Conv1D
(
n_state
*
3
,
1
,
nx
)
self
.
c_proj
=
Conv1D
(
n_state
,
1
,
nx
)
self
.
attn_dropout
=
nn
.
Dropout
(
config
.
attn_pdrop
)
...
...
@@ -278,6 +279,8 @@ class Attention(nn.Module):
w
=
nn
.
Softmax
(
dim
=-
1
)(
w
)
w
=
self
.
attn_dropout
(
w
)
if
self
.
output_attentions
:
return
w
,
torch
.
matmul
(
w
,
v
)
return
torch
.
matmul
(
w
,
v
)
def
merge_heads
(
self
,
x
):
...
...
@@ -300,9 +303,13 @@ class Attention(nn.Module):
key
=
self
.
split_heads
(
key
,
k
=
True
)
value
=
self
.
split_heads
(
value
)
a
=
self
.
_attn
(
query
,
key
,
value
)
if
self
.
output_attentions
:
attentions
,
a
=
a
a
=
self
.
merge_heads
(
a
)
a
=
self
.
c_proj
(
a
)
a
=
self
.
resid_dropout
(
a
)
if
self
.
output_attentions
:
return
attentions
,
a
return
a
...
...
@@ -322,19 +329,24 @@ class MLP(nn.Module):
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
n_ctx
,
config
,
scale
=
False
):
def
__init__
(
self
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
):
super
(
Block
,
self
).
__init__
()
nx
=
config
.
n_embd
self
.
attn
=
Attention
(
nx
,
n_ctx
,
config
,
scale
)
self
.
output_attentions
=
output_attentions
self
.
attn
=
Attention
(
nx
,
n_ctx
,
config
,
scale
,
output_attentions
)
self
.
ln_1
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
MLP
(
4
*
nx
,
config
)
self
.
ln_2
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
self
,
x
):
a
=
self
.
attn
(
x
)
if
self
.
output_attentions
:
attentions
,
a
=
a
n
=
self
.
ln_1
(
x
+
a
)
m
=
self
.
mlp
(
n
)
h
=
self
.
ln_2
(
n
+
m
)
if
self
.
output_attentions
:
return
attentions
,
h
return
h
...
...
@@ -591,12 +603,13 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
```
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
OpenAIGPTModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
tokens_embed
=
nn
.
Embedding
(
config
.
total_tokens_embeddings
,
config
.
n_embd
)
self
.
positions_embed
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
,
output_attentions
=
output_attentions
)
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
config
.
n_layer
)])
self
.
apply
(
self
.
init_weights
)
...
...
@@ -639,9 +652,16 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
# Add the position information to the input embeddings
# h = e.sum(dim=2)
hidden_states
=
inputs_embeds
+
position_embeds
+
token_type_embeds
all_attentions
=
[]
for
block
in
self
.
h
:
hidden_states
=
block
(
hidden_states
)
if
self
.
output_attentions
:
attentions
,
hidden_states
=
block
(
hidden_states
)
all_attentions
.
append
(
attentions
)
else
:
hidden_states
=
block
(
hidden_states
)
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
if
self
.
output_attentions
:
return
all_attentions
,
hidden_states
.
view
(
*
output_shape
)
return
hidden_states
.
view
(
*
output_shape
)
...
...
@@ -701,9 +721,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
```
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
OpenAIGPTLMHeadModel
,
self
).
__init__
(
config
)
self
.
transformer
=
OpenAIGPTModel
(
config
)
self
.
transformer
=
OpenAIGPTModel
(
config
,
output_attentions
=
output_attentions
)
self
.
lm_head
=
OpenAIGPTLMHead
(
self
.
transformer
.
tokens_embed
.
weight
,
config
)
self
.
apply
(
self
.
init_weights
)
...
...
@@ -716,6 +736,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
lm_labels
=
None
):
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
)
if
self
.
transformer
.
output_attentions
:
all_attentions
,
hidden_states
=
hidden_states
lm_logits
=
self
.
lm_head
(
hidden_states
)
if
lm_labels
is
not
None
:
# Shift so that tokens < n predict n
...
...
@@ -726,6 +748,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
return
loss
if
self
.
transformer
.
output_attentions
:
return
all_attentions
,
lm_logits
return
lm_logits
...
...
@@ -790,9 +814,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
```
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
OpenAIGPTDoubleHeadsModel
,
self
).
__init__
(
config
)
self
.
transformer
=
OpenAIGPTModel
(
config
)
self
.
transformer
=
OpenAIGPTModel
(
config
,
output_attentions
=
output_attentions
)
self
.
lm_head
=
OpenAIGPTLMHead
(
self
.
transformer
.
tokens_embed
.
weight
,
config
)
self
.
multiple_choice_head
=
OpenAIGPTMultipleChoiceHead
(
config
)
self
.
apply
(
self
.
init_weights
)
...
...
@@ -806,6 +830,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
def
forward
(
self
,
input_ids
,
mc_token_ids
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
position_ids
=
None
):
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
)
if
self
.
transformer
.
output_attentions
:
all_attentions
,
hidden_states
=
hidden_states
lm_logits
=
self
.
lm_head
(
hidden_states
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
losses
=
[]
...
...
@@ -819,4 +845,6 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
losses
.
append
(
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
)))
if
losses
:
return
losses
if
self
.
transformer
.
output_attentions
:
return
all_attentions
,
lm_logits
,
mc_logits
return
lm_logits
,
mc_logits
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment