Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a139d1a1
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "fd3de2000fc087cf04361ea8d295f7554566854a"
Unverified
Commit
a139d1a1
authored
Jun 08, 2020
by
Sam Shleifer
Committed by
GitHub
Jun 08, 2020
Browse files
[cleanup] consolidate some prune_heads logic (#4799)
parent
4c7f564f
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
54 additions
and
59 deletions
+54
-59
src/transformers/modeling_albert.py
src/transformers/modeling_albert.py
+4
-9
src/transformers/modeling_bert.py
src/transformers/modeling_bert.py
+4
-9
src/transformers/modeling_distilbert.py
src/transformers/modeling_distilbert.py
+2
-8
src/transformers/modeling_gpt2.py
src/transformers/modeling_gpt2.py
+10
-9
src/transformers/modeling_openai.py
src/transformers/modeling_openai.py
+10
-8
src/transformers/modeling_t5.py
src/transformers/modeling_t5.py
+2
-8
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+14
-0
src/transformers/modeling_xlm.py
src/transformers/modeling_xlm.py
+8
-8
No files found.
src/transformers/modeling_albert.py
View file @
a139d1a1
...
@@ -26,7 +26,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
...
@@ -26,7 +26,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
from
.configuration_albert
import
AlbertConfig
from
.configuration_albert
import
AlbertConfig
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_bert
import
ACT2FN
,
BertEmbeddings
,
BertSelfAttention
,
prune_linear_layer
from
.modeling_bert
import
ACT2FN
,
BertEmbeddings
,
BertSelfAttention
,
prune_linear_layer
from
.modeling_utils
import
PreTrainedModel
from
.modeling_utils
import
PreTrainedModel
,
find_pruneable_heads_and_indices
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -199,14 +199,9 @@ class AlbertAttention(BertSelfAttention):
...
@@ -199,14 +199,9 @@ class AlbertAttention(BertSelfAttention):
def
prune_heads
(
self
,
heads
):
def
prune_heads
(
self
,
heads
):
if
len
(
heads
)
==
0
:
if
len
(
heads
)
==
0
:
return
return
mask
=
torch
.
ones
(
self
.
num_attention_heads
,
self
.
attention_head_size
)
heads
,
index
=
find_pruneable_heads_and_indices
(
heads
=
set
(
heads
)
-
self
.
pruned_heads
# Convert to set and emove already pruned heads
heads
,
self
.
num_attention_heads
,
self
.
attention_head_size
,
self
.
pruned_heads
for
head
in
heads
:
)
# Compute how many pruned heads are before the head and move the index accordingly
head
=
head
-
sum
(
1
if
h
<
head
else
0
for
h
in
self
.
pruned_heads
)
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
# Prune linear layers
# Prune linear layers
self
.
query
=
prune_linear_layer
(
self
.
query
,
index
)
self
.
query
=
prune_linear_layer
(
self
.
query
,
index
)
...
...
src/transformers/modeling_bert.py
View file @
a139d1a1
...
@@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
...
@@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
from
.activations
import
gelu
,
gelu_new
,
swish
from
.activations
import
gelu
,
gelu_new
,
swish
from
.configuration_bert
import
BertConfig
from
.configuration_bert
import
BertConfig
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_utils
import
PreTrainedModel
,
prune_linear_layer
from
.modeling_utils
import
PreTrainedModel
,
find_pruneable_heads_and_indices
,
prune_linear_layer
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -284,14 +284,9 @@ class BertAttention(nn.Module):
...
@@ -284,14 +284,9 @@ class BertAttention(nn.Module):
def
prune_heads
(
self
,
heads
):
def
prune_heads
(
self
,
heads
):
if
len
(
heads
)
==
0
:
if
len
(
heads
)
==
0
:
return
return
mask
=
torch
.
ones
(
self
.
self
.
num_attention_heads
,
self
.
self
.
attention_head_size
)
heads
,
index
=
find_pruneable_heads_and_indices
(
heads
=
set
(
heads
)
-
self
.
pruned_heads
# Convert to set and remove already pruned heads
heads
,
self
.
self
.
num_attention_heads
,
self
.
self
.
attention_head_size
,
self
.
pruned_heads
for
head
in
heads
:
)
# Compute how many pruned heads are before the head and move the index accordingly
head
=
head
-
sum
(
1
if
h
<
head
else
0
for
h
in
self
.
pruned_heads
)
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
# Prune linear layers
# Prune linear layers
self
.
self
.
query
=
prune_linear_layer
(
self
.
self
.
query
,
index
)
self
.
self
.
query
=
prune_linear_layer
(
self
.
self
.
query
,
index
)
...
...
src/transformers/modeling_distilbert.py
View file @
a139d1a1
...
@@ -31,7 +31,7 @@ from torch.nn import CrossEntropyLoss
...
@@ -31,7 +31,7 @@ from torch.nn import CrossEntropyLoss
from
.activations
import
gelu
from
.activations
import
gelu
from
.configuration_distilbert
import
DistilBertConfig
from
.configuration_distilbert
import
DistilBertConfig
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_utils
import
PreTrainedModel
,
prune_linear_layer
from
.modeling_utils
import
PreTrainedModel
,
find_pruneable_heads_and_indices
,
prune_linear_layer
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -120,13 +120,7 @@ class MultiHeadSelfAttention(nn.Module):
...
@@ -120,13 +120,7 @@ class MultiHeadSelfAttention(nn.Module):
attention_head_size
=
self
.
dim
//
self
.
n_heads
attention_head_size
=
self
.
dim
//
self
.
n_heads
if
len
(
heads
)
==
0
:
if
len
(
heads
)
==
0
:
return
return
mask
=
torch
.
ones
(
self
.
n_heads
,
attention_head_size
)
heads
,
index
=
find_pruneable_heads_and_indices
(
heads
,
self
.
n_heads
,
attention_head_size
,
self
.
pruned_heads
)
heads
=
set
(
heads
)
-
self
.
pruned_heads
for
head
in
heads
:
head
-=
sum
(
1
if
h
<
head
else
0
for
h
in
self
.
pruned_heads
)
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
# Prune linear layers
# Prune linear layers
self
.
q_lin
=
prune_linear_layer
(
self
.
q_lin
,
index
)
self
.
q_lin
=
prune_linear_layer
(
self
.
q_lin
,
index
)
self
.
k_lin
=
prune_linear_layer
(
self
.
k_lin
,
index
)
self
.
k_lin
=
prune_linear_layer
(
self
.
k_lin
,
index
)
...
...
src/transformers/modeling_gpt2.py
View file @
a139d1a1
...
@@ -27,7 +27,13 @@ from torch.nn import CrossEntropyLoss
...
@@ -27,7 +27,13 @@ from torch.nn import CrossEntropyLoss
from
.activations
import
ACT2FN
from
.activations
import
ACT2FN
from
.configuration_gpt2
import
GPT2Config
from
.configuration_gpt2
import
GPT2Config
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_utils
import
Conv1D
,
PreTrainedModel
,
SequenceSummary
,
prune_conv1d_layer
from
.modeling_utils
import
(
Conv1D
,
PreTrainedModel
,
SequenceSummary
,
find_pruneable_heads_and_indices
,
prune_conv1d_layer
,
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -122,14 +128,9 @@ class Attention(nn.Module):
...
@@ -122,14 +128,9 @@ class Attention(nn.Module):
def
prune_heads
(
self
,
heads
):
def
prune_heads
(
self
,
heads
):
if
len
(
heads
)
==
0
:
if
len
(
heads
)
==
0
:
return
return
mask
=
torch
.
ones
(
self
.
n_head
,
self
.
split_size
//
self
.
n_head
)
heads
,
index
=
find_pruneable_heads_and_indices
(
heads
=
set
(
heads
)
-
self
.
pruned_heads
# Convert to set and emove already pruned heads
heads
,
self
.
n_head
,
self
.
split_size
//
self
.
n_head
,
self
.
pruned_heads
for
head
in
heads
:
)
# Compute how many pruned heads are before the head and move the index accordingly
head
=
head
-
sum
(
1
if
h
<
head
else
0
for
h
in
self
.
pruned_heads
)
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
index_attn
=
torch
.
cat
([
index
,
index
+
self
.
split_size
,
index
+
(
2
*
self
.
split_size
)])
index_attn
=
torch
.
cat
([
index
,
index
+
self
.
split_size
,
index
+
(
2
*
self
.
split_size
)])
# Prune conv1d layers
# Prune conv1d layers
...
...
src/transformers/modeling_openai.py
View file @
a139d1a1
...
@@ -29,7 +29,13 @@ from torch.nn import CrossEntropyLoss
...
@@ -29,7 +29,13 @@ from torch.nn import CrossEntropyLoss
from
.activations
import
gelu_new
,
swish
from
.activations
import
gelu_new
,
swish
from
.configuration_openai
import
OpenAIGPTConfig
from
.configuration_openai
import
OpenAIGPTConfig
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_utils
import
Conv1D
,
PreTrainedModel
,
SequenceSummary
,
prune_conv1d_layer
from
.modeling_utils
import
(
Conv1D
,
PreTrainedModel
,
SequenceSummary
,
find_pruneable_heads_and_indices
,
prune_conv1d_layer
,
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -142,13 +148,9 @@ class Attention(nn.Module):
...
@@ -142,13 +148,9 @@ class Attention(nn.Module):
def
prune_heads
(
self
,
heads
):
def
prune_heads
(
self
,
heads
):
if
len
(
heads
)
==
0
:
if
len
(
heads
)
==
0
:
return
return
mask
=
torch
.
ones
(
self
.
n_head
,
self
.
split_size
//
self
.
n_head
)
heads
,
index
=
find_pruneable_heads_and_indices
(
heads
=
set
(
heads
)
-
self
.
pruned_heads
heads
,
self
.
n_head
,
self
.
split_size
//
self
.
n_head
,
self
.
pruned_heads
for
head
in
heads
:
)
head
-=
sum
(
1
if
h
<
head
else
0
for
h
in
self
.
pruned_heads
)
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
index_attn
=
torch
.
cat
([
index
,
index
+
self
.
split_size
,
index
+
(
2
*
self
.
split_size
)])
index_attn
=
torch
.
cat
([
index
,
index
+
self
.
split_size
,
index
+
(
2
*
self
.
split_size
)])
# Prune conv1d layers
# Prune conv1d layers
self
.
c_attn
=
prune_conv1d_layer
(
self
.
c_attn
,
index_attn
,
dim
=
1
)
self
.
c_attn
=
prune_conv1d_layer
(
self
.
c_attn
,
index_attn
,
dim
=
1
)
...
...
src/transformers/modeling_t5.py
View file @
a139d1a1
...
@@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss
...
@@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss
from
.configuration_t5
import
T5Config
from
.configuration_t5
import
T5Config
from
.file_utils
import
DUMMY_INPUTS
,
DUMMY_MASK
,
add_start_docstrings
,
add_start_docstrings_to_callable
from
.file_utils
import
DUMMY_INPUTS
,
DUMMY_MASK
,
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_utils
import
PreTrainedModel
,
prune_linear_layer
from
.modeling_utils
import
PreTrainedModel
,
find_pruneable_heads_and_indices
,
prune_linear_layer
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -216,13 +216,7 @@ class T5Attention(nn.Module):
...
@@ -216,13 +216,7 @@ class T5Attention(nn.Module):
def
prune_heads
(
self
,
heads
):
def
prune_heads
(
self
,
heads
):
if
len
(
heads
)
==
0
:
if
len
(
heads
)
==
0
:
return
return
mask
=
torch
.
ones
(
self
.
n_heads
,
self
.
d_kv
)
heads
,
index
=
find_pruneable_heads_and_indices
(
heads
,
self
.
n_heads
,
self
.
d_kv
,
self
.
pruned_heads
)
heads
=
set
(
heads
)
-
self
.
pruned_heads
for
head
in
heads
:
head
-=
sum
(
1
if
h
<
head
else
0
for
h
in
self
.
pruned_heads
)
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
# Prune linear layers
# Prune linear layers
self
.
q
=
prune_linear_layer
(
self
.
q
,
index
)
self
.
q
=
prune_linear_layer
(
self
.
q
,
index
)
self
.
k
=
prune_linear_layer
(
self
.
k
,
index
)
self
.
k
=
prune_linear_layer
(
self
.
k
,
index
)
...
...
src/transformers/modeling_utils.py
View file @
a139d1a1
...
@@ -55,6 +55,20 @@ except ImportError:
...
@@ -55,6 +55,20 @@ except ImportError:
return
input
return
input
def
find_pruneable_heads_and_indices
(
heads
:
List
,
n_heads
:
int
,
head_size
:
int
,
already_pruned_heads
:
set
)
->
Tuple
[
set
,
"torch.LongTensor"
]:
mask
=
torch
.
ones
(
n_heads
,
head_size
)
heads
=
set
(
heads
)
-
already_pruned_heads
# Convert to set and remove already pruned heads
for
head
in
heads
:
# Compute how many pruned heads are before the head and move the index accordingly
head
=
head
-
sum
(
1
if
h
<
head
else
0
for
h
in
already_pruned_heads
)
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
:
torch
.
LongTensor
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
return
heads
,
index
class
ModuleUtilsMixin
:
class
ModuleUtilsMixin
:
"""
"""
A few utilities for torch.nn.Modules, to be used as a mixin.
A few utilities for torch.nn.Modules, to be used as a mixin.
...
...
src/transformers/modeling_xlm.py
View file @
a139d1a1
...
@@ -29,7 +29,13 @@ from torch.nn import functional as F
...
@@ -29,7 +29,13 @@ from torch.nn import functional as F
from
.activations
import
gelu
from
.activations
import
gelu
from
.configuration_xlm
import
XLMConfig
from
.configuration_xlm
import
XLMConfig
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_utils
import
PreTrainedModel
,
SequenceSummary
,
SQuADHead
,
prune_linear_layer
from
.modeling_utils
import
(
PreTrainedModel
,
SequenceSummary
,
SQuADHead
,
find_pruneable_heads_and_indices
,
prune_linear_layer
,
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -105,13 +111,7 @@ class MultiHeadAttention(nn.Module):
...
@@ -105,13 +111,7 @@ class MultiHeadAttention(nn.Module):
attention_head_size
=
self
.
dim
//
self
.
n_heads
attention_head_size
=
self
.
dim
//
self
.
n_heads
if
len
(
heads
)
==
0
:
if
len
(
heads
)
==
0
:
return
return
mask
=
torch
.
ones
(
self
.
n_heads
,
attention_head_size
)
heads
,
index
=
find_pruneable_heads_and_indices
(
heads
,
self
.
n_heads
,
attention_head_size
,
self
.
pruned_heads
)
heads
=
set
(
heads
)
-
self
.
pruned_heads
for
head
in
heads
:
head
-=
sum
(
1
if
h
<
head
else
0
for
h
in
self
.
pruned_heads
)
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
# Prune linear layers
# Prune linear layers
self
.
q_lin
=
prune_linear_layer
(
self
.
q_lin
,
index
)
self
.
q_lin
=
prune_linear_layer
(
self
.
q_lin
,
index
)
self
.
k_lin
=
prune_linear_layer
(
self
.
k_lin
,
index
)
self
.
k_lin
=
prune_linear_layer
(
self
.
k_lin
,
index
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment