Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f12007e4
Commit
f12007e4
authored
Jun 17, 2019
by
thomwolf
Browse files
add head masking and pruning to openai GPT
parent
b860e47c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
149 additions
and
21 deletions
+149
-21
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+79
-21
tests/modeling_openai_test.py
tests/modeling_openai_test.py
+70
-0
No files found.
pytorch_pretrained_bert/modeling_openai.py
View file @
f12007e4
...
...
@@ -36,6 +36,7 @@ from torch.nn.parameter import Parameter
from
.file_utils
import
cached_path
,
CONFIG_NAME
,
WEIGHTS_NAME
from
.modeling
import
BertLayerNorm
as
LayerNorm
from
.modeling_gpt2
import
prune_conv1d_layer
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -256,7 +257,7 @@ class Conv1D(nn.Module):
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
nx
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
):
def
__init__
(
self
,
nx
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
Attention
,
self
).
__init__
()
n_state
=
nx
# in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
...
...
@@ -265,13 +266,31 @@ class Attention(nn.Module):
self
.
n_head
=
config
.
n_head
self
.
split_size
=
n_state
self
.
scale
=
scale
self
.
output_attentions
=
output_attentions
self
.
keep_multihead_output
=
keep_multihead_output
self
.
multihead_output
=
None
self
.
c_attn
=
Conv1D
(
n_state
*
3
,
1
,
nx
)
self
.
c_proj
=
Conv1D
(
n_state
,
1
,
nx
)
self
.
attn_dropout
=
nn
.
Dropout
(
config
.
attn_pdrop
)
self
.
resid_dropout
=
nn
.
Dropout
(
config
.
resid_pdrop
)
def
_attn
(
self
,
q
,
k
,
v
):
def
prune_heads
(
self
,
heads
):
mask
=
torch
.
ones
(
self
.
n_head
,
self
.
split_size
//
self
.
n_head
)
for
head
in
heads
:
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
index_attn
=
torch
.
cat
([
index
,
index
+
self
.
split_size
,
index
+
(
2
*
self
.
split_size
)])
# Prune conv1d layers
self
.
c_attn
=
prune_conv1d_layer
(
self
.
c_attn
,
index_attn
,
dim
=
1
)
self
.
c_proj
=
prune_conv1d_layer
(
self
.
c_proj
,
index
,
dim
=
0
)
# Update hyper params
self
.
split_size
=
(
self
.
split_size
//
self
.
n_head
)
*
(
self
.
n_head
-
len
(
heads
))
self
.
n_head
=
self
.
n_head
-
len
(
heads
)
def
_attn
(
self
,
q
,
k
,
v
,
head_mask
=
None
):
w
=
torch
.
matmul
(
q
,
k
)
if
self
.
scale
:
w
=
w
/
math
.
sqrt
(
v
.
size
(
-
1
))
...
...
@@ -282,6 +301,11 @@ class Attention(nn.Module):
w
=
nn
.
Softmax
(
dim
=-
1
)(
w
)
w
=
self
.
attn_dropout
(
w
)
# Mask heads if we want to
if
head_mask
is
not
None
:
w
=
w
*
head_mask
if
self
.
output_attentions
:
return
w
,
torch
.
matmul
(
w
,
v
)
return
torch
.
matmul
(
w
,
v
)
...
...
@@ -299,13 +323,18 @@ class Attention(nn.Module):
else
:
return
x
.
permute
(
0
,
2
,
1
,
3
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
head_mask
=
None
):
x
=
self
.
c_attn
(
x
)
query
,
key
,
value
=
x
.
split
(
self
.
split_size
,
dim
=
2
)
query
=
self
.
split_heads
(
query
)
key
=
self
.
split_heads
(
key
,
k
=
True
)
value
=
self
.
split_heads
(
value
)
a
=
self
.
_attn
(
query
,
key
,
value
)
a
=
self
.
_attn
(
query
,
key
,
value
,
head_mask
)
if
self
.
keep_multihead_output
:
self
.
multihead_output
=
a
self
.
multihead_output
.
retain_grad
()
if
self
.
output_attentions
:
attentions
,
a
=
a
a
=
self
.
merge_heads
(
a
)
...
...
@@ -332,17 +361,17 @@ class MLP(nn.Module):
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
):
def
__init__
(
self
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
Block
,
self
).
__init__
()
nx
=
config
.
n_embd
self
.
output_attentions
=
output_attentions
self
.
attn
=
Attention
(
nx
,
n_ctx
,
config
,
scale
,
output_attentions
)
self
.
attn
=
Attention
(
nx
,
n_ctx
,
config
,
scale
,
output_attentions
,
keep_multihead_output
)
self
.
ln_1
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
MLP
(
4
*
nx
,
config
)
self
.
ln_2
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
self
,
x
):
a
=
self
.
attn
(
x
)
def
forward
(
self
,
x
,
head_mask
=
None
):
a
=
self
.
attn
(
x
,
head_mask
=
head_mask
)
if
self
.
output_attentions
:
attentions
,
a
=
a
n
=
self
.
ln_1
(
x
+
a
)
...
...
@@ -614,13 +643,14 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
OpenAIGPTModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
tokens_embed
=
nn
.
Embedding
(
config
.
total_tokens_embeddings
,
config
.
n_embd
)
self
.
positions_embed
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
,
output_attentions
=
output_attentions
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
config
.
n_layer
)])
self
.
apply
(
self
.
init_weights
)
...
...
@@ -639,7 +669,20 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
# Copy word embeddings from the previous weights
self
.
tokens_embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
=
old_embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
):
def
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
h
[
layer
].
attn
.
prune_heads
(
heads
)
def
get_multihead_outputs
(
self
):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return
[
h
.
attn
.
multihead_output
for
h
in
self
.
h
]
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
head_mask
=
None
):
if
position_ids
is
None
:
# This was used when we had a single embedding matrice from position and token embeddings
# start = self.config.vocab_size + self.config.n_special
...
...
@@ -648,6 +691,17 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
position_ids
=
torch
.
arange
(
input_ids
.
size
(
-
1
),
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
input_ids
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we mask the head
# attention_probs has shape bsz x n_heads x N x N
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each instance in batch
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
head_mask
=
(
1.0
-
head_mask
)
input_shape
=
input_ids
.
size
()
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
position_ids
=
position_ids
.
view
(
-
1
,
position_ids
.
size
(
-
1
))
...
...
@@ -664,11 +718,12 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
all_attentions
=
[]
for
block
in
self
.
h
:
outputs
=
block
(
hidden_states
,
head_mask
)
if
self
.
output_attentions
:
attentions
,
hidden_states
=
block
(
hidden_states
)
attentions
,
hidden_states
=
outputs
all_attentions
.
append
(
attentions
)
else
:
hidden_states
=
block
(
hidden_states
)
hidden_states
=
outputs
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
if
self
.
output_attentions
:
return
all_attentions
,
hidden_states
.
view
(
*
output_shape
)
...
...
@@ -731,9 +786,10 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
OpenAIGPTLMHeadModel
,
self
).
__init__
(
config
)
self
.
transformer
=
OpenAIGPTModel
(
config
,
output_attentions
=
output_attentions
)
self
.
transformer
=
OpenAIGPTModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
lm_head
=
OpenAIGPTLMHead
(
self
.
transformer
.
tokens_embed
.
weight
,
config
)
self
.
apply
(
self
.
init_weights
)
...
...
@@ -745,8 +801,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_embed
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
lm_labels
=
None
):
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
lm_labels
=
None
,
head_mask
=
None
):
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
head_mask
)
if
self
.
transformer
.
output_attentions
:
all_attentions
,
hidden_states
=
hidden_states
lm_logits
=
self
.
lm_head
(
hidden_states
)
...
...
@@ -825,9 +881,10 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
OpenAIGPTDoubleHeadsModel
,
self
).
__init__
(
config
)
self
.
transformer
=
OpenAIGPTModel
(
config
,
output_attentions
=
output_attentions
)
self
.
transformer
=
OpenAIGPTModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
lm_head
=
OpenAIGPTLMHead
(
self
.
transformer
.
tokens_embed
.
weight
,
config
)
self
.
multiple_choice_head
=
OpenAIGPTMultipleChoiceHead
(
config
)
self
.
apply
(
self
.
init_weights
)
...
...
@@ -840,8 +897,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_embed
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
def
forward
(
self
,
input_ids
,
mc_token_ids
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
position_ids
=
None
):
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
)
def
forward
(
self
,
input_ids
,
mc_token_ids
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
):
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
head_mask
)
if
self
.
transformer
.
output_attentions
:
all_attentions
,
hidden_states
=
hidden_states
lm_logits
=
self
.
lm_head
(
hidden_states
)
...
...
tests/modeling_openai_test.py
View file @
f12007e4
...
...
@@ -182,6 +182,73 @@ class OpenAIGPTModelTest(unittest.TestCase):
[
list
(
l
.
size
())
for
l
in
result
[
"loss"
]],
[[],
[]])
def
create_and_check_openai_for_headmasking
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
for
model_class
in
(
OpenAIGPTModel
,
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
):
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
.
eval
()
head_mask
=
torch
.
ones
(
self
.
n_head
).
to
(
input_ids
.
device
)
head_mask
[
0
]
=
0.0
head_mask
[
-
1
]
=
0.0
# Mask all but the first and last heads
if
isinstance
(
model
,
OpenAIGPTDoubleHeadsModel
):
output
=
model
(
input_ids
,
mc_token_ids
,
head_mask
=
head_mask
)
else
:
output
=
model
(
input_ids
,
head_mask
=
head_mask
)
output
=
sum
(
t
.
sum
()
for
t
in
output
[:
-
1
])
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
(
model
if
isinstance
(
model
,
OpenAIGPTModel
)
else
model
.
transformer
).
get_multihead_outputs
()
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
0
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
1
:(
self
.
n_head
-
1
),
:,
:].
nonzero
()),
0
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
0
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
n_choices
*
self
.
seq_length
*
self
.
n_embd
//
self
.
n_head
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
self
.
n_head
-
1
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
n_choices
*
self
.
seq_length
*
self
.
n_embd
//
self
.
n_head
)
def
create_and_check_openai_for_head_pruning
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
for
model_class
in
(
OpenAIGPTModel
,
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
):
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
.
eval
()
transformer
=
model
if
isinstance
(
model
,
OpenAIGPTModel
)
else
model
.
transformer
heads_to_prune
=
{
0
:
list
(
range
(
1
,
self
.
n_head
)),
-
1
:
[
0
]}
transformer
.
prune_heads
(
heads_to_prune
)
if
isinstance
(
model
,
OpenAIGPTDoubleHeadsModel
):
output
=
model
(
input_ids
,
mc_token_ids
)
else
:
output
=
model
(
input_ids
)
output
=
sum
(
t
.
sum
()
for
t
in
output
[:
-
1
])
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
transformer
.
get_multihead_outputs
()
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
0
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
1
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
1
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
-
1
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
-
1
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
def
test_default
(
self
):
self
.
run_tester
(
OpenAIGPTModelTest
.
OpenAIGPTModelTester
(
self
))
...
...
@@ -220,6 +287,9 @@ class OpenAIGPTModelTest(unittest.TestCase):
tester
.
check_openai_double_heads_output
(
output_result
)
tester
.
check_openai_double_heads_loss_output
(
output_result
)
tester
.
create_and_check_openai_for_headmasking
(
*
config_and_inputs
)
tester
.
create_and_check_openai_for_head_pruning
(
*
config_and_inputs
)
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment