Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
23edebc0
"include/vscode:/vscode.git/clone" did not exist on "fa94a220eb5932a9b33e298a2d6839744e10b4b0"
Commit
23edebc0
authored
Oct 02, 2019
by
VictorSanh
Committed by
Victor SANH
Oct 03, 2019
Browse files
update extract_distilbert
parent
cbfcfce2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
18 deletions
+10
-18
examples/distillation/scripts/extract_distilbert.py
examples/distillation/scripts/extract_distilbert.py
+10
-18
No files found.
examples/distillation/scripts/extract_
for_
distil.py
→
examples/distillation/scripts/extract_distil
bert
.py
View file @
23edebc0
...
...
@@ -14,6 +14,7 @@
# limitations under the License.
"""
Preprocessing script before training DistilBERT.
Specific to BERT -> DistilBERT.
"""
from
transformers
import
BertForMaskedLM
,
RobertaForMaskedLM
import
torch
...
...
@@ -21,7 +22,7 @@ import argparse
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation"
)
parser
.
add_argument
(
"--model_type"
,
default
=
"bert"
,
choices
=
[
"bert"
,
"roberta"
])
parser
.
add_argument
(
"--model_type"
,
default
=
"bert"
,
choices
=
[
"bert"
])
parser
.
add_argument
(
"--model_name"
,
default
=
'bert-base-uncased'
,
type
=
str
)
parser
.
add_argument
(
"--dump_checkpoint"
,
default
=
'serialization_dir/tf_bert-base-uncased_0247911.pth'
,
type
=
str
)
parser
.
add_argument
(
"--vocab_transform"
,
action
=
'store_true'
)
...
...
@@ -31,9 +32,8 @@ if __name__ == '__main__':
if
args
.
model_type
==
'bert'
:
model
=
BertForMaskedLM
.
from_pretrained
(
args
.
model_name
)
prefix
=
'bert'
elif
args
.
model_type
==
'roberta'
:
model
=
RobertaForMaskedLM
.
from_pretrained
(
args
.
model_name
)
prefix
=
'roberta'
else
:
raise
ValueError
(
f
'args.model_type should be "bert".'
)
state_dict
=
model
.
state_dict
()
compressed_sd
=
{}
...
...
@@ -68,20 +68,12 @@ if __name__ == '__main__':
state_dict
[
f
'
{
prefix
}
.encoder.layer.
{
teacher_idx
}
.output.LayerNorm.
{
w
}
'
]
std_idx
+=
1
if
args
.
model_type
==
'bert'
:
compressed_sd
[
f
'vocab_projector.weight'
]
=
state_dict
[
f
'cls.predictions.decoder.weight'
]
compressed_sd
[
f
'vocab_projector.bias'
]
=
state_dict
[
f
'cls.predictions.bias'
]
if
args
.
vocab_transform
:
for
w
in
[
'weight'
,
'bias'
]:
compressed_sd
[
f
'vocab_transform.
{
w
}
'
]
=
state_dict
[
f
'cls.predictions.transform.dense.
{
w
}
'
]
compressed_sd
[
f
'vocab_layer_norm.
{
w
}
'
]
=
state_dict
[
f
'cls.predictions.transform.LayerNorm.
{
w
}
'
]
elif
args
.
model_type
==
'roberta'
:
compressed_sd
[
f
'vocab_projector.weight'
]
=
state_dict
[
f
'lm_head.decoder.weight'
]
compressed_sd
[
f
'vocab_projector.bias'
]
=
state_dict
[
f
'lm_head.bias'
]
if
args
.
vocab_transform
:
for
w
in
[
'weight'
,
'bias'
]:
compressed_sd
[
f
'vocab_transform.
{
w
}
'
]
=
state_dict
[
f
'lm_head.dense.
{
w
}
'
]
compressed_sd
[
f
'vocab_layer_norm.
{
w
}
'
]
=
state_dict
[
f
'lm_head.layer_norm.
{
w
}
'
]
compressed_sd
[
f
'vocab_projector.weight'
]
=
state_dict
[
f
'cls.predictions.decoder.weight'
]
compressed_sd
[
f
'vocab_projector.bias'
]
=
state_dict
[
f
'cls.predictions.bias'
]
if
args
.
vocab_transform
:
for
w
in
[
'weight'
,
'bias'
]:
compressed_sd
[
f
'vocab_transform.
{
w
}
'
]
=
state_dict
[
f
'cls.predictions.transform.dense.
{
w
}
'
]
compressed_sd
[
f
'vocab_layer_norm.
{
w
}
'
]
=
state_dict
[
f
'cls.predictions.transform.LayerNorm.
{
w
}
'
]
print
(
f
'N layers selected for distillation:
{
std_idx
}
'
)
print
(
f
'Number of params transfered for distillation:
{
len
(
compressed_sd
.
keys
())
}
'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment