Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c51e533a
Commit
c51e533a
authored
Oct 02, 2019
by
VictorSanh
Committed by
Victor SANH
Oct 03, 2019
Browse files
update train.py
parent
a76c3f9c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
118 additions
and
75 deletions
+118
-75
examples/distillation/train.py
examples/distillation/train.py
+118
-75
No files found.
examples/distillation/train.py
View file @
c51e533a
...
@@ -13,7 +13,8 @@
...
@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""
"""
Training DistilBERT.
Training the distilled model.
Supported architectures include: BERT -> DistilBERT, RoBERTa -> DistilRoBERTa, GPT2 -> DistilGPT2.
"""
"""
import
os
import
os
import
argparse
import
argparse
...
@@ -23,68 +24,96 @@ import shutil
...
@@ -23,68 +24,96 @@ import shutil
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
transformers
import
BertTokenizer
,
BertForMaskedLM
,
RobertaTokenizer
,
RobertaForMaskedLM
from
transformers
import
BertConfig
,
BertForMaskedLM
,
BertTokenizer
from
transformers
import
DistilBertForMaskedLM
,
DistilBertConfig
from
transformers
import
RobertaConfig
,
RobertaForMaskedLM
,
RobertaTokenizer
from
transformers
import
DistilBertConfig
,
DistilBertForMaskedLM
,
DistilBertTokenizer
from
transformers
import
GPT2Config
,
GPT2LMHeadModel
,
GPT2Tokenizer
from
distiller
import
Distiller
from
distiller
import
Distiller
from
utils
import
git_log
,
logger
,
init_gpu_params
,
set_seed
from
utils
import
git_log
,
logger
,
init_gpu_params
,
set_seed
from
dataset
import
Dataset
from
lm_seqs_dataset
import
LmSeqsDataset
MODEL_CLASSES
=
{
'distilbert'
:
(
DistilBertConfig
,
DistilBertForMaskedLM
,
DistilBertTokenizer
),
'roberta'
:
(
RobertaConfig
,
RobertaForMaskedLM
,
RobertaTokenizer
),
'bert'
:
(
BertConfig
,
BertForMaskedLM
,
BertTokenizer
),
'gpt2'
:
(
GPT2Config
,
GPT2LMHeadModel
,
GPT2Tokenizer
)
}
def
sanity_checks
(
args
):
"""
A bunch of args sanity checks to perform even starting...
"""
assert
(
args
.
mlm
and
args
.
alpha_mlm
>
0.
)
or
(
not
args
.
mlm
and
args
.
alpha_mlm
==
0.
)
assert
(
args
.
alpha_mlm
>
0.
and
args
.
alpha_clm
==
0.
)
or
(
args
.
alpha_mlm
==
0.
and
args
.
alpha_clm
>
0.
)
if
args
.
mlm
:
assert
os
.
path
.
isfile
(
args
.
token_counts
)
assert
(
args
.
student_type
in
[
'roberta'
,
'distilbert'
])
and
(
args
.
teacher_type
in
[
'roberta'
,
'bert'
])
else
:
assert
(
args
.
student_type
in
[
'gpt2'
])
and
(
args
.
teacher_type
in
[
'gpt2'
])
assert
args
.
teacher_type
==
args
.
student_type
or
(
args
.
student_type
==
'distilbert'
and
args
.
teacher_type
==
'bert'
)
assert
os
.
path
.
isfile
(
args
.
student_config
)
if
args
.
student_pretrained_weights
is
not
None
:
assert
os
.
path
.
isfile
(
args
.
student_pretrained_weights
)
if
args
.
freeze_token_type_embds
:
assert
args
.
student_type
in
[
'roberta'
]
assert
args
.
alpha_ce
>=
0.
assert
args
.
alpha_mlm
>=
0.
assert
args
.
alpha_clm
>=
0.
assert
args
.
alpha_mse
>=
0.
assert
args
.
alpha_cos
>=
0.
assert
args
.
alpha_ce
+
args
.
alpha_mlm
+
args
.
alpha_clm
+
args
.
alpha_mse
+
args
.
alpha_cos
>
0.
def
freeze_pos_embeddings
(
student
,
args
):
if
args
.
student_type
==
'roberta'
:
student
.
roberta
.
embeddings
.
position_embeddings
.
weight
.
requires_grad
=
False
elif
args
.
student_type
==
'gpt2'
:
student
.
transformer
.
wpe
.
weight
.
requires_grad
=
False
def
freeze_token_type_embeddings
(
student
,
args
):
if
args
.
student_type
==
'roberta'
:
student
.
roberta
.
embeddings
.
token_type_embeddings
.
weight
.
requires_grad
=
False
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Training"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Training"
)
parser
.
add_argument
(
"--force"
,
action
=
'store_true'
,
help
=
"Overwrite dump_path if it already exists."
)
parser
.
add_argument
(
"--dump_path"
,
type
=
str
,
required
=
True
,
parser
.
add_argument
(
"--dump_path"
,
type
=
str
,
required
=
True
,
help
=
"The output directory (log, checkpoints, parameters, etc.)"
)
help
=
"The output directory (log, checkpoints, parameters, etc.)"
)
parser
.
add_argument
(
"--data_file"
,
type
=
str
,
required
=
True
,
parser
.
add_argument
(
"--data_file"
,
type
=
str
,
required
=
True
,
help
=
"The binarized file (tokenized + tokens_to_ids) and grouped by sequence."
)
help
=
"The binarized file (tokenized + tokens_to_ids) and grouped by sequence."
)
parser
.
add_argument
(
"--token_counts"
,
type
=
str
,
required
=
True
,
help
=
"The token counts in the data_file for MLM."
)
parser
.
add_argument
(
"--force"
,
action
=
'store_true'
,
help
=
"Overwrite dump_path if it already exists."
)
parser
.
add_argument
(
"--vocab_size"
,
default
=
30522
,
type
=
int
,
parser
.
add_argument
(
"--student_type"
,
type
=
str
,
choices
=
[
"distilbert"
,
"roberta"
,
"gpt2"
],
required
=
True
,
help
=
"The vocabulary size."
)
help
=
"The student type (DistilBERT, RoBERTa)."
)
parser
.
add_argument
(
"--max_position_embeddings"
,
default
=
512
,
type
=
int
,
parser
.
add_argument
(
"--student_config"
,
type
=
str
,
required
=
True
,
help
=
"Maximum sequence length we can model (including [CLS] and [SEP])."
)
help
=
"Path to the student configuration."
)
parser
.
add_argument
(
"--sinusoidal_pos_embds"
,
action
=
'store_false'
,
parser
.
add_argument
(
"--student_pretrained_weights"
,
default
=
None
,
type
=
str
,
help
=
"If true, the position embeddings are simply fixed with sinusoidal embeddings."
)
parser
.
add_argument
(
"--n_layers"
,
default
=
6
,
type
=
int
,
help
=
"Number of Transformer blocks."
)
parser
.
add_argument
(
"--n_heads"
,
default
=
12
,
type
=
int
,
help
=
"Number of heads in the self-attention module."
)
parser
.
add_argument
(
"--dim"
,
default
=
768
,
type
=
int
,
help
=
"Dimension through the network. Must be divisible by n_heads"
)
parser
.
add_argument
(
"--hidden_dim"
,
default
=
3072
,
type
=
int
,
help
=
"Intermediate dimension in the FFN."
)
parser
.
add_argument
(
"--dropout"
,
default
=
0.1
,
type
=
float
,
help
=
"Dropout."
)
parser
.
add_argument
(
"--attention_dropout"
,
default
=
0.1
,
type
=
float
,
help
=
"Dropout in self-attention."
)
parser
.
add_argument
(
"--activation"
,
default
=
'gelu'
,
type
=
str
,
help
=
"Activation to use in self-attention"
)
parser
.
add_argument
(
"--tie_weights_"
,
action
=
'store_false'
,
help
=
"If true, we tie the embeddings matrix with the projection over the vocabulary matrix. Default is true."
)
parser
.
add_argument
(
"--from_pretrained_weights"
,
default
=
None
,
type
=
str
,
help
=
"Load student initialization checkpoint."
)
help
=
"Load student initialization checkpoint."
)
parser
.
add_argument
(
"--from_pretrained_config"
,
default
=
None
,
type
=
str
,
help
=
"Load student initialization architecture config."
)
parser
.
add_argument
(
"--teacher_type"
,
choices
=
[
"bert"
,
"roberta"
,
"gpt2"
],
required
=
True
,
parser
.
add_argument
(
"--teacher_type"
,
default
=
"bert"
,
choices
=
[
"bert"
,
"roberta"
],
help
=
"Teacher type (BERT, RoBERTa)."
)
help
=
"Teacher type (BERT, RoBERTa)."
)
parser
.
add_argument
(
"--teacher_name"
,
default
=
"bert-base-uncased"
,
type
=
str
,
parser
.
add_argument
(
"--teacher_name"
,
type
=
str
,
required
=
True
,
help
=
"The teacher model."
)
help
=
"The teacher model."
)
parser
.
add_argument
(
"--temperature"
,
default
=
2.
,
type
=
float
,
parser
.
add_argument
(
"--temperature"
,
default
=
2.
,
type
=
float
,
help
=
"Temperature for the softmax temperature."
)
help
=
"Temperature for the softmax temperature."
)
parser
.
add_argument
(
"--alpha_ce"
,
default
=
0.5
,
type
=
float
,
parser
.
add_argument
(
"--alpha_ce"
,
default
=
0.5
,
type
=
float
,
help
=
"Linear weight for the distillation loss. Must be >=0."
)
help
=
"Linear weight for the distillation loss. Must be >=0."
)
parser
.
add_argument
(
"--alpha_mlm"
,
default
=
0.5
,
type
=
float
,
parser
.
add_argument
(
"--alpha_mlm"
,
default
=
0.0
,
type
=
float
,
help
=
"Linear weight for the MLM loss. Must be >=0."
)
help
=
"Linear weight for the MLM loss. Must be >=0. Should be used in coonjunction with `mlm` flag."
)
parser
.
add_argument
(
"--alpha_clm"
,
default
=
0.5
,
type
=
float
,
help
=
"Linear weight for the CLM loss. Must be >=0."
)
parser
.
add_argument
(
"--alpha_mse"
,
default
=
0.0
,
type
=
float
,
parser
.
add_argument
(
"--alpha_mse"
,
default
=
0.0
,
type
=
float
,
help
=
"Linear weight of the MSE loss. Must be >=0."
)
help
=
"Linear weight of the MSE loss. Must be >=0."
)
parser
.
add_argument
(
"--alpha_cos"
,
default
=
0.0
,
type
=
float
,
parser
.
add_argument
(
"--alpha_cos"
,
default
=
0.0
,
type
=
float
,
help
=
"Linear weight of the cosine embedding loss. Must be >=0."
)
help
=
"Linear weight of the cosine embedding loss. Must be >=0."
)
parser
.
add_argument
(
"--mlm"
,
action
=
"store_true"
,
help
=
"The LM step: MLM or CLM. If `mlm` is True, the MLM is used over CLM."
)
parser
.
add_argument
(
"--mlm_mask_prop"
,
default
=
0.15
,
type
=
float
,
parser
.
add_argument
(
"--mlm_mask_prop"
,
default
=
0.15
,
type
=
float
,
help
=
"Proportion of tokens for which we need to make a prediction."
)
help
=
"Proportion of tokens for which we need to make a prediction."
)
parser
.
add_argument
(
"--word_mask"
,
default
=
0.8
,
type
=
float
,
parser
.
add_argument
(
"--word_mask"
,
default
=
0.8
,
type
=
float
,
...
@@ -95,17 +124,20 @@ def main():
...
@@ -95,17 +124,20 @@ def main():
help
=
"Proportion of tokens to randomly replace."
)
help
=
"Proportion of tokens to randomly replace."
)
parser
.
add_argument
(
"--mlm_smoothing"
,
default
=
0.7
,
type
=
float
,
parser
.
add_argument
(
"--mlm_smoothing"
,
default
=
0.7
,
type
=
float
,
help
=
"Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec)."
)
help
=
"Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec)."
)
parser
.
add_argument
(
"--token_counts"
,
type
=
str
,
help
=
"The token counts in the data_file for MLM."
)
parser
.
add_argument
(
"--restrict_ce_to_mask"
,
action
=
'store_true'
,
parser
.
add_argument
(
"--restrict_ce_to_mask"
,
action
=
'store_true'
,
help
=
"If true, compute the distilation loss only the [MLM] prediction distribution."
)
help
=
"If true, compute the distilation loss only the [MLM] prediction distribution."
)
parser
.
add_argument
(
"--freeze_pos_embs"
,
action
=
"store_true"
,
help
=
"Freeze positional embeddings during distillation. For student_type in ['roberta', 'gpt2'] only."
)
parser
.
add_argument
(
"--freeze_token_type_embds"
,
action
=
"store_true"
,
help
=
"Freeze token type embeddings during distillation if existent. For student_type in ['roberta'] only."
)
parser
.
add_argument
(
"--n_epoch"
,
type
=
int
,
default
=
3
,
parser
.
add_argument
(
"--n_epoch"
,
type
=
int
,
default
=
3
,
help
=
"Number of pass on the whole dataset."
)
help
=
"Number of pass on the whole dataset."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
5
,
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
5
,
help
=
"Batch size (for each process)."
)
help
=
"Batch size (for each process)."
)
parser
.
add_argument
(
"--tokens_per_batch"
,
type
=
int
,
default
=-
1
,
help
=
"If specified, modify the batches so that they have approximately this number of tokens."
)
parser
.
add_argument
(
"--shuffle"
,
action
=
'store_false'
,
help
=
"If true, shuffle the sequence order. Default is true."
)
parser
.
add_argument
(
"--group_by_size"
,
action
=
'store_false'
,
parser
.
add_argument
(
"--group_by_size"
,
action
=
'store_false'
,
help
=
"If true, group sequences that have similar length into the same batch. Default is true."
)
help
=
"If true, group sequences that have similar length into the same batch. Default is true."
)
...
@@ -141,6 +173,7 @@ def main():
...
@@ -141,6 +173,7 @@ def main():
parser
.
add_argument
(
"--checkpoint_interval"
,
type
=
int
,
default
=
4000
,
parser
.
add_argument
(
"--checkpoint_interval"
,
type
=
int
,
default
=
4000
,
help
=
"Checkpoint interval."
)
help
=
"Checkpoint interval."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
sanity_checks
(
args
)
## ARGS ##
## ARGS ##
...
@@ -164,21 +197,19 @@ def main():
...
@@ -164,21 +197,19 @@ def main():
with
open
(
os
.
path
.
join
(
args
.
dump_path
,
'parameters.json'
),
'w'
)
as
f
:
with
open
(
os
.
path
.
join
(
args
.
dump_path
,
'parameters.json'
),
'w'
)
as
f
:
json
.
dump
(
vars
(
args
),
f
,
indent
=
4
)
json
.
dump
(
vars
(
args
),
f
,
indent
=
4
)
git_log
(
args
.
dump_path
)
git_log
(
args
.
dump_path
)
assert
(
args
.
from_pretrained_weights
is
None
and
args
.
from_pretrained_config
is
None
)
or
\
(
args
.
from_pretrained_weights
is
not
None
and
args
.
from_pretrained_config
is
not
None
)
student_config_class
,
student_model_class
,
_
=
MODEL_CLASSES
[
args
.
student_type
]
teacher_config_class
,
teacher_model_class
,
teacher_tokenizer_class
=
MODEL_CLASSES
[
args
.
teacher_type
]
### TOKENIZER ###
### TOKENIZER ###
if
args
.
teacher_type
==
'bert'
:
tokenizer
=
teacher_tokenizer_class
.
from_pretrained
(
args
.
teacher_name
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
teacher_name
)
elif
args
.
teacher_type
==
'roberta'
:
tokenizer
=
RobertaTokenizer
.
from_pretrained
(
args
.
teacher_name
)
special_tok_ids
=
{}
special_tok_ids
=
{}
for
tok_name
,
tok_symbol
in
tokenizer
.
special_tokens_map
.
items
():
for
tok_name
,
tok_symbol
in
tokenizer
.
special_tokens_map
.
items
():
idx
=
tokenizer
.
all_special_tokens
.
index
(
tok_symbol
)
idx
=
tokenizer
.
all_special_tokens
.
index
(
tok_symbol
)
special_tok_ids
[
tok_name
]
=
tokenizer
.
all_special_ids
[
idx
]
special_tok_ids
[
tok_name
]
=
tokenizer
.
all_special_ids
[
idx
]
logger
.
info
(
f
'Special tokens
{
special_tok_ids
}
'
)
logger
.
info
(
f
'Special tokens
{
special_tok_ids
}
'
)
args
.
special_tok_ids
=
special_tok_ids
args
.
special_tok_ids
=
special_tok_ids
args
.
max_model_input_size
=
tokenizer
.
max_model_input_sizes
[
args
.
teacher_name
]
## DATA LOADER ##
## DATA LOADER ##
...
@@ -187,35 +218,34 @@ def main():
...
@@ -187,35 +218,34 @@ def main():
data
=
pickle
.
load
(
fp
)
data
=
pickle
.
load
(
fp
)
assert
os
.
path
.
isfile
(
args
.
token_counts
)
if
args
.
mlm
:
logger
.
info
(
f
'Loading token counts from
{
args
.
token_counts
}
(already pre-computed)'
)
logger
.
info
(
f
'Loading token counts from
{
args
.
token_counts
}
(already pre-computed)'
)
with
open
(
args
.
token_counts
,
'rb'
)
as
fp
:
with
open
(
args
.
token_counts
,
'rb'
)
as
fp
:
counts
=
pickle
.
load
(
fp
)
counts
=
pickle
.
load
(
fp
)
assert
len
(
counts
)
==
args
.
vocab_size
token_probs
=
np
.
maximum
(
counts
,
1
)
**
-
args
.
mlm_smoothing
token_probs
=
np
.
maximum
(
counts
,
1
)
**
-
args
.
mlm_smoothing
for
idx
in
special_tok_ids
.
values
():
for
idx
in
special_tok_ids
.
values
():
token_probs
[
idx
]
=
0.
# do not predict special tokens
token_probs
[
idx
]
=
0.
# do not predict special tokens
token_probs
=
torch
.
from_numpy
(
token_probs
)
token_probs
=
torch
.
from_numpy
(
token_probs
)
else
:
token_probs
=
None
train_
dataloader
=
Dataset
(
params
=
args
,
data
=
data
)
train_
lm_seq_dataset
=
LmSeqs
Dataset
(
params
=
args
,
data
=
data
)
logger
.
info
(
f
'Data loader created.'
)
logger
.
info
(
f
'Data loader created.'
)
## STUDENT ##
## STUDENT ##
if
args
.
from_pretrained_weights
is
not
None
:
logger
.
info
(
f
'Loading student config from
{
args
.
student_config
}
'
)
assert
os
.
path
.
isfile
(
args
.
from_pretrained_weights
)
stu_architecture_config
=
student_config_class
.
from_pretrained
(
args
.
student_config
)
assert
os
.
path
.
isfile
(
args
.
from_pretrained_config
)
stu_architecture_config
.
output_hidden_states
=
True
logger
.
info
(
f
'Loading pretrained weights from
{
args
.
from_pretrained_weights
}
'
)
logger
.
info
(
f
'Loading pretrained config from
{
args
.
from_pretrained_config
}
'
)
if
args
.
student_pretrained_weights
is
not
None
:
stu_architecture_config
=
DistilBertConfig
.
from_json_file
(
args
.
from_pretrained_config
)
logger
.
info
(
f
'Loading pretrained weights from
{
args
.
student_pretrained_weights
}
'
)
stu_architecture_config
.
output_hidden_states
=
True
student
=
student_model_class
.
from_pretrained
(
args
.
student_pretrained_weights
,
student
=
DistilBertForMaskedLM
.
from_pretrained
(
args
.
from_pretrained_weights
,
config
=
stu_architecture_config
)
config
=
stu_architecture_config
)
else
:
else
:
args
.
vocab_size_or_config_json_file
=
args
.
vocab_size
student
=
student_model_class
(
stu_architecture_config
)
stu_architecture_config
=
DistilBertConfig
(
**
vars
(
args
),
output_hidden_states
=
True
)
student
=
DistilBertForMaskedLM
(
stu_architecture_config
)
if
args
.
n_gpu
>
0
:
if
args
.
n_gpu
>
0
:
...
@@ -224,18 +254,31 @@ def main():
...
@@ -224,18 +254,31 @@ def main():
## TEACHER ##
## TEACHER ##
if
args
.
teacher_type
==
'bert'
:
teacher
=
teacher_model_class
.
from_pretrained
(
args
.
teacher_name
,
output_hidden_states
=
True
)
teacher
=
BertForMaskedLM
.
from_pretrained
(
args
.
teacher_name
,
output_hidden_states
=
True
)
elif
args
.
teacher_type
==
'roberta'
:
teacher
=
RobertaForMaskedLM
.
from_pretrained
(
args
.
teacher_name
,
output_hidden_states
=
True
)
if
args
.
n_gpu
>
0
:
if
args
.
n_gpu
>
0
:
teacher
.
to
(
f
'cuda:
{
args
.
local_rank
}
'
)
teacher
.
to
(
f
'cuda:
{
args
.
local_rank
}
'
)
logger
.
info
(
f
'Teacher loaded from
{
args
.
teacher_name
}
.'
)
logger
.
info
(
f
'Teacher loaded from
{
args
.
teacher_name
}
.'
)
## FREEZING ##
if
args
.
freeze_pos_embs
:
freeze_pos_embeddings
(
student
,
args
)
if
args
.
freeze_token_type_embds
:
freeze_token_type_embeddings
(
student
,
args
)
## SANITY CHECKS ##
assert
student
.
config
.
vocab_size
==
teacher
.
config
.
vocab_size
assert
student
.
config
.
hidden_size
==
teacher
.
config
.
hidden_size
assert
student
.
config
.
max_position_embeddings
==
teacher
.
config
.
max_position_embeddings
if
args
.
mlm
:
assert
token_probs
.
size
(
0
)
==
stu_architecture_config
.
vocab_size
## DISTILLER ##
## DISTILLER ##
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
distiller
=
Distiller
(
params
=
args
,
distiller
=
Distiller
(
params
=
args
,
data
loader
=
train_
dataloader
,
data
set
=
train_
lm_seq_dataset
,
token_probs
=
token_probs
,
token_probs
=
token_probs
,
student
=
student
,
student
=
student
,
teacher
=
teacher
)
teacher
=
teacher
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment