Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b860e47c
"git@developer.sourcefind.cn:wqshmzh/ktransformers.git" did not exist on "ca9695b488f2167c8d79a112785c872a853cb408"
Commit
b860e47c
authored
Jun 17, 2019
by
thomwolf
Browse files
add head masking and pruning to gpt-2
parent
7220d47a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
172 additions
and
21 deletions
+172
-21
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+102
-21
tests/modeling_gpt2_test.py
tests/modeling_gpt2_test.py
+70
-0
No files found.
pytorch_pretrained_bert/modeling_gpt2.py
View file @
b860e47c
...
@@ -44,6 +44,30 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.hugging
...
@@ -44,6 +44,30 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.hugging
PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
"gpt2"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"
,
PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
"gpt2"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"
,
"gpt2-medium"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"
}
"gpt2-medium"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"
}
def
prune_conv1d_layer
(
layer
,
index
,
dim
=
1
):
""" Prune a Conv1D layer (a model parameters) to keep only entries in index.
A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
Return the pruned layer as a new layer with requires_grad=True.
Used to remove heads.
"""
index
=
index
.
to
(
layer
.
weight
.
device
)
W
=
layer
.
weight
.
index_select
(
dim
,
index
).
clone
().
detach
()
if
dim
==
0
:
b
=
layer
.
bias
.
clone
().
detach
()
else
:
b
=
layer
.
bias
[
index
].
clone
().
detach
()
new_size
=
list
(
layer
.
weight
.
size
())
new_size
[
dim
]
=
len
(
index
)
new_layer
=
Conv1D
(
new_size
[
1
],
new_size
[
0
])
new_layer
.
weight
.
requires_grad
=
False
new_layer
.
weight
.
copy_
(
W
.
contiguous
())
new_layer
.
weight
.
requires_grad
=
True
new_layer
.
bias
.
requires_grad
=
False
new_layer
.
bias
.
copy_
(
b
.
contiguous
())
new_layer
.
bias
.
requires_grad
=
True
return
new_layer
def
load_tf_weights_in_gpt2
(
model
,
gpt2_checkpoint_path
):
def
load_tf_weights_in_gpt2
(
model
,
gpt2_checkpoint_path
):
""" Load tf checkpoints in a pytorch model
""" Load tf checkpoints in a pytorch model
"""
"""
...
@@ -223,7 +247,7 @@ class Conv1D(nn.Module):
...
@@ -223,7 +247,7 @@ class Conv1D(nn.Module):
class
Attention
(
nn
.
Module
):
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
nx
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
):
def
__init__
(
self
,
nx
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
Attention
,
self
).
__init__
()
super
(
Attention
,
self
).
__init__
()
n_state
=
nx
# in Attention: n_state=768 (nx=n_embd)
n_state
=
nx
# in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
...
@@ -232,13 +256,31 @@ class Attention(nn.Module):
...
@@ -232,13 +256,31 @@ class Attention(nn.Module):
self
.
n_head
=
config
.
n_head
self
.
n_head
=
config
.
n_head
self
.
split_size
=
n_state
self
.
split_size
=
n_state
self
.
scale
=
scale
self
.
scale
=
scale
self
.
output_attentions
=
output_attentions
self
.
output_attentions
=
output_attentions
self
.
keep_multihead_output
=
keep_multihead_output
self
.
multihead_output
=
None
self
.
c_attn
=
Conv1D
(
n_state
*
3
,
nx
)
self
.
c_attn
=
Conv1D
(
n_state
*
3
,
nx
)
self
.
c_proj
=
Conv1D
(
n_state
,
nx
)
self
.
c_proj
=
Conv1D
(
n_state
,
nx
)
self
.
attn_dropout
=
nn
.
Dropout
(
config
.
attn_pdrop
)
self
.
attn_dropout
=
nn
.
Dropout
(
config
.
attn_pdrop
)
self
.
resid_dropout
=
nn
.
Dropout
(
config
.
resid_pdrop
)
self
.
resid_dropout
=
nn
.
Dropout
(
config
.
resid_pdrop
)
def
_attn
(
self
,
q
,
k
,
v
):
def
prune_heads
(
self
,
heads
):
mask
=
torch
.
ones
(
self
.
n_head
,
self
.
split_size
//
self
.
n_head
)
for
head
in
heads
:
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
index_attn
=
torch
.
cat
([
index
,
index
+
self
.
split_size
,
index
+
(
2
*
self
.
split_size
)])
# Prune conv1d layers
self
.
c_attn
=
prune_conv1d_layer
(
self
.
c_attn
,
index_attn
,
dim
=
1
)
self
.
c_proj
=
prune_conv1d_layer
(
self
.
c_proj
,
index
,
dim
=
0
)
# Update hyper params
self
.
split_size
=
(
self
.
split_size
//
self
.
n_head
)
*
(
self
.
n_head
-
len
(
heads
))
self
.
n_head
=
self
.
n_head
-
len
(
heads
)
def
_attn
(
self
,
q
,
k
,
v
,
head_mask
=
None
):
w
=
torch
.
matmul
(
q
,
k
)
w
=
torch
.
matmul
(
q
,
k
)
if
self
.
scale
:
if
self
.
scale
:
w
=
w
/
math
.
sqrt
(
v
.
size
(
-
1
))
w
=
w
/
math
.
sqrt
(
v
.
size
(
-
1
))
...
@@ -248,6 +290,11 @@ class Attention(nn.Module):
...
@@ -248,6 +290,11 @@ class Attention(nn.Module):
w
=
nn
.
Softmax
(
dim
=-
1
)(
w
)
w
=
nn
.
Softmax
(
dim
=-
1
)(
w
)
w
=
self
.
attn_dropout
(
w
)
w
=
self
.
attn_dropout
(
w
)
# Mask heads if we want to
if
head_mask
is
not
None
:
w
=
w
*
head_mask
if
self
.
output_attentions
:
if
self
.
output_attentions
:
return
w
,
torch
.
matmul
(
w
,
v
)
return
w
,
torch
.
matmul
(
w
,
v
)
return
torch
.
matmul
(
w
,
v
)
return
torch
.
matmul
(
w
,
v
)
...
@@ -265,7 +312,7 @@ class Attention(nn.Module):
...
@@ -265,7 +312,7 @@ class Attention(nn.Module):
else
:
else
:
return
x
.
permute
(
0
,
2
,
1
,
3
)
# (batch, head, seq_length, head_features)
return
x
.
permute
(
0
,
2
,
1
,
3
)
# (batch, head, seq_length, head_features)
def
forward
(
self
,
x
,
layer_past
=
None
):
def
forward
(
self
,
x
,
layer_past
=
None
,
head_mask
=
None
):
x
=
self
.
c_attn
(
x
)
x
=
self
.
c_attn
(
x
)
query
,
key
,
value
=
x
.
split
(
self
.
split_size
,
dim
=
2
)
query
,
key
,
value
=
x
.
split
(
self
.
split_size
,
dim
=
2
)
query
=
self
.
split_heads
(
query
)
query
=
self
.
split_heads
(
query
)
...
@@ -276,7 +323,12 @@ class Attention(nn.Module):
...
@@ -276,7 +323,12 @@ class Attention(nn.Module):
key
=
torch
.
cat
((
past_key
,
key
),
dim
=-
1
)
key
=
torch
.
cat
((
past_key
,
key
),
dim
=-
1
)
value
=
torch
.
cat
((
past_value
,
value
),
dim
=-
2
)
value
=
torch
.
cat
((
past_value
,
value
),
dim
=-
2
)
present
=
torch
.
stack
((
key
.
transpose
(
-
2
,
-
1
),
value
))
# transpose to have same shapes for stacking
present
=
torch
.
stack
((
key
.
transpose
(
-
2
,
-
1
),
value
))
# transpose to have same shapes for stacking
a
=
self
.
_attn
(
query
,
key
,
value
)
a
=
self
.
_attn
(
query
,
key
,
value
,
head_mask
)
if
self
.
keep_multihead_output
:
self
.
multihead_output
=
a
self
.
multihead_output
.
retain_grad
()
if
self
.
output_attentions
:
if
self
.
output_attentions
:
attentions
,
a
=
a
attentions
,
a
=
a
a
=
self
.
merge_heads
(
a
)
a
=
self
.
merge_heads
(
a
)
...
@@ -303,17 +355,17 @@ class MLP(nn.Module):
...
@@ -303,17 +355,17 @@ class MLP(nn.Module):
class
Block
(
nn
.
Module
):
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
):
def
__init__
(
self
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
Block
,
self
).
__init__
()
super
(
Block
,
self
).
__init__
()
nx
=
config
.
n_embd
nx
=
config
.
n_embd
self
.
output_attentions
=
output_attentions
self
.
output_attentions
=
output_attentions
self
.
ln_1
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_1
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
Attention
(
nx
,
n_ctx
,
config
,
scale
,
output_attentions
)
self
.
attn
=
Attention
(
nx
,
n_ctx
,
config
,
scale
,
output_attentions
,
keep_multihead_output
)
self
.
ln_2
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_2
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
MLP
(
4
*
nx
,
config
)
self
.
mlp
=
MLP
(
4
*
nx
,
config
)
def
forward
(
self
,
x
,
layer_past
=
None
):
def
forward
(
self
,
x
,
layer_past
=
None
,
head_mask
=
None
):
output_attn
=
self
.
attn
(
self
.
ln_1
(
x
),
layer_past
=
layer_past
)
output_attn
=
self
.
attn
(
self
.
ln_1
(
x
),
layer_past
=
layer_past
,
head_mask
=
head_mask
)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
attentions
,
a
,
present
=
output_attn
attentions
,
a
,
present
=
output_attn
else
:
else
:
...
@@ -593,13 +645,14 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -593,13 +645,14 @@ class GPT2Model(GPT2PreTrainedModel):
```
```
"""
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
GPT2Model
,
self
).
__init__
(
config
)
super
(
GPT2Model
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_attentions
=
output_attentions
self
.
wte
=
nn
.
Embedding
(
config
.
total_tokens_embeddings
,
config
.
n_embd
)
self
.
wte
=
nn
.
Embedding
(
config
.
total_tokens_embeddings
,
config
.
n_embd
)
self
.
wpe
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
self
.
wpe
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
,
output_attentions
=
output_attentions
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
config
.
n_layer
)])
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
config
.
n_layer
)])
self
.
ln_f
=
LayerNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_f
=
LayerNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
...
@@ -619,7 +672,20 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -619,7 +672,20 @@ class GPT2Model(GPT2PreTrainedModel):
# Copy word embeddings from the previous weights
# Copy word embeddings from the previous weights
self
.
wte
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
=
old_embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
self
.
wte
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
=
old_embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
past
=
None
):
def
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
h
[
layer
].
attn
.
prune_heads
(
heads
)
def
get_multihead_outputs
(
self
):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return
[
h
.
attn
.
multihead_output
for
h
in
self
.
h
]
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
past
=
None
,
head_mask
=
None
):
if
past
is
None
:
if
past
is
None
:
past_length
=
0
past_length
=
0
past
=
[
None
]
*
len
(
self
.
h
)
past
=
[
None
]
*
len
(
self
.
h
)
...
@@ -629,6 +695,17 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -629,6 +695,17 @@ class GPT2Model(GPT2PreTrainedModel):
position_ids
=
torch
.
arange
(
past_length
,
input_ids
.
size
(
-
1
)
+
past_length
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_ids
=
torch
.
arange
(
past_length
,
input_ids
.
size
(
-
1
)
+
past_length
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
input_ids
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
input_ids
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we mask the head
# attention_probs has shape bsz x n_heads x N x N
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each instance in batch
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
head_mask
=
(
1.0
-
head_mask
)
input_shape
=
input_ids
.
size
()
input_shape
=
input_ids
.
size
()
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
position_ids
=
position_ids
.
view
(
-
1
,
position_ids
.
size
(
-
1
))
position_ids
=
position_ids
.
view
(
-
1
,
position_ids
.
size
(
-
1
))
...
@@ -646,11 +723,12 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -646,11 +723,12 @@ class GPT2Model(GPT2PreTrainedModel):
presents
=
[]
presents
=
[]
all_attentions
=
[]
all_attentions
=
[]
for
block
,
layer_past
in
zip
(
self
.
h
,
past
):
for
block
,
layer_past
in
zip
(
self
.
h
,
past
):
outputs
=
block
(
hidden_states
,
layer_past
,
head_mask
)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
attentions
,
hidden_states
,
present
=
block
(
hidden_states
,
layer_past
)
attentions
,
hidden_states
,
present
=
outputs
all_attentions
.
append
(
attentions
)
all_attentions
.
append
(
attentions
)
else
:
else
:
hidden_states
,
present
=
block
(
hidden_states
,
layer_past
)
hidden_states
,
present
=
outputs
presents
.
append
(
present
)
presents
.
append
(
present
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
...
@@ -703,9 +781,10 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
...
@@ -703,9 +781,10 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
```
```
"""
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
GPT2LMHeadModel
,
self
).
__init__
(
config
)
super
(
GPT2LMHeadModel
,
self
).
__init__
(
config
)
self
.
transformer
=
GPT2Model
(
config
,
output_attentions
=
output_attentions
)
self
.
transformer
=
GPT2Model
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
lm_head
=
GPT2LMHead
(
self
.
transformer
.
wte
.
weight
,
config
)
self
.
lm_head
=
GPT2LMHead
(
self
.
transformer
.
wte
.
weight
,
config
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
...
@@ -717,8 +796,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
...
@@ -717,8 +796,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
wte
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
wte
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
lm_labels
=
None
,
past
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
lm_labels
=
None
,
past
=
None
,
head_mask
=
None
):
transformer_output
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
)
transformer_output
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
,
head_mask
)
if
self
.
transformer
.
output_attentions
:
if
self
.
transformer
.
output_attentions
:
all_attentions
,
hidden_states
,
presents
=
transformer_output
all_attentions
,
hidden_states
,
presents
=
transformer_output
else
:
else
:
...
@@ -787,9 +866,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
...
@@ -787,9 +866,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
```
```
"""
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
GPT2DoubleHeadsModel
,
self
).
__init__
(
config
)
super
(
GPT2DoubleHeadsModel
,
self
).
__init__
(
config
)
self
.
transformer
=
GPT2Model
(
config
,
output_attentions
=
output_attentions
)
self
.
transformer
=
GPT2Model
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
lm_head
=
GPT2LMHead
(
self
.
transformer
.
wte
.
weight
,
config
)
self
.
lm_head
=
GPT2LMHead
(
self
.
transformer
.
wte
.
weight
,
config
)
self
.
multiple_choice_head
=
GPT2MultipleChoiceHead
(
config
)
self
.
multiple_choice_head
=
GPT2MultipleChoiceHead
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
...
@@ -802,8 +882,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
...
@@ -802,8 +882,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
wte
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
wte
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
def
forward
(
self
,
input_ids
,
mc_token_ids
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
past
=
None
):
def
forward
(
self
,
input_ids
,
mc_token_ids
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
transformer_output
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
)
position_ids
=
None
,
past
=
None
,
head_mask
=
None
):
transformer_output
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
,
head_mask
)
if
self
.
transformer
.
output_attentions
:
if
self
.
transformer
.
output_attentions
:
all_attentions
,
hidden_states
,
presents
=
transformer_output
all_attentions
,
hidden_states
,
presents
=
transformer_output
else
:
else
:
...
...
tests/modeling_gpt2_test.py
View file @
b860e47c
...
@@ -209,6 +209,73 @@ class GPT2ModelTest(unittest.TestCase):
...
@@ -209,6 +209,73 @@ class GPT2ModelTest(unittest.TestCase):
[
list
(
l
.
size
())
for
l
in
result
[
"loss"
]],
[
list
(
l
.
size
())
for
l
in
result
[
"loss"
]],
[[],
[]])
[[],
[]])
def
create_and_check_gpt2_for_headmasking
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
for
model_class
in
(
GPT2Model
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
):
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
.
eval
()
head_mask
=
torch
.
ones
(
self
.
n_head
).
to
(
input_ids
.
device
)
head_mask
[
0
]
=
0.0
head_mask
[
-
1
]
=
0.0
# Mask all but the first and last heads
if
isinstance
(
model
,
GPT2DoubleHeadsModel
):
output
=
model
(
input_ids
,
mc_token_ids
,
head_mask
=
head_mask
)
else
:
output
=
model
(
input_ids
,
head_mask
=
head_mask
)
output
=
sum
(
t
.
sum
()
for
t
in
output
[:
-
1
])
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
(
model
if
isinstance
(
model
,
GPT2Model
)
else
model
.
transformer
).
get_multihead_outputs
()
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
0
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
1
:(
self
.
n_head
-
1
),
:,
:].
nonzero
()),
0
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
0
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
n_choices
*
self
.
seq_length
*
self
.
n_embd
//
self
.
n_head
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
self
.
n_head
-
1
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
n_choices
*
self
.
seq_length
*
self
.
n_embd
//
self
.
n_head
)
def
create_and_check_gpt2_for_head_pruning
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
for
model_class
in
(
GPT2Model
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
):
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
.
eval
()
transformer
=
model
if
isinstance
(
model
,
GPT2Model
)
else
model
.
transformer
heads_to_prune
=
{
0
:
list
(
range
(
1
,
self
.
n_head
)),
-
1
:
[
0
]}
transformer
.
prune_heads
(
heads_to_prune
)
if
isinstance
(
model
,
GPT2DoubleHeadsModel
):
output
=
model
(
input_ids
,
mc_token_ids
)
else
:
output
=
model
(
input_ids
)
output
=
sum
(
t
.
sum
()
for
t
in
output
[:
-
1
])
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
transformer
.
get_multihead_outputs
()
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
0
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
1
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
1
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
-
1
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
-
1
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
def
test_default
(
self
):
def
test_default
(
self
):
self
.
run_tester
(
GPT2ModelTest
.
GPT2ModelTester
(
self
))
self
.
run_tester
(
GPT2ModelTest
.
GPT2ModelTester
(
self
))
...
@@ -247,6 +314,9 @@ class GPT2ModelTest(unittest.TestCase):
...
@@ -247,6 +314,9 @@ class GPT2ModelTest(unittest.TestCase):
tester
.
check_gpt2_double_heads_output
(
output_result
)
tester
.
check_gpt2_double_heads_output
(
output_result
)
tester
.
check_gpt2_double_heads_loss_output
(
output_result
)
tester
.
check_gpt2_double_heads_loss_output
(
output_result
)
tester
.
create_and_check_gpt2_for_headmasking
(
*
config_and_inputs
)
tester
.
create_and_check_gpt2_for_head_pruning
(
*
config_and_inputs
)
@
classmethod
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
"""Creates a random int32 tensor of the shape within the vocab size."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment