Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fb8f4277
"benchmark/git@developer.sourcefind.cn:change/sglang.git" did not exist on "dd650e0e21bbe07d131dd861aa475b0b9fc89ead"
Commit
fb8f4277
authored
May 27, 2020
by
Victor SANH
Browse files
add scripts
parent
d489a6d3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
224 additions
and
0 deletions
+224
-0
examples/movement-pruning/bertarize.py
examples/movement-pruning/bertarize.py
+132
-0
examples/movement-pruning/counts_parameters.py
examples/movement-pruning/counts_parameters.py
+92
-0
No files found.
examples/movement-pruning/bertarize.py
0 → 100644
View file @
fb8f4277
# Copyright 2020-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Once a model has been fine-pruned, the weights that are masked during the forward pass can be pruned once for all.
For instance, once the a model from the :class:`~emmental.MaskedBertForSequenceClassification` is trained, it can be saved (and then loaded)
as a standard :class:`~transformers.BertForSequenceClassification`.
"""
import
os
import
shutil
import
argparse
import
torch
from
emmental.modules
import
MagnitudeBinarizer
,
TopKBinarizer
,
ThresholdBinarizer
def
main
(
args
):
pruning_method
=
args
.
pruning_method
threshold
=
args
.
threshold
model_name_or_path
=
args
.
model_name_or_path
.
rstrip
(
"/"
)
target_model_path
=
args
.
target_model_path
print
(
f
"Load fine-pruned model from
{
model_name_or_path
}
"
)
model
=
torch
.
load
(
os
.
path
.
join
(
model_name_or_path
,
"pytorch_model.bin"
))
pruned_model
=
{}
for
name
,
tensor
in
model
.
items
():
if
"embeddings"
in
name
or
"LayerNorm"
in
name
or
"pooler"
in
name
:
pruned_model
[
name
]
=
tensor
print
(
f
"Pruned layer
{
name
}
"
)
elif
"classifier"
in
name
or
"qa_output"
in
name
:
pruned_model
[
name
]
=
tensor
print
(
f
"Pruned layer
{
name
}
"
)
elif
"bias"
in
name
:
pruned_model
[
name
]
=
tensor
print
(
f
"Pruned layer
{
name
}
"
)
else
:
if
pruning_method
==
"magnitude"
:
mask
=
MagnitudeBinarizer
.
apply
(
inputs
=
tensor
,
threshold
=
threshold
)
pruned_model
[
name
]
=
tensor
*
mask
print
(
f
"Pruned layer
{
name
}
"
)
elif
pruning_method
==
"topK"
:
if
"mask_scores"
in
name
:
continue
prefix_
=
name
[:
-
6
]
scores
=
model
[
f
"
{
prefix_
}
mask_scores"
]
mask
=
TopKBinarizer
.
apply
(
scores
,
threshold
)
pruned_model
[
name
]
=
tensor
*
mask
print
(
f
"Pruned layer
{
name
}
"
)
elif
pruning_method
==
"sigmoied_threshold"
:
if
"mask_scores"
in
name
:
continue
prefix_
=
name
[:
-
6
]
scores
=
model
[
f
"
{
prefix_
}
mask_scores"
]
mask
=
ThresholdBinarizer
.
apply
(
scores
,
threshold
,
True
)
pruned_model
[
name
]
=
tensor
*
mask
print
(
f
"Pruned layer
{
name
}
"
)
elif
pruning_method
==
"l0"
:
if
"mask_scores"
in
name
:
continue
prefix_
=
name
[:
-
6
]
scores
=
model
[
f
"
{
prefix_
}
mask_scores"
]
l
,
r
=
-
0.1
,
1.1
s
=
torch
.
sigmoid
(
scores
)
s_bar
=
s
*
(
r
-
l
)
+
l
mask
=
s_bar
.
clamp
(
min
=
0.0
,
max
=
1.0
)
pruned_model
[
name
]
=
tensor
*
mask
print
(
f
"Pruned layer
{
name
}
"
)
else
:
raise
ValueError
(
"Unknown pruning method"
)
if
target_model_path
is
None
:
target_model_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
model_name_or_path
),
f
"bertarized_
{
os
.
path
.
basename
(
model_name_or_path
)
}
"
)
if
not
os
.
path
.
isdir
(
target_model_path
):
shutil
.
copytree
(
model_name_or_path
,
target_model_path
)
print
(
f
"
\n
Created folder
{
target_model_path
}
"
)
torch
.
save
(
pruned_model
,
os
.
path
.
join
(
target_model_path
,
"pytorch_model.bin"
))
print
(
"
\n
Pruned model saved! See you later!"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--pruning_method"
,
choices
=
[
"l0"
,
"magnitude"
,
"topK"
,
"sigmoied_threshold"
,
],
type
=
str
,
required
=
True
,
help
=
"Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning, sigmoied_threshold = Soft movement pruning)"
,
)
parser
.
add_argument
(
"--threshold"
,
type
=
float
,
required
=
False
,
help
=
"For `magnitude` and `topK`, it is the level of remaining weights (in %) in the fine-pruned model."
"For `sigmoied_threshold`, it is the threshold
\t
au against which the (sigmoied) scores are compared."
"Not needed for `l0`"
,
)
parser
.
add_argument
(
"--model_name_or_path"
,
type
=
str
,
required
=
True
,
help
=
"Folder containing the model that was previously fine-pruned"
,
)
parser
.
add_argument
(
"--target_model_path"
,
default
=
None
,
type
=
str
,
required
=
False
,
help
=
"Folder containing the model that was previously fine-pruned"
,
)
args
=
parser
.
parse_args
()
main
(
args
)
examples/movement-pruning/counts_parameters.py
0 → 100644
View file @
fb8f4277
# Copyright 2020-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Count remaining (non-zero) weights in the encoder (i.e. the transformer layers).
Sparsity and remaining weights levels are equivalent: sparsity % = 100 - remaining weights %.
"""
import
os
import
argparse
import
torch
from
emmental.modules
import
TopKBinarizer
,
ThresholdBinarizer
def
main
(
args
):
serialization_dir
=
args
.
serialization_dir
pruning_method
=
args
.
pruning_method
threshold
=
args
.
threshold
st
=
torch
.
load
(
os
.
path
.
join
(
serialization_dir
,
"pytorch_model.bin"
),
map_location
=
"cpu"
)
remaining_count
=
0
# Number of remaining (not pruned) params in the encoder
encoder_count
=
0
# Number of params in the encoder
print
(
"name"
.
ljust
(
60
,
" "
),
"Remaining Weights %"
,
"Remaning Weight"
)
for
name
,
param
in
st
.
items
():
if
"encoder"
not
in
name
:
continue
if
"mask_scores"
in
name
:
if
pruning_method
==
"topK"
:
mask_ones
=
TopKBinarizer
.
apply
(
param
,
threshold
).
sum
().
item
()
elif
pruning_method
==
"sigmoied_threshold"
:
mask_ones
=
ThresholdBinarizer
.
apply
(
param
,
threshold
,
True
).
sum
().
item
()
elif
pruning_method
==
"l0"
:
l
,
r
=
-
0.1
,
1.1
s
=
torch
.
sigmoid
(
param
)
s_bar
=
s
*
(
r
-
l
)
+
l
mask
=
s_bar
.
clamp
(
min
=
0.0
,
max
=
1.0
)
mask_ones
=
(
mask
>
0.0
).
sum
().
item
()
else
:
raise
ValueError
(
"Unknown pruning method"
)
remaining_count
+=
mask_ones
print
(
name
.
ljust
(
60
,
" "
),
str
(
round
(
100
*
mask_ones
/
param
.
numel
(),
3
)).
ljust
(
20
,
" "
),
str
(
mask_ones
))
else
:
encoder_count
+=
param
.
numel
()
if
"bias"
in
name
or
"LayerNorm"
in
name
:
remaining_count
+=
param
.
numel
()
print
(
""
)
print
(
"Remaining Weights (global) %: "
,
100
*
remaining_count
/
encoder_count
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--pruning_method"
,
choices
=
[
"l0"
,
"topK"
,
"sigmoied_threshold"
,
],
type
=
str
,
required
=
True
,
help
=
"Pruning Method (l0 = L0 regularization, topK = Movement pruning, sigmoied_threshold = Soft movement pruning)"
,
)
parser
.
add_argument
(
"--threshold"
,
type
=
float
,
required
=
False
,
help
=
"For `topK`, it is the level of remaining weights (in %) in the fine-pruned model."
"For `sigmoied_threshold`, it is the threshold
\t
au against which the (sigmoied) scores are compared."
"Not needed for `l0`"
,
)
parser
.
add_argument
(
"--serialization_dir"
,
type
=
str
,
required
=
True
,
help
=
"Folder containing the model that was previously fine-pruned"
,
)
args
=
parser
.
parse_args
()
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment