Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f12007e4
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "56d5d160cdd177ae6e644506535b56e79feccf68"
Commit
f12007e4
authored
Jun 17, 2019
by
thomwolf
Browse files
add head masking and pruning to openai GPT
parent
b860e47c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
149 additions
and
21 deletions
+149
-21
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+79
-21
tests/modeling_openai_test.py
tests/modeling_openai_test.py
+70
-0
No files found.
pytorch_pretrained_bert/modeling_openai.py
View file @
f12007e4
...
@@ -36,6 +36,7 @@ from torch.nn.parameter import Parameter
...
@@ -36,6 +36,7 @@ from torch.nn.parameter import Parameter
from
.file_utils
import
cached_path
,
CONFIG_NAME
,
WEIGHTS_NAME
from
.file_utils
import
cached_path
,
CONFIG_NAME
,
WEIGHTS_NAME
from
.modeling
import
BertLayerNorm
as
LayerNorm
from
.modeling
import
BertLayerNorm
as
LayerNorm
from
.modeling_gpt2
import
prune_conv1d_layer
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -256,7 +257,7 @@ class Conv1D(nn.Module):
...
@@ -256,7 +257,7 @@ class Conv1D(nn.Module):
class
Attention
(
nn
.
Module
):
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
nx
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
):
def
__init__
(
self
,
nx
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
Attention
,
self
).
__init__
()
super
(
Attention
,
self
).
__init__
()
n_state
=
nx
# in Attention: n_state=768 (nx=n_embd)
n_state
=
nx
# in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
...
@@ -265,13 +266,31 @@ class Attention(nn.Module):
...
@@ -265,13 +266,31 @@ class Attention(nn.Module):
self
.
n_head
=
config
.
n_head
self
.
n_head
=
config
.
n_head
self
.
split_size
=
n_state
self
.
split_size
=
n_state
self
.
scale
=
scale
self
.
scale
=
scale
self
.
output_attentions
=
output_attentions
self
.
output_attentions
=
output_attentions
self
.
keep_multihead_output
=
keep_multihead_output
self
.
multihead_output
=
None
self
.
c_attn
=
Conv1D
(
n_state
*
3
,
1
,
nx
)
self
.
c_attn
=
Conv1D
(
n_state
*
3
,
1
,
nx
)
self
.
c_proj
=
Conv1D
(
n_state
,
1
,
nx
)
self
.
c_proj
=
Conv1D
(
n_state
,
1
,
nx
)
self
.
attn_dropout
=
nn
.
Dropout
(
config
.
attn_pdrop
)
self
.
attn_dropout
=
nn
.
Dropout
(
config
.
attn_pdrop
)
self
.
resid_dropout
=
nn
.
Dropout
(
config
.
resid_pdrop
)
self
.
resid_dropout
=
nn
.
Dropout
(
config
.
resid_pdrop
)
def
_attn
(
self
,
q
,
k
,
v
):
def
prune_heads
(
self
,
heads
):
mask
=
torch
.
ones
(
self
.
n_head
,
self
.
split_size
//
self
.
n_head
)
for
head
in
heads
:
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
index_attn
=
torch
.
cat
([
index
,
index
+
self
.
split_size
,
index
+
(
2
*
self
.
split_size
)])
# Prune conv1d layers
self
.
c_attn
=
prune_conv1d_layer
(
self
.
c_attn
,
index_attn
,
dim
=
1
)
self
.
c_proj
=
prune_conv1d_layer
(
self
.
c_proj
,
index
,
dim
=
0
)
# Update hyper params
self
.
split_size
=
(
self
.
split_size
//
self
.
n_head
)
*
(
self
.
n_head
-
len
(
heads
))
self
.
n_head
=
self
.
n_head
-
len
(
heads
)
def
_attn
(
self
,
q
,
k
,
v
,
head_mask
=
None
):
w
=
torch
.
matmul
(
q
,
k
)
w
=
torch
.
matmul
(
q
,
k
)
if
self
.
scale
:
if
self
.
scale
:
w
=
w
/
math
.
sqrt
(
v
.
size
(
-
1
))
w
=
w
/
math
.
sqrt
(
v
.
size
(
-
1
))
...
@@ -282,6 +301,11 @@ class Attention(nn.Module):
...
@@ -282,6 +301,11 @@ class Attention(nn.Module):
w
=
nn
.
Softmax
(
dim
=-
1
)(
w
)
w
=
nn
.
Softmax
(
dim
=-
1
)(
w
)
w
=
self
.
attn_dropout
(
w
)
w
=
self
.
attn_dropout
(
w
)
# Mask heads if we want to
if
head_mask
is
not
None
:
w
=
w
*
head_mask
if
self
.
output_attentions
:
if
self
.
output_attentions
:
return
w
,
torch
.
matmul
(
w
,
v
)
return
w
,
torch
.
matmul
(
w
,
v
)
return
torch
.
matmul
(
w
,
v
)
return
torch
.
matmul
(
w
,
v
)
...
@@ -299,13 +323,18 @@ class Attention(nn.Module):
...
@@ -299,13 +323,18 @@ class Attention(nn.Module):
else
:
else
:
return
x
.
permute
(
0
,
2
,
1
,
3
)
return
x
.
permute
(
0
,
2
,
1
,
3
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
head_mask
=
None
):
x
=
self
.
c_attn
(
x
)
x
=
self
.
c_attn
(
x
)
query
,
key
,
value
=
x
.
split
(
self
.
split_size
,
dim
=
2
)
query
,
key
,
value
=
x
.
split
(
self
.
split_size
,
dim
=
2
)
query
=
self
.
split_heads
(
query
)
query
=
self
.
split_heads
(
query
)
key
=
self
.
split_heads
(
key
,
k
=
True
)
key
=
self
.
split_heads
(
key
,
k
=
True
)
value
=
self
.
split_heads
(
value
)
value
=
self
.
split_heads
(
value
)
a
=
self
.
_attn
(
query
,
key
,
value
)
a
=
self
.
_attn
(
query
,
key
,
value
,
head_mask
)
if
self
.
keep_multihead_output
:
self
.
multihead_output
=
a
self
.
multihead_output
.
retain_grad
()
if
self
.
output_attentions
:
if
self
.
output_attentions
:
attentions
,
a
=
a
attentions
,
a
=
a
a
=
self
.
merge_heads
(
a
)
a
=
self
.
merge_heads
(
a
)
...
@@ -332,17 +361,17 @@ class MLP(nn.Module):
...
@@ -332,17 +361,17 @@ class MLP(nn.Module):
class
Block
(
nn
.
Module
):
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
):
def
__init__
(
self
,
n_ctx
,
config
,
scale
=
False
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
Block
,
self
).
__init__
()
super
(
Block
,
self
).
__init__
()
nx
=
config
.
n_embd
nx
=
config
.
n_embd
self
.
output_attentions
=
output_attentions
self
.
output_attentions
=
output_attentions
self
.
attn
=
Attention
(
nx
,
n_ctx
,
config
,
scale
,
output_attentions
)
self
.
attn
=
Attention
(
nx
,
n_ctx
,
config
,
scale
,
output_attentions
,
keep_multihead_output
)
self
.
ln_1
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_1
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
MLP
(
4
*
nx
,
config
)
self
.
mlp
=
MLP
(
4
*
nx
,
config
)
self
.
ln_2
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_2
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
head_mask
=
None
):
a
=
self
.
attn
(
x
)
a
=
self
.
attn
(
x
,
head_mask
=
head_mask
)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
attentions
,
a
=
a
attentions
,
a
=
a
n
=
self
.
ln_1
(
x
+
a
)
n
=
self
.
ln_1
(
x
+
a
)
...
@@ -614,13 +643,14 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
...
@@ -614,13 +643,14 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
```
```
"""
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
OpenAIGPTModel
,
self
).
__init__
(
config
)
super
(
OpenAIGPTModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_attentions
=
output_attentions
self
.
tokens_embed
=
nn
.
Embedding
(
config
.
total_tokens_embeddings
,
config
.
n_embd
)
self
.
tokens_embed
=
nn
.
Embedding
(
config
.
total_tokens_embeddings
,
config
.
n_embd
)
self
.
positions_embed
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
self
.
positions_embed
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
,
output_attentions
=
output_attentions
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
config
.
n_layer
)])
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
config
.
n_layer
)])
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
...
@@ -639,7 +669,20 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
...
@@ -639,7 +669,20 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
# Copy word embeddings from the previous weights
# Copy word embeddings from the previous weights
self
.
tokens_embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
=
old_embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
self
.
tokens_embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
=
old_embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
):
def
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
h
[
layer
].
attn
.
prune_heads
(
heads
)
def
get_multihead_outputs
(
self
):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return
[
h
.
attn
.
multihead_output
for
h
in
self
.
h
]
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
head_mask
=
None
):
if
position_ids
is
None
:
if
position_ids
is
None
:
# This was used when we had a single embedding matrice from position and token embeddings
# This was used when we had a single embedding matrice from position and token embeddings
# start = self.config.vocab_size + self.config.n_special
# start = self.config.vocab_size + self.config.n_special
...
@@ -648,6 +691,17 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
...
@@ -648,6 +691,17 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
position_ids
=
torch
.
arange
(
input_ids
.
size
(
-
1
),
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_ids
=
torch
.
arange
(
input_ids
.
size
(
-
1
),
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
input_ids
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
input_ids
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we mask the head
# attention_probs has shape bsz x n_heads x N x N
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each instance in batch
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
head_mask
=
(
1.0
-
head_mask
)
input_shape
=
input_ids
.
size
()
input_shape
=
input_ids
.
size
()
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
position_ids
=
position_ids
.
view
(
-
1
,
position_ids
.
size
(
-
1
))
position_ids
=
position_ids
.
view
(
-
1
,
position_ids
.
size
(
-
1
))
...
@@ -664,11 +718,12 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
...
@@ -664,11 +718,12 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
all_attentions
=
[]
all_attentions
=
[]
for
block
in
self
.
h
:
for
block
in
self
.
h
:
outputs
=
block
(
hidden_states
,
head_mask
)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
attentions
,
hidden_states
=
block
(
hidden_states
)
attentions
,
hidden_states
=
outputs
all_attentions
.
append
(
attentions
)
all_attentions
.
append
(
attentions
)
else
:
else
:
hidden_states
=
block
(
hidden_states
)
hidden_states
=
outputs
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
return
all_attentions
,
hidden_states
.
view
(
*
output_shape
)
return
all_attentions
,
hidden_states
.
view
(
*
output_shape
)
...
@@ -731,9 +786,10 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
...
@@ -731,9 +786,10 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
```
```
"""
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
OpenAIGPTLMHeadModel
,
self
).
__init__
(
config
)
super
(
OpenAIGPTLMHeadModel
,
self
).
__init__
(
config
)
self
.
transformer
=
OpenAIGPTModel
(
config
,
output_attentions
=
output_attentions
)
self
.
transformer
=
OpenAIGPTModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
lm_head
=
OpenAIGPTLMHead
(
self
.
transformer
.
tokens_embed
.
weight
,
config
)
self
.
lm_head
=
OpenAIGPTLMHead
(
self
.
transformer
.
tokens_embed
.
weight
,
config
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
...
@@ -745,8 +801,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
...
@@ -745,8 +801,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_embed
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_embed
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
lm_labels
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
lm_labels
=
None
,
head_mask
=
None
):
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
)
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
head_mask
)
if
self
.
transformer
.
output_attentions
:
if
self
.
transformer
.
output_attentions
:
all_attentions
,
hidden_states
=
hidden_states
all_attentions
,
hidden_states
=
hidden_states
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
...
@@ -825,9 +881,10 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
...
@@ -825,9 +881,10 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
```
```
"""
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
OpenAIGPTDoubleHeadsModel
,
self
).
__init__
(
config
)
super
(
OpenAIGPTDoubleHeadsModel
,
self
).
__init__
(
config
)
self
.
transformer
=
OpenAIGPTModel
(
config
,
output_attentions
=
output_attentions
)
self
.
transformer
=
OpenAIGPTModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
lm_head
=
OpenAIGPTLMHead
(
self
.
transformer
.
tokens_embed
.
weight
,
config
)
self
.
lm_head
=
OpenAIGPTLMHead
(
self
.
transformer
.
tokens_embed
.
weight
,
config
)
self
.
multiple_choice_head
=
OpenAIGPTMultipleChoiceHead
(
config
)
self
.
multiple_choice_head
=
OpenAIGPTMultipleChoiceHead
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
...
@@ -840,8 +897,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
...
@@ -840,8 +897,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_embed
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_embed
.
weight
,
predict_special_tokens
=
predict_special_tokens
)
def
forward
(
self
,
input_ids
,
mc_token_ids
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
position_ids
=
None
):
def
forward
(
self
,
input_ids
,
mc_token_ids
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
)
position_ids
=
None
,
head_mask
=
None
):
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
head_mask
)
if
self
.
transformer
.
output_attentions
:
if
self
.
transformer
.
output_attentions
:
all_attentions
,
hidden_states
=
hidden_states
all_attentions
,
hidden_states
=
hidden_states
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
...
...
tests/modeling_openai_test.py
View file @
f12007e4
...
@@ -182,6 +182,73 @@ class OpenAIGPTModelTest(unittest.TestCase):
...
@@ -182,6 +182,73 @@ class OpenAIGPTModelTest(unittest.TestCase):
[
list
(
l
.
size
())
for
l
in
result
[
"loss"
]],
[
list
(
l
.
size
())
for
l
in
result
[
"loss"
]],
[[],
[]])
[[],
[]])
def
create_and_check_openai_for_headmasking
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
for
model_class
in
(
OpenAIGPTModel
,
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
):
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
.
eval
()
head_mask
=
torch
.
ones
(
self
.
n_head
).
to
(
input_ids
.
device
)
head_mask
[
0
]
=
0.0
head_mask
[
-
1
]
=
0.0
# Mask all but the first and last heads
if
isinstance
(
model
,
OpenAIGPTDoubleHeadsModel
):
output
=
model
(
input_ids
,
mc_token_ids
,
head_mask
=
head_mask
)
else
:
output
=
model
(
input_ids
,
head_mask
=
head_mask
)
output
=
sum
(
t
.
sum
()
for
t
in
output
[:
-
1
])
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
(
model
if
isinstance
(
model
,
OpenAIGPTModel
)
else
model
.
transformer
).
get_multihead_outputs
()
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
0
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
1
:(
self
.
n_head
-
1
),
:,
:].
nonzero
()),
0
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
0
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
n_choices
*
self
.
seq_length
*
self
.
n_embd
//
self
.
n_head
)
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
[
0
][:,
self
.
n_head
-
1
,
:,
:].
nonzero
()),
self
.
batch_size
*
self
.
n_choices
*
self
.
seq_length
*
self
.
n_embd
//
self
.
n_head
)
def
create_and_check_openai_for_head_pruning
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
for
model_class
in
(
OpenAIGPTModel
,
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
):
model
=
model_class
(
config
=
config
,
keep_multihead_output
=
True
)
model
.
eval
()
transformer
=
model
if
isinstance
(
model
,
OpenAIGPTModel
)
else
model
.
transformer
heads_to_prune
=
{
0
:
list
(
range
(
1
,
self
.
n_head
)),
-
1
:
[
0
]}
transformer
.
prune_heads
(
heads_to_prune
)
if
isinstance
(
model
,
OpenAIGPTDoubleHeadsModel
):
output
=
model
(
input_ids
,
mc_token_ids
)
else
:
output
=
model
(
input_ids
)
output
=
sum
(
t
.
sum
()
for
t
in
output
[:
-
1
])
output
=
output
.
sum
()
output
.
backward
()
multihead_outputs
=
transformer
.
get_multihead_outputs
()
self
.
parent
.
assertEqual
(
len
(
multihead_outputs
),
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
0
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
1
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
1
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
self
.
parent
.
assertListEqual
(
list
(
multihead_outputs
[
-
1
].
size
()),
[
self
.
batch_size
*
self
.
n_choices
,
self
.
n_head
-
1
,
self
.
seq_length
,
self
.
n_embd
//
self
.
n_head
])
def
test_default
(
self
):
def
test_default
(
self
):
self
.
run_tester
(
OpenAIGPTModelTest
.
OpenAIGPTModelTester
(
self
))
self
.
run_tester
(
OpenAIGPTModelTest
.
OpenAIGPTModelTester
(
self
))
...
@@ -220,6 +287,9 @@ class OpenAIGPTModelTest(unittest.TestCase):
...
@@ -220,6 +287,9 @@ class OpenAIGPTModelTest(unittest.TestCase):
tester
.
check_openai_double_heads_output
(
output_result
)
tester
.
check_openai_double_heads_output
(
output_result
)
tester
.
check_openai_double_heads_loss_output
(
output_result
)
tester
.
check_openai_double_heads_loss_output
(
output_result
)
tester
.
create_and_check_openai_for_headmasking
(
*
config_and_inputs
)
tester
.
create_and_check_openai_for_head_pruning
(
*
config_and_inputs
)
@
classmethod
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
"""Creates a random int32 tensor of the shape within the vocab size."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment