Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4d6dfbd3
"git@developer.sourcefind.cn:modelzoo/bladedisc_deepmd.git" did not exist on "56d2b104df15b509ff1c7417a923bc25312440dd"
Commit
4d6dfbd3
authored
Oct 02, 2019
by
VictorSanh
Committed by
Victor SANH
Oct 03, 2019
Browse files
update extract
parent
23edebc0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
89 additions
and
0 deletions
+89
-0
examples/distillation/scripts/extract.py
examples/distillation/scripts/extract.py
+89
-0
No files found.
examples/distillation/scripts/extract.py
0 → 100644
View file @
4d6dfbd3
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocessing script before training the distilled model.
Specific to RoBERTa -> DistilRoBERTa and GPT2 -> DistilGPT2.
"""
from
transformers
import
BertForMaskedLM
,
RobertaForMaskedLM
,
GPT2LMHeadModel
import
torch
import
argparse
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation"
)
parser
.
add_argument
(
"--model_type"
,
default
=
"roberta"
,
choices
=
[
"roberta"
,
"gpt2"
])
parser
.
add_argument
(
"--model_name"
,
default
=
'roberta-large'
,
type
=
str
)
parser
.
add_argument
(
"--dump_checkpoint"
,
default
=
'serialization_dir/tf_roberta_048131723.pth'
,
type
=
str
)
parser
.
add_argument
(
"--vocab_transform"
,
action
=
'store_true'
)
args
=
parser
.
parse_args
()
if
args
.
model_type
==
'roberta'
:
model
=
RobertaForMaskedLM
.
from_pretrained
(
args
.
model_name
)
prefix
=
'roberta'
elif
args
.
model_type
==
'gpt2'
:
model
=
GPT2LMHeadModel
.
from_pretrained
(
args
.
model_name
)
prefix
=
'transformer'
state_dict
=
model
.
state_dict
()
compressed_sd
=
{}
### Embeddings ###
if
args
.
model_type
==
'gpt2'
:
for
param_name
in
[
'wte.weight'
,
'wpe.weight'
]:
compressed_sd
[
f
'
{
prefix
}
.
{
param_name
}
'
]
=
state_dict
[
f
'
{
prefix
}
.
{
param_name
}
'
]
else
:
for
w
in
[
'word_embeddings'
,
'position_embeddings'
,
'token_type_embeddings'
]:
param_name
=
f
'
{
prefix
}
.embeddings.
{
w
}
.weight'
compressed_sd
[
param_name
]
=
state_dict
[
param_name
]
for
w
in
[
'weight'
,
'bias'
]:
param_name
=
f
'
{
prefix
}
.embeddings.LayerNorm.
{
w
}
'
compressed_sd
[
param_name
]
=
state_dict
[
param_name
]
### Transformer Blocks ###
std_idx
=
0
for
teacher_idx
in
[
0
,
2
,
4
,
7
,
9
,
11
]:
if
args
.
model_type
==
'gpt2'
:
for
layer
in
[
'ln_1'
,
'attn.c_attn'
,
'attn.c_proj'
,
'ln_2'
,
'mlp.c_fc'
,
'mlp.c_proj'
]:
for
w
in
[
'weight'
,
'bias'
]:
compressed_sd
[
f
'
{
prefix
}
.h.
{
std_idx
}
.
{
layer
}
.
{
w
}
'
]
=
\
state_dict
[
f
'
{
prefix
}
.h.
{
teacher_idx
}
.
{
layer
}
.
{
w
}
'
]
compressed_sd
[
f
'
{
prefix
}
.h.
{
std_idx
}
.attn.bias'
]
=
state_dict
[
f
'
{
prefix
}
.h.
{
teacher_idx
}
.attn.bias'
]
else
:
for
layer
in
[
'attention.self.query'
,
'attention.self.key'
,
'attention.self.value'
,
'attention.output.dense'
,
'attention.output.LayerNorm'
,
'intermediate.dense'
,
'output.dense'
,
'output.LayerNorm'
]:
for
w
in
[
'weight'
,
'bias'
]:
compressed_sd
[
f
'
{
prefix
}
.encoder.layer.
{
std_idx
}
.
{
layer
}
.
{
w
}
'
]
=
\
state_dict
[
f
'
{
prefix
}
.encoder.layer.
{
teacher_idx
}
.
{
layer
}
.
{
w
}
'
]
std_idx
+=
1
### Language Modeling Head ###s
if
args
.
model_type
==
'roberta'
:
for
layer
in
[
'lm_head.decoder.weight'
,
'lm_head.bias'
]:
compressed_sd
[
f
'
{
layer
}
'
]
=
state_dict
[
f
'
{
layer
}
'
]
if
args
.
vocab_transform
:
for
w
in
[
'weight'
,
'bias'
]:
compressed_sd
[
f
'lm_head.dense.
{
w
}
'
]
=
state_dict
[
f
'lm_head.dense.
{
w
}
'
]
compressed_sd
[
f
'lm_head.layer_norm.
{
w
}
'
]
=
state_dict
[
f
'lm_head.layer_norm.
{
w
}
'
]
elif
args
.
model_type
==
'gpt2'
:
for
w
in
[
'weight'
,
'bias'
]:
compressed_sd
[
f
'
{
prefix
}
.ln_f.
{
w
}
'
]
=
state_dict
[
f
'
{
prefix
}
.ln_f.
{
w
}
'
]
compressed_sd
[
f
'lm_head.weight'
]
=
state_dict
[
f
'lm_head.weight'
]
print
(
f
'N layers selected for distillation:
{
std_idx
}
'
)
print
(
f
'Number of params transfered for distillation:
{
len
(
compressed_sd
.
keys
())
}
'
)
print
(
f
'Save transfered checkpoint to
{
args
.
dump_checkpoint
}
.'
)
torch
.
save
(
compressed_sd
,
args
.
dump_checkpoint
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment