Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
11600edc
Commit
11600edc
authored
Aug 31, 2019
by
LysandreJik
Browse files
Rebase on master + DistilBERT head pruning patch
parent
b6992b7b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
5 deletions
+10
-5
pytorch_transformers/modeling_distilbert.py
pytorch_transformers/modeling_distilbert.py
+10
-5
No files found.
pytorch_transformers/modeling_distilbert.py
View file @
11600edc
...
@@ -174,12 +174,16 @@ class MultiHeadSelfAttention(nn.Module):
...
@@ -174,12 +174,16 @@ class MultiHeadSelfAttention(nn.Module):
self
.
v_lin
=
nn
.
Linear
(
in_features
=
config
.
dim
,
out_features
=
config
.
dim
)
self
.
v_lin
=
nn
.
Linear
(
in_features
=
config
.
dim
,
out_features
=
config
.
dim
)
self
.
out_lin
=
nn
.
Linear
(
in_features
=
config
.
dim
,
out_features
=
config
.
dim
)
self
.
out_lin
=
nn
.
Linear
(
in_features
=
config
.
dim
,
out_features
=
config
.
dim
)
self
.
pruned_heads
=
set
()
def
prune_heads
(
self
,
heads
):
def
prune_heads
(
self
,
heads
):
attention_head_size
=
self
.
dim
//
self
.
n_heads
attention_head_size
=
self
.
dim
//
self
.
n_heads
if
len
(
heads
)
==
0
:
if
len
(
heads
)
==
0
:
return
return
mask
=
torch
.
ones
(
self
.
n_heads
,
attention_head_size
)
mask
=
torch
.
ones
(
self
.
n_heads
,
attention_head_size
)
heads
=
set
(
heads
)
-
self
.
pruned_heads
for
head
in
heads
:
for
head
in
heads
:
head
-=
sum
(
1
if
h
<
head
else
0
for
h
in
self
.
pruned_heads
)
mask
[
head
]
=
0
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
...
@@ -191,6 +195,7 @@ class MultiHeadSelfAttention(nn.Module):
...
@@ -191,6 +195,7 @@ class MultiHeadSelfAttention(nn.Module):
# Update hyper params
# Update hyper params
self
.
n_heads
=
self
.
n_heads
-
len
(
heads
)
self
.
n_heads
=
self
.
n_heads
-
len
(
heads
)
self
.
dim
=
attention_head_size
*
self
.
n_heads
self
.
dim
=
attention_head_size
*
self
.
n_heads
self
.
pruned_heads
=
self
.
pruned_heads
.
union
(
heads
)
def
forward
(
self
,
query
,
key
,
value
,
mask
,
head_mask
=
None
):
def
forward
(
self
,
query
,
key
,
value
,
mask
,
head_mask
=
None
):
"""
"""
...
@@ -395,7 +400,7 @@ class DistilBertPreTrainedModel(PreTrainedModel):
...
@@ -395,7 +400,7 @@ class DistilBertPreTrainedModel(PreTrainedModel):
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
DistilBertPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
super
(
DistilBertPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
def
_
init_weights
(
self
,
module
):
""" Initialize the weights.
""" Initialize the weights.
"""
"""
if
isinstance
(
module
,
nn
.
Embedding
):
if
isinstance
(
module
,
nn
.
Embedding
):
...
@@ -480,7 +485,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
...
@@ -480,7 +485,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
self
.
embeddings
=
Embeddings
(
config
)
# Embeddings
self
.
embeddings
=
Embeddings
(
config
)
# Embeddings
self
.
transformer
=
Transformer
(
config
)
# Encoder
self
.
transformer
=
Transformer
(
config
)
# Encoder
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
old_embeddings
=
self
.
embeddings
.
word_embeddings
old_embeddings
=
self
.
embeddings
.
word_embeddings
...
@@ -568,7 +573,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
...
@@ -568,7 +573,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
self
.
vocab_layer_norm
=
nn
.
LayerNorm
(
config
.
dim
,
eps
=
1e-12
)
self
.
vocab_layer_norm
=
nn
.
LayerNorm
(
config
.
dim
,
eps
=
1e-12
)
self
.
vocab_projector
=
nn
.
Linear
(
config
.
dim
,
config
.
vocab_size
)
self
.
vocab_projector
=
nn
.
Linear
(
config
.
dim
,
config
.
vocab_size
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
self
.
tie_weights
()
self
.
tie_weights
()
self
.
mlm_loss_fct
=
nn
.
CrossEntropyLoss
(
ignore_index
=-
1
)
self
.
mlm_loss_fct
=
nn
.
CrossEntropyLoss
(
ignore_index
=-
1
)
...
@@ -642,7 +647,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
...
@@ -642,7 +647,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
self
.
classifier
=
nn
.
Linear
(
config
.
dim
,
config
.
num_labels
)
self
.
classifier
=
nn
.
Linear
(
config
.
dim
,
config
.
num_labels
)
self
.
dropout
=
nn
.
Dropout
(
config
.
seq_classif_dropout
)
self
.
dropout
=
nn
.
Dropout
(
config
.
seq_classif_dropout
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
labels
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
labels
=
None
,
head_mask
=
None
):
distilbert_output
=
self
.
distilbert
(
input_ids
=
input_ids
,
distilbert_output
=
self
.
distilbert
(
input_ids
=
input_ids
,
...
@@ -716,7 +721,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
...
@@ -716,7 +721,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
assert
config
.
num_labels
==
2
assert
config
.
num_labels
==
2
self
.
dropout
=
nn
.
Dropout
(
config
.
qa_dropout
)
self
.
dropout
=
nn
.
Dropout
(
config
.
qa_dropout
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
start_positions
=
None
,
end_positions
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
start_positions
=
None
,
end_positions
=
None
,
head_mask
=
None
):
distilbert_output
=
self
.
distilbert
(
input_ids
=
input_ids
,
distilbert_output
=
self
.
distilbert
(
input_ids
=
input_ids
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment