Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b860e47c
Commit
b860e47c
authored
Jun 17, 2019
by
thomwolf
Browse files
add head masking and pruning to gpt-2
parent
7220d47a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
172 additions
and
21 deletions
+172
-21
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+102
-21
tests/modeling_gpt2_test.py
tests/modeling_gpt2_test.py
+70
-0
No files found.
pytorch_pretrained_bert/modeling_gpt2.py
View file @
b860e47c
...
...
@@ -44,6 +44,30 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.hugging
PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
"gpt2"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"
,
"gpt2-medium"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"
}
def
prune_conv1d_layer
(
layer
,
index
,
dim
=
1
):
""" Prune a Conv1D layer (a model parameters) to keep only entries in index.
A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
Return the pruned layer as a new layer with requires_grad=True.
Used to remove heads.
"""
index
=
index
.
to
(
layer
.
weight
.
device
)
W
=
layer
.
weight
.
index_select
(
dim
,
index
).
clone
().
detach
()
if
dim
==
0
:
b
=
layer
.
bias
.
clone
().
detach
()
else
:
b
=
layer
.
bias
[
index
].
clone
().
detach
()
new_size
=
list
(
layer
.
weight
.
size
())
new_size
[
dim
]
=
len
(
index
)
new_layer
=
Conv1D
(
new_size
[
1
],
new_size
[
0
])
new_layer
.
weight
.
requires_grad
=
False
new_layer
.
weight
.
copy_
(
W
.
contiguous
())
new_layer
.
weight
.
requires_grad
=
True
new_layer
.
bias
.
requires_grad
=
False
new_layer
.
bias
.
copy_
(
b
.
contiguous
())
new_layer
.
bias
.
requires_grad
=
True
return
new_layer
def
load_tf_weights_in_gpt2
(
model
,
gpt2_checkpoint_path
):
""" Load tf checkpoints in a pytorch model
"""
...
...
@@ -223,7 +247,7 @@ class Conv1D(nn.Module):
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
nx
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
):
def
__init__
(
self
,
nx
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
Attention
,
self
).
__init__
()
n_state
=
nx
# in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
...
...
@@ -232,13 +256,31 @@ class Attention(nn.Module):
self
.
n_head
=
config
.
n_head
self
.
split_size
=
n_state
self
.
scale
=
scale
self
.
output_attentions
=
output_attentions
self
.
keep_multihead_output
=
keep_multihead_output
self
.
multihead_output
=
None
self
.
c_attn
=
Conv1D
(
n_state
*
3
,
nx
)
self
.
c_proj
=
Conv1D
(
n_state
,
nx
)
self
.
attn_dropout
=
nn
.
Dropout
(
config
.
attn_pdrop
)
self
.
resid_dropout
=
nn
.
Dropout
(
config
.
resid_pdrop
)
def
_attn
(
self
,
q
,
k
,
v
):
def
prune_heads
(
self
,
heads
):
mask
=
torch
.
ones
(
self
.
n_head
,
self
.
split_size
//
self
.
n_head
)
for
head
in
heads
:
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
index_attn
=
torch
.
cat
([
index
,
index
+
self
.
split_size
,
index
+
(
2
*
self
.
split_size
)])
# Prune conv1d layers
self
.
c_attn
=
prune_conv1d_layer
(
self
.
c_attn
,
index_attn
,
dim
=
1
)
self
.
c_proj
=
prune_conv1d_layer
(
self
.
c_proj
,
index
,
dim
=
0
)
# Update hyper params
self
.
split_size
=
(
self
.
split_size
//
self
.
n_head
)
*
(
self
.
n_head
-
len
(
heads
))
self
.
n_head
=
self
.
n_head
-
len
(
heads
)
def
_attn
(
self
,
q
,
k
,
v
,
head_mask
=
None
):
w
=
torch
.
matmul
(
q
,
k
)
if
self
.
scale
:
w
=
w
/
math
.
sqrt
(
v
.
size
(
-
1
))
...
...
@@ -248,6 +290,11 @@ class Attention(nn.Module):
w
=
nn
.
Softmax
(
dim
=-
1
)(
w
)
w
=
self
.
attn_dropout
(
w
)
# Mask heads if we want to
if
head_mask
is
not
None
:
w
=
w
*
head_mask
if
self
.
output_attentions
:
return
w
,
torch
.
matmul
(
w
,
v
)
return
torch
.
matmul
(
w
,
v
)
...
...
@@ -265,7 +312,7 @@ class Attention(nn.Module):
else
:
return
x
.
permute
(
0
,
2
,
1
,
3
)
# (batch, head, seq_length, head_features)
def
forward
(
self
,
x
,
layer_past
=
None
):
def
forward
(
self
,
x
,
layer_past
=
None
,
head_mask
=
None
):
x
=
self
.
c_attn
(
x
)
query
,
key
,
value
=
x
.
split
(
self
.
split_size
,
dim
=
2
)
query
=
self
.
split_heads
(
query
)
...
...
@@ -276,7 +323,12 @@ class Attention(nn.Module):
key
=
torch
.
cat
((
past_key
,
key
),
dim
=-
1
)
value
=
torch
.
cat
((
past_value
,
value
),
dim
=-
2
)
present
=
torch
.
stack
((
key
.
transpose
(
-
2
,
-
1
),
value
))
# transpose to have same shapes for stacking
a
=
self
.
_attn
(
query
,
key
,
value
)
a
=
self
.
_attn
(
query
,
key
,
value
,
head_mask
)
if
self
.
keep_multihead_output
:
self
.
multihead_output
=
a
self
.
multihead_output
.
retain_grad
()
if
self
.
output_attentions
:
attentions
,
a
=
a
a
=
self
.
merge_heads
(
a
)
...
...
@@ -303,17 +355,17 @@ class MLP(nn.Module):
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
):
def
__init__
(
self
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
Block
,
self
).
__init__
()
nx
=
config
.
n_embd
self
.
output_attentions
=
output_attentions
self
.
ln_1
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
Attention
(
nx
,
n_ctx
,
config
,
scale
,
output_attentions
)
self
.
attn
=
Attention
(
nx
,
n_ctx
,
config
,
scale
,
output_attentions
,
keep_multihead_output
)
self
.
ln_2
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
MLP
(
4
*
nx
,
config
)
def
forward
(
self
,
x
,
layer_past
=
None
):
output_attn
=
self
.
attn
(
self
.
ln_1
(
x
),
layer_past
=
layer_past
)
def
forward
(
self
,
x
,
layer_past
=
None
,
head_mask
=
None
):
output_attn
=
self
.
attn
(
self
.
ln_1
(
x
),
layer_past
=
layer_past
,
head_mask
=
head_mask
)
if
self
.
output_attentions
:
attentions
,
a
,
present
=
output_attn
else
:
...
...
@@ -593,13 +645,14 @@ class GPT2Model(GPT2PreTrainedModel):
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
GPT2Model
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
wte
=
nn
.
Embedding
(
config
.
total_tokens_embeddings
,
config
.
n_embd
)
self
.
wpe
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
,
output_attentions
=
output_attentions
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
config
.
n_layer
)])
self
.
ln_f
=
LayerNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
...
...
@@ -619,7 +672,20 @@ class GPT2Model(GPT2PreTrainedModel):
# Copy word embeddings from the previous weights
self
.
wte
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
=
old_embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
past
=
None
):
def
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
h
[
layer
].
attn
.
prune_heads
(
heads
)
def
get_multihead_outputs
(
self
):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return
[
h
.
attn
.
multihead_output
for
h
in
self
.
h
]
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
past
=
None
,
head_mask
=
None
):
if
past
is
None
:
past_length
=
0
past
=
[
None
]
*
len
(
self
.
h
)
...
...
@@ -629,6 +695,17 @@ class GPT2Model(GPT2PreTrainedModel):
position_ids
=
torch
.
arange
(
past_length
,
input_ids
.
size
(
-
1
)
+
past_length
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
input_ids
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we mask the head
# attention_probs has shape bsz x n_heads x N x N
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each instance in batch
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
head_mask
=
(
1.0
-
head_mask
)
input_shape
=
input_ids
.
size
()
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
position_ids
=
position_ids
.
view
(
-
1
,
position_ids
.
size
(
-
1
))
...
...
@@ -646,11 +723,12 @@ class GPT2Model(GPT2PreTrainedModel):
presents
=
[]
all_attentions
=
[]
for
block
,
layer_past
in
zip
(
self
.
h
,
past
):
outputs
=
block
(
hidden_states
,
layer_past
,
head_mask
)
if
self
.
output_attentions
:
attentions
,
hidden_states
,
present
=
block
(
hidden_states
,
layer_past
)
attentions
,
hidden_states
,
present
=
outputs
all_attentions
.
append
(
attentions
)
else
:
hidden_states
,
present
=
block
(
hidden_states
,
layer_past
)
hidden_states
,
present
=
outputs
presents
.
append
(
present
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
...
...
@@ -703,9 +781,10 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
GPT2LMHeadModel
,
self
).
__init__
(
config
)
self
.
transformer
=
GPT2Model
(
config
,
output_attentions
=
output_attentions
)
self
.
transformer
=
GPT2Model
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
lm_head
=
GPT2LMHead
(
self
.
transformer
.
wte
.
weight
,
config
)
self
.
apply
(
self
.
init_weights
)
...
...
@@ -717,8 +796,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
wte
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
lm_labels
=
None
,
past
=
None
):
transformer_output
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
lm_labels
=
None
,
past
=
None
,
head_mask
=
None
):
transformer_output
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
,
head_mask
)
if
self
.
transformer
.
output_attentions
:
all_attentions
,
hidden_states
,
presents
=
transformer_output
else
:
...
...
@@ -787,9 +866,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
GPT2DoubleHeadsModel
,
self
).
__init__
(
config
)
self
.
transformer
=
GPT2Model
(
config
,
output_attentions
=
output_attentions
)
self
.
transformer
=
GPT2Model
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
lm_head
=
GPT2LMHead
(
self
.
transformer
.
wte
.
weight
,
config
)
self
.
multiple_choice_head
=
GPT2MultipleChoiceHead
(
config
)
self
.
apply
(
self
.
init_weights
)
...
...
@@ -802,8 +882,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
wte
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
def
forward
(
self
,
input_ids
,
mc_token_ids
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
past
=
None
):
transformer_output
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
)
def
forward
(
self
,
input_ids
,
mc_token_ids
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
past
=
None
,
head_mask
=
None
):
transformer_output
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
,
head_mask
)
if
self
.
transformer
.
output_attentions
:
all_attentions
,
hidden_states
,
presents
=
transformer_output
else
:
...
...
tests/modeling_gpt2_test.py
View file @
b860e47c
...
...
@@ -209,6 +209,73 @@ class GPT2ModelTest(unittest.TestCase):
[
list
(
l
.
size
())
for
l
in
result
[
"loss"
]],
[[],
[]])
def
create_and_check_gpt2_for_headmasking
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
for
model_class
in
(
GPT2Model
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
):
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
.
eval
()
head_mask
=
torch
.
ones
(
self
.
n_head
).
to
(
input_ids
.
device
)
head_mask
[
0
]
=
0.0
head_mask
[
-
1
]
=
0.0
# Mask all but the first and last heads
if
isinstance
(
model
,
GPT2DoubleHeadsModel
):
output
=
model
(
input_ids
,
mc_token_ids
,
head_mask
=
head_mask
)
else
:
output
=
model
(
input_ids
,
head_mask
=
head_mask
)
output
=
sum
(
t
.
sum
()
for
t
in
output
[:
-
1
])
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
(
model
if
isinstance
(
model
,
GPT2Model
)
else
model
.
transformer
).
get_multihead_outputs
()
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
0
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
1
:(
self
.
n_head
-
1
),
:,
:].
nonzero
()),
0
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
0
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
n_choices
*
self
.
seq_length
*
self
.
n_embd
//
self
.
n_head
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
self
.
n_head
-
1
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
n_choices
*
self
.
seq_length
*
self
.
n_embd
//
self
.
n_head
)
def
create_and_check_gpt2_for_head_pruning
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
for
model_class
in
(
GPT2Model
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
):
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
.
eval
()
transformer
=
model
if
isinstance
(
model
,
GPT2Model
)
else
model
.
transformer
heads_to_prune
=
{
0
:
list
(
range
(
1
,
self
.
n_head
)),
-
1
:
[
0
]}
transformer
.
prune_heads
(
heads_to_prune
)
if
isinstance
(
model
,
GPT2DoubleHeadsModel
):
output
=
model
(
input_ids
,
mc_token_ids
)
else
:
output
=
model
(
input_ids
)
output
=
sum
(
t
.
sum
()
for
t
in
output
[:
-
1
])
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
transformer
.
get_multihead_outputs
()
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
0
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
1
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
1
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
-
1
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
-
1
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
def
test_default
(
self
):
self
.
run_tester
(
GPT2ModelTest
.
GPT2ModelTester
(
self
))
...
...
@@ -247,6 +314,9 @@ class GPT2ModelTest(unittest.TestCase):
tester
.
check_gpt2_double_heads_output
(
output_result
)
tester
.
check_gpt2_double_heads_loss_output
(
output_result
)
tester
.
create_and_check_gpt2_for_headmasking
(
*
config_and_inputs
)
tester
.
create_and_check_gpt2_for_head_pruning
(
*
config_and_inputs
)
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment