Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c51e533a
"...composable_kernel_rocm.git" did not exist on "80fc636bdfefb7aa1d60fb899e54e7111184747f"
Commit
c51e533a
authored
Oct 02, 2019
by
VictorSanh
Committed by
Victor SANH
Oct 03, 2019
Browse files
update train.py
parent
a76c3f9c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
118 additions
and
75 deletions
+118
-75
examples/distillation/train.py
examples/distillation/train.py
+118
-75
No files found.
examples/distillation/train.py
View file @
c51e533a
...
@@ -13,7 +13,8 @@
...
@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""
"""
Training DistilBERT.
Training the distilled model.
Supported architectures include: BERT -> DistilBERT, RoBERTa -> DistilRoBERTa, GPT2 -> DistilGPT2.
"""
"""
import
os
import
os
import
argparse
import
argparse
...
@@ -23,68 +24,96 @@ import shutil
...
@@ -23,68 +24,96 @@ import shutil
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
transformers
import
BertTokenizer
,
BertForMaskedLM
,
RobertaTokenizer
,
RobertaForMaskedLM
from
transformers
import
BertConfig
,
BertForMaskedLM
,
BertTokenizer
from
transformers
import
DistilBertForMaskedLM
,
DistilBertConfig
from
transformers
import
RobertaConfig
,
RobertaForMaskedLM
,
RobertaTokenizer
from
transformers
import
DistilBertConfig
,
DistilBertForMaskedLM
,
DistilBertTokenizer
from
transformers
import
GPT2Config
,
GPT2LMHeadModel
,
GPT2Tokenizer
from
distiller
import
Distiller
from
distiller
import
Distiller
from
utils
import
git_log
,
logger
,
init_gpu_params
,
set_seed
from
utils
import
git_log
,
logger
,
init_gpu_params
,
set_seed
from
dataset
import
Dataset
from
lm_seqs_dataset
import
LmSeqsDataset
MODEL_CLASSES
=
{
'distilbert'
:
(
DistilBertConfig
,
DistilBertForMaskedLM
,
DistilBertTokenizer
),
'roberta'
:
(
RobertaConfig
,
RobertaForMaskedLM
,
RobertaTokenizer
),
'bert'
:
(
BertConfig
,
BertForMaskedLM
,
BertTokenizer
),
'gpt2'
:
(
GPT2Config
,
GPT2LMHeadModel
,
GPT2Tokenizer
)
}
def
sanity_checks
(
args
):
"""
A bunch of args sanity checks to perform even starting...
"""
assert
(
args
.
mlm
and
args
.
alpha_mlm
>
0.
)
or
(
not
args
.
mlm
and
args
.
alpha_mlm
==
0.
)
assert
(
args
.
alpha_mlm
>
0.
and
args
.
alpha_clm
==
0.
)
or
(
args
.
alpha_mlm
==
0.
and
args
.
alpha_clm
>
0.
)
if
args
.
mlm
:
assert
os
.
path
.
isfile
(
args
.
token_counts
)
assert
(
args
.
student_type
in
[
'roberta'
,
'distilbert'
])
and
(
args
.
teacher_type
in
[
'roberta'
,
'bert'
])
else
:
assert
(
args
.
student_type
in
[
'gpt2'
])
and
(
args
.
teacher_type
in
[
'gpt2'
])
assert
args
.
teacher_type
==
args
.
student_type
or
(
args
.
student_type
==
'distilbert'
and
args
.
teacher_type
==
'bert'
)
assert
os
.
path
.
isfile
(
args
.
student_config
)
if
args
.
student_pretrained_weights
is
not
None
:
assert
os
.
path
.
isfile
(
args
.
student_pretrained_weights
)
if
args
.
freeze_token_type_embds
:
assert
args
.
student_type
in
[
'roberta'
]
assert
args
.
alpha_ce
>=
0.
assert
args
.
alpha_mlm
>=
0.
assert
args
.
alpha_clm
>=
0.
assert
args
.
alpha_mse
>=
0.
assert
args
.
alpha_cos
>=
0.
assert
args
.
alpha_ce
+
args
.
alpha_mlm
+
args
.
alpha_clm
+
args
.
alpha_mse
+
args
.
alpha_cos
>
0.
def
freeze_pos_embeddings
(
student
,
args
):
if
args
.
student_type
==
'roberta'
:
student
.
roberta
.
embeddings
.
position_embeddings
.
weight
.
requires_grad
=
False
elif
args
.
student_type
==
'gpt2'
:
student
.
transformer
.
wpe
.
weight
.
requires_grad
=
False
def
freeze_token_type_embeddings
(
student
,
args
):
if
args
.
student_type
==
'roberta'
:
student
.
roberta
.
embeddings
.
token_type_embeddings
.
weight
.
requires_grad
=
False
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Training"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Training"
)
parser
.
add_argument
(
"--force"
,
action
=
'store_true'
,
help
=
"Overwrite dump_path if it already exists."
)
parser
.
add_argument
(
"--dump_path"
,
type
=
str
,
required
=
True
,
parser
.
add_argument
(
"--dump_path"
,
type
=
str
,
required
=
True
,
help
=
"The output directory (log, checkpoints, parameters, etc.)"
)
help
=
"The output directory (log, checkpoints, parameters, etc.)"
)
parser
.
add_argument
(
"--data_file"
,
type
=
str
,
required
=
True
,
parser
.
add_argument
(
"--data_file"
,
type
=
str
,
required
=
True
,
help
=
"The binarized file (tokenized + tokens_to_ids) and grouped by sequence."
)
help
=
"The binarized file (tokenized + tokens_to_ids) and grouped by sequence."
)
parser
.
add_argument
(
"--token_counts"
,
type
=
str
,
required
=
True
,
help
=
"The token counts in the data_file for MLM."
)
parser
.
add_argument
(
"--force"
,
action
=
'store_true'
,
help
=
"Overwrite dump_path if it already exists."
)
parser
.
add_argument
(
"--vocab_size"
,
default
=
30522
,
type
=
int
,
parser
.
add_argument
(
"--student_type"
,
type
=
str
,
choices
=
[
"distilbert"
,
"roberta"
,
"gpt2"
],
required
=
True
,
help
=
"The vocabulary size."
)
help
=
"The student type (DistilBERT, RoBERTa)."
)
parser
.
add_argument
(
"--max_position_embeddings"
,
default
=
512
,
type
=
int
,
parser
.
add_argument
(
"--student_config"
,
type
=
str
,
required
=
True
,
help
=
"Maximum sequence length we can model (including [CLS] and [SEP])."
)
help
=
"Path to the student configuration."
)
parser
.
add_argument
(
"--sinusoidal_pos_embds"
,
action
=
'store_false'
,
parser
.
add_argument
(
"--student_pretrained_weights"
,
default
=
None
,
type
=
str
,
help
=
"If true, the position embeddings are simply fixed with sinusoidal embeddings."
)
parser
.
add_argument
(
"--n_layers"
,
default
=
6
,
type
=
int
,
help
=
"Number of Transformer blocks."
)
parser
.
add_argument
(
"--n_heads"
,
default
=
12
,
type
=
int
,
help
=
"Number of heads in the self-attention module."
)
parser
.
add_argument
(
"--dim"
,
default
=
768
,
type
=
int
,
help
=
"Dimension through the network. Must be divisible by n_heads"
)
parser
.
add_argument
(
"--hidden_dim"
,
default
=
3072
,
type
=
int
,
help
=
"Intermediate dimension in the FFN."
)
parser
.
add_argument
(
"--dropout"
,
default
=
0.1
,
type
=
float
,
help
=
"Dropout."
)
parser
.
add_argument
(
"--attention_dropout"
,
default
=
0.1
,
type
=
float
,
help
=
"Dropout in self-attention."
)
parser
.
add_argument
(
"--activation"
,
default
=
'gelu'
,
type
=
str
,
help
=
"Activation to use in self-attention"
)
parser
.
add_argument
(
"--tie_weights_"
,
action
=
'store_false'
,
help
=
"If true, we tie the embeddings matrix with the projection over the vocabulary matrix. Default is true."
)
parser
.
add_argument
(
"--from_pretrained_weights"
,
default
=
None
,
type
=
str
,
help
=
"Load student initialization checkpoint."
)
help
=
"Load student initialization checkpoint."
)
parser
.
add_argument
(
"--from_pretrained_config"
,
default
=
None
,
type
=
str
,
help
=
"Load student initialization architecture config."
)
parser
.
add_argument
(
"--teacher_type"
,
choices
=
[
"bert"
,
"roberta"
,
"gpt2"
],
required
=
True
,
parser
.
add_argument
(
"--teacher_type"
,
default
=
"bert"
,
choices
=
[
"bert"
,
"roberta"
],
help
=
"Teacher type (BERT, RoBERTa)."
)
help
=
"Teacher type (BERT, RoBERTa)."
)
parser
.
add_argument
(
"--teacher_name"
,
default
=
"bert-base-uncased"
,
type
=
str
,
parser
.
add_argument
(
"--teacher_name"
,
type
=
str
,
required
=
True
,
help
=
"The teacher model."
)
help
=
"The teacher model."
)
parser
.
add_argument
(
"--temperature"
,
default
=
2.
,
type
=
float
,
parser
.
add_argument
(
"--temperature"
,
default
=
2.
,
type
=
float
,
help
=
"Temperature for the softmax temperature."
)
help
=
"Temperature for the softmax temperature."
)
parser
.
add_argument
(
"--alpha_ce"
,
default
=
0.5
,
type
=
float
,
parser
.
add_argument
(
"--alpha_ce"
,
default
=
0.5
,
type
=
float
,
help
=
"Linear weight for the distillation loss. Must be >=0."
)
help
=
"Linear weight for the distillation loss. Must be >=0."
)
parser
.
add_argument
(
"--alpha_mlm"
,
default
=
0.5
,
type
=
float
,
parser
.
add_argument
(
"--alpha_mlm"
,
default
=
0.0
,
type
=
float
,
help
=
"Linear weight for the MLM loss. Must be >=0."
)
help
=
"Linear weight for the MLM loss. Must be >=0. Should be used in coonjunction with `mlm` flag."
)
parser
.
add_argument
(
"--alpha_clm"
,
default
=
0.5
,
type
=
float
,
help
=
"Linear weight for the CLM loss. Must be >=0."
)
parser
.
add_argument
(
"--alpha_mse"
,
default
=
0.0
,
type
=
float
,
parser
.
add_argument
(
"--alpha_mse"
,
default
=
0.0
,
type
=
float
,
help
=
"Linear weight of the MSE loss. Must be >=0."
)
help
=
"Linear weight of the MSE loss. Must be >=0."
)
parser
.
add_argument
(
"--alpha_cos"
,
default
=
0.0
,
type
=
float
,
parser
.
add_argument
(
"--alpha_cos"
,
default
=
0.0
,
type
=
float
,
help
=
"Linear weight of the cosine embedding loss. Must be >=0."
)
help
=
"Linear weight of the cosine embedding loss. Must be >=0."
)
parser
.
add_argument
(
"--mlm"
,
action
=
"store_true"
,
help
=
"The LM step: MLM or CLM. If `mlm` is True, the MLM is used over CLM."
)
parser
.
add_argument
(
"--mlm_mask_prop"
,
default
=
0.15
,
type
=
float
,
parser
.
add_argument
(
"--mlm_mask_prop"
,
default
=
0.15
,
type
=
float
,
help
=
"Proportion of tokens for which we need to make a prediction."
)
help
=
"Proportion of tokens for which we need to make a prediction."
)
parser
.
add_argument
(
"--word_mask"
,
default
=
0.8
,
type
=
float
,
parser
.
add_argument
(
"--word_mask"
,
default
=
0.8
,
type
=
float
,
...
@@ -95,17 +124,20 @@ def main():
...
@@ -95,17 +124,20 @@ def main():
help
=
"Proportion of tokens to randomly replace."
)
help
=
"Proportion of tokens to randomly replace."
)
parser
.
add_argument
(
"--mlm_smoothing"
,
default
=
0.7
,
type
=
float
,
parser
.
add_argument
(
"--mlm_smoothing"
,
default
=
0.7
,
type
=
float
,
help
=
"Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec)."
)
help
=
"Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec)."
)
parser
.
add_argument
(
"--token_counts"
,
type
=
str
,
help
=
"The token counts in the data_file for MLM."
)
parser
.
add_argument
(
"--restrict_ce_to_mask"
,
action
=
'store_true'
,
parser
.
add_argument
(
"--restrict_ce_to_mask"
,
action
=
'store_true'
,
help
=
"If true, compute the distilation loss only the [MLM] prediction distribution."
)
help
=
"If true, compute the distilation loss only the [MLM] prediction distribution."
)
parser
.
add_argument
(
"--freeze_pos_embs"
,
action
=
"store_true"
,
help
=
"Freeze positional embeddings during distillation. For student_type in ['roberta', 'gpt2'] only."
)
parser
.
add_argument
(
"--freeze_token_type_embds"
,
action
=
"store_true"
,
help
=
"Freeze token type embeddings during distillation if existent. For student_type in ['roberta'] only."
)
parser
.
add_argument
(
"--n_epoch"
,
type
=
int
,
default
=
3
,
parser
.
add_argument
(
"--n_epoch"
,
type
=
int
,
default
=
3
,
help
=
"Number of pass on the whole dataset."
)
help
=
"Number of pass on the whole dataset."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
5
,
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
5
,
help
=
"Batch size (for each process)."
)
help
=
"Batch size (for each process)."
)
parser
.
add_argument
(
"--tokens_per_batch"
,
type
=
int
,
default
=-
1
,
help
=
"If specified, modify the batches so that they have approximately this number of tokens."
)
parser
.
add_argument
(
"--shuffle"
,
action
=
'store_false'
,
help
=
"If true, shuffle the sequence order. Default is true."
)
parser
.
add_argument
(
"--group_by_size"
,
action
=
'store_false'
,
parser
.
add_argument
(
"--group_by_size"
,
action
=
'store_false'
,
help
=
"If true, group sequences that have similar length into the same batch. Default is true."
)
help
=
"If true, group sequences that have similar length into the same batch. Default is true."
)
...
@@ -141,6 +173,7 @@ def main():
...
@@ -141,6 +173,7 @@ def main():
parser
.
add_argument
(
"--checkpoint_interval"
,
type
=
int
,
default
=
4000
,
parser
.
add_argument
(
"--checkpoint_interval"
,
type
=
int
,
default
=
4000
,
help
=
"Checkpoint interval."
)
help
=
"Checkpoint interval."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
sanity_checks
(
args
)
## ARGS ##
## ARGS ##
...
@@ -164,21 +197,19 @@ def main():
...
@@ -164,21 +197,19 @@ def main():
with
open
(
os
.
path
.
join
(
args
.
dump_path
,
'parameters.json'
),
'w'
)
as
f
:
with
open
(
os
.
path
.
join
(
args
.
dump_path
,
'parameters.json'
),
'w'
)
as
f
:
json
.
dump
(
vars
(
args
),
f
,
indent
=
4
)
json
.
dump
(
vars
(
args
),
f
,
indent
=
4
)
git_log
(
args
.
dump_path
)
git_log
(
args
.
dump_path
)
assert
(
args
.
from_pretrained_weights
is
None
and
args
.
from_pretrained_config
is
None
)
or
\
(
args
.
from_pretrained_weights
is
not
None
and
args
.
from_pretrained_config
is
not
None
)
student_config_class
,
student_model_class
,
_
=
MODEL_CLASSES
[
args
.
student_type
]
teacher_config_class
,
teacher_model_class
,
teacher_tokenizer_class
=
MODEL_CLASSES
[
args
.
teacher_type
]
### TOKENIZER ###
### TOKENIZER ###
if
args
.
teacher_type
==
'bert'
:
tokenizer
=
teacher_tokenizer_class
.
from_pretrained
(
args
.
teacher_name
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
teacher_name
)
elif
args
.
teacher_type
==
'roberta'
:
tokenizer
=
RobertaTokenizer
.
from_pretrained
(
args
.
teacher_name
)
special_tok_ids
=
{}
special_tok_ids
=
{}
for
tok_name
,
tok_symbol
in
tokenizer
.
special_tokens_map
.
items
():
for
tok_name
,
tok_symbol
in
tokenizer
.
special_tokens_map
.
items
():
idx
=
tokenizer
.
all_special_tokens
.
index
(
tok_symbol
)
idx
=
tokenizer
.
all_special_tokens
.
index
(
tok_symbol
)
special_tok_ids
[
tok_name
]
=
tokenizer
.
all_special_ids
[
idx
]
special_tok_ids
[
tok_name
]
=
tokenizer
.
all_special_ids
[
idx
]
logger
.
info
(
f
'Special tokens
{
special_tok_ids
}
'
)
logger
.
info
(
f
'Special tokens
{
special_tok_ids
}
'
)
args
.
special_tok_ids
=
special_tok_ids
args
.
special_tok_ids
=
special_tok_ids
args
.
max_model_input_size
=
tokenizer
.
max_model_input_sizes
[
args
.
teacher_name
]
## DATA LOADER ##
## DATA LOADER ##
...
@@ -187,35 +218,34 @@ def main():
...
@@ -187,35 +218,34 @@ def main():
data
=
pickle
.
load
(
fp
)
data
=
pickle
.
load
(
fp
)
assert
os
.
path
.
isfile
(
args
.
token_counts
)
if
args
.
mlm
:
logger
.
info
(
f
'Loading token counts from
{
args
.
token_counts
}
(already pre-computed)'
)
logger
.
info
(
f
'Loading token counts from
{
args
.
token_counts
}
(already pre-computed)'
)
with
open
(
args
.
token_counts
,
'rb'
)
as
fp
:
with
open
(
args
.
token_counts
,
'rb'
)
as
fp
:
counts
=
pickle
.
load
(
fp
)
counts
=
pickle
.
load
(
fp
)
assert
len
(
counts
)
==
args
.
vocab_size
token_probs
=
np
.
maximum
(
counts
,
1
)
**
-
args
.
mlm_smoothing
token_probs
=
np
.
maximum
(
counts
,
1
)
**
-
args
.
mlm_smoothing
for
idx
in
special_tok_ids
.
values
():
for
idx
in
special_tok_ids
.
values
():
token_probs
[
idx
]
=
0.
# do not predict special tokens
token_probs
[
idx
]
=
0.
# do not predict special tokens
token_probs
=
torch
.
from_numpy
(
token_probs
)
token_probs
=
torch
.
from_numpy
(
token_probs
)
else
:
token_probs
=
None
train_
dataloader
=
Dataset
(
params
=
args
,
data
=
data
)
train_
lm_seq_dataset
=
LmSeqs
Dataset
(
params
=
args
,
data
=
data
)
logger
.
info
(
f
'Data loader created.'
)
logger
.
info
(
f
'Data loader created.'
)
## STUDENT ##
## STUDENT ##
if
args
.
from_pretrained_weights
is
not
None
:
logger
.
info
(
f
'Loading student config from
{
args
.
student_config
}
'
)
assert
os
.
path
.
isfile
(
args
.
from_pretrained_weights
)
stu_architecture_config
=
student_config_class
.
from_pretrained
(
args
.
student_config
)
assert
os
.
path
.
isfile
(
args
.
from_pretrained_config
)
stu_architecture_config
.
output_hidden_states
=
True
logger
.
info
(
f
'Loading pretrained weights from
{
args
.
from_pretrained_weights
}
'
)
logger
.
info
(
f
'Loading pretrained config from
{
args
.
from_pretrained_config
}
'
)
if
args
.
student_pretrained_weights
is
not
None
:
stu_architecture_config
=
DistilBertConfig
.
from_json_file
(
args
.
from_pretrained_config
)
logger
.
info
(
f
'Loading pretrained weights from
{
args
.
student_pretrained_weights
}
'
)
stu_architecture_config
.
output_hidden_states
=
True
student
=
student_model_class
.
from_pretrained
(
args
.
student_pretrained_weights
,
student
=
DistilBertForMaskedLM
.
from_pretrained
(
args
.
from_pretrained_weights
,
config
=
stu_architecture_config
)
config
=
stu_architecture_config
)
else
:
else
:
args
.
vocab_size_or_config_json_file
=
args
.
vocab_size
student
=
student_model_class
(
stu_architecture_config
)
stu_architecture_config
=
DistilBertConfig
(
**
vars
(
args
),
output_hidden_states
=
True
)
student
=
DistilBertForMaskedLM
(
stu_architecture_config
)
if
args
.
n_gpu
>
0
:
if
args
.
n_gpu
>
0
:
...
@@ -224,18 +254,31 @@ def main():
...
@@ -224,18 +254,31 @@ def main():
## TEACHER ##
## TEACHER ##
if
args
.
teacher_type
==
'bert'
:
teacher
=
teacher_model_class
.
from_pretrained
(
args
.
teacher_name
,
output_hidden_states
=
True
)
teacher
=
BertForMaskedLM
.
from_pretrained
(
args
.
teacher_name
,
output_hidden_states
=
True
)
elif
args
.
teacher_type
==
'roberta'
:
teacher
=
RobertaForMaskedLM
.
from_pretrained
(
args
.
teacher_name
,
output_hidden_states
=
True
)
if
args
.
n_gpu
>
0
:
if
args
.
n_gpu
>
0
:
teacher
.
to
(
f
'cuda:
{
args
.
local_rank
}
'
)
teacher
.
to
(
f
'cuda:
{
args
.
local_rank
}
'
)
logger
.
info
(
f
'Teacher loaded from
{
args
.
teacher_name
}
.'
)
logger
.
info
(
f
'Teacher loaded from
{
args
.
teacher_name
}
.'
)
## FREEZING ##
if
args
.
freeze_pos_embs
:
freeze_pos_embeddings
(
student
,
args
)
if
args
.
freeze_token_type_embds
:
freeze_token_type_embeddings
(
student
,
args
)
## SANITY CHECKS ##
assert
student
.
config
.
vocab_size
==
teacher
.
config
.
vocab_size
assert
student
.
config
.
hidden_size
==
teacher
.
config
.
hidden_size
assert
student
.
config
.
max_position_embeddings
==
teacher
.
config
.
max_position_embeddings
if
args
.
mlm
:
assert
token_probs
.
size
(
0
)
==
stu_architecture_config
.
vocab_size
## DISTILLER ##
## DISTILLER ##
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
distiller
=
Distiller
(
params
=
args
,
distiller
=
Distiller
(
params
=
args
,
data
loader
=
train_
dataloader
,
data
set
=
train_
lm_seq_dataset
,
token_probs
=
token_probs
,
token_probs
=
token_probs
,
student
=
student
,
student
=
student
,
teacher
=
teacher
)
teacher
=
teacher
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment