Commit c51e533a authored by VictorSanh's avatar VictorSanh Committed by Victor SANH
Browse files

update train.py

parent a76c3f9c
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Training DistilBERT. Training the distilled model.
Supported architectures include: BERT -> DistilBERT, RoBERTa -> DistilRoBERTa, GPT2 -> DistilGPT2.
""" """
import os import os
import argparse import argparse
...@@ -23,68 +24,96 @@ import shutil ...@@ -23,68 +24,96 @@ import shutil
import numpy as np import numpy as np
import torch import torch
from transformers import BertTokenizer, BertForMaskedLM, RobertaTokenizer, RobertaForMaskedLM from transformers import BertConfig, BertForMaskedLM, BertTokenizer
from transformers import DistilBertForMaskedLM, DistilBertConfig from transformers import RobertaConfig, RobertaForMaskedLM, RobertaTokenizer
from transformers import DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
from distiller import Distiller from distiller import Distiller
from utils import git_log, logger, init_gpu_params, set_seed from utils import git_log, logger, init_gpu_params, set_seed
from dataset import Dataset from lm_seqs_dataset import LmSeqsDataset
MODEL_CLASSES = {
'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
'bert': (BertConfig, BertForMaskedLM, BertTokenizer),
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer)
}
def sanity_checks(args):
"""
A bunch of args sanity checks to perform even starting...
"""
assert (args.mlm and args.alpha_mlm > 0.) or (not args.mlm and args.alpha_mlm == 0.)
assert (args.alpha_mlm > 0. and args.alpha_clm == 0.) or (args.alpha_mlm == 0. and args.alpha_clm > 0.)
if args.mlm:
assert os.path.isfile(args.token_counts)
assert (args.student_type in ['roberta', 'distilbert']) and (args.teacher_type in ['roberta', 'bert'])
else:
assert (args.student_type in ['gpt2']) and (args.teacher_type in ['gpt2'])
assert args.teacher_type == args.student_type or (args.student_type=='distilbert' and args.teacher_type=='bert')
assert os.path.isfile(args.student_config)
if args.student_pretrained_weights is not None:
assert os.path.isfile(args.student_pretrained_weights)
if args.freeze_token_type_embds: assert args.student_type in ['roberta']
assert args.alpha_ce >= 0.
assert args.alpha_mlm >= 0.
assert args.alpha_clm >= 0.
assert args.alpha_mse >= 0.
assert args.alpha_cos >= 0.
assert args.alpha_ce + args.alpha_mlm + args.alpha_clm + args.alpha_mse + args.alpha_cos > 0.
def freeze_pos_embeddings(student, args):
if args.student_type == 'roberta':
student.roberta.embeddings.position_embeddings.weight.requires_grad = False
elif args.student_type == 'gpt2':
student.transformer.wpe.weight.requires_grad = False
def freeze_token_type_embeddings(student, args):
if args.student_type == 'roberta':
student.roberta.embeddings.token_type_embeddings.weight.requires_grad = False
def main(): def main():
parser = argparse.ArgumentParser(description="Training") parser = argparse.ArgumentParser(description="Training")
parser.add_argument("--force", action='store_true',
help="Overwrite dump_path if it already exists.")
parser.add_argument("--dump_path", type=str, required=True, parser.add_argument("--dump_path", type=str, required=True,
help="The output directory (log, checkpoints, parameters, etc.)") help="The output directory (log, checkpoints, parameters, etc.)")
parser.add_argument("--data_file", type=str, required=True, parser.add_argument("--data_file", type=str, required=True,
help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.") help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.")
parser.add_argument("--token_counts", type=str, required=True,
help="The token counts in the data_file for MLM.")
parser.add_argument("--force", action='store_true',
help="Overwrite dump_path if it already exists.")
parser.add_argument("--vocab_size", default=30522, type=int, parser.add_argument("--student_type", type=str, choices=["distilbert", "roberta", "gpt2"], required=True,
help="The vocabulary size.") help="The student type (DistilBERT, RoBERTa).")
parser.add_argument("--max_position_embeddings", default=512, type=int, parser.add_argument("--student_config", type=str, required=True,
help="Maximum sequence length we can model (including [CLS] and [SEP]).") help="Path to the student configuration.")
parser.add_argument("--sinusoidal_pos_embds", action='store_false', parser.add_argument("--student_pretrained_weights", default=None, type=str,
help="If true, the position embeddings are simply fixed with sinusoidal embeddings.")
parser.add_argument("--n_layers", default=6, type=int,
help="Number of Transformer blocks.")
parser.add_argument("--n_heads", default=12, type=int,
help="Number of heads in the self-attention module.")
parser.add_argument("--dim", default=768, type=int,
help="Dimension through the network. Must be divisible by n_heads")
parser.add_argument("--hidden_dim", default=3072, type=int,
help="Intermediate dimension in the FFN.")
parser.add_argument("--dropout", default=0.1, type=float,
help="Dropout.")
parser.add_argument("--attention_dropout", default=0.1, type=float,
help="Dropout in self-attention.")
parser.add_argument("--activation", default='gelu', type=str,
help="Activation to use in self-attention")
parser.add_argument("--tie_weights_", action='store_false',
help="If true, we tie the embeddings matrix with the projection over the vocabulary matrix. Default is true.")
parser.add_argument("--from_pretrained_weights", default=None, type=str,
help="Load student initialization checkpoint.") help="Load student initialization checkpoint.")
parser.add_argument("--from_pretrained_config", default=None, type=str,
help="Load student initialization architecture config.") parser.add_argument("--teacher_type", choices=["bert", "roberta", "gpt2"], required=True,
parser.add_argument("--teacher_type", default="bert", choices=["bert", "roberta"],
help="Teacher type (BERT, RoBERTa).") help="Teacher type (BERT, RoBERTa).")
parser.add_argument("--teacher_name", default="bert-base-uncased", type=str, parser.add_argument("--teacher_name", type=str, required=True,
help="The teacher model.") help="The teacher model.")
parser.add_argument("--temperature", default=2., type=float, parser.add_argument("--temperature", default=2., type=float,
help="Temperature for the softmax temperature.") help="Temperature for the softmax temperature.")
parser.add_argument("--alpha_ce", default=0.5, type=float, parser.add_argument("--alpha_ce", default=0.5, type=float,
help="Linear weight for the distillation loss. Must be >=0.") help="Linear weight for the distillation loss. Must be >=0.")
parser.add_argument("--alpha_mlm", default=0.5, type=float, parser.add_argument("--alpha_mlm", default=0.0, type=float,
help="Linear weight for the MLM loss. Must be >=0.") help="Linear weight for the MLM loss. Must be >=0. Should be used in coonjunction with `mlm` flag.")
parser.add_argument("--alpha_clm", default=0.5, type=float,
help="Linear weight for the CLM loss. Must be >=0.")
parser.add_argument("--alpha_mse", default=0.0, type=float, parser.add_argument("--alpha_mse", default=0.0, type=float,
help="Linear weight of the MSE loss. Must be >=0.") help="Linear weight of the MSE loss. Must be >=0.")
parser.add_argument("--alpha_cos", default=0.0, type=float, parser.add_argument("--alpha_cos", default=0.0, type=float,
help="Linear weight of the cosine embedding loss. Must be >=0.") help="Linear weight of the cosine embedding loss. Must be >=0.")
parser.add_argument("--mlm", action="store_true",
help="The LM step: MLM or CLM. If `mlm` is True, the MLM is used over CLM.")
parser.add_argument("--mlm_mask_prop", default=0.15, type=float, parser.add_argument("--mlm_mask_prop", default=0.15, type=float,
help="Proportion of tokens for which we need to make a prediction.") help="Proportion of tokens for which we need to make a prediction.")
parser.add_argument("--word_mask", default=0.8, type=float, parser.add_argument("--word_mask", default=0.8, type=float,
...@@ -95,17 +124,20 @@ def main(): ...@@ -95,17 +124,20 @@ def main():
help="Proportion of tokens to randomly replace.") help="Proportion of tokens to randomly replace.")
parser.add_argument("--mlm_smoothing", default=0.7, type=float, parser.add_argument("--mlm_smoothing", default=0.7, type=float,
help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).") help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).")
parser.add_argument("--token_counts", type=str,
help="The token counts in the data_file for MLM.")
parser.add_argument("--restrict_ce_to_mask", action='store_true', parser.add_argument("--restrict_ce_to_mask", action='store_true',
help="If true, compute the distilation loss only the [MLM] prediction distribution.") help="If true, compute the distilation loss only the [MLM] prediction distribution.")
parser.add_argument("--freeze_pos_embs", action="store_true",
help="Freeze positional embeddings during distillation. For student_type in ['roberta', 'gpt2'] only.")
parser.add_argument("--freeze_token_type_embds", action="store_true",
help="Freeze token type embeddings during distillation if existent. For student_type in ['roberta'] only.")
parser.add_argument("--n_epoch", type=int, default=3, parser.add_argument("--n_epoch", type=int, default=3,
help="Number of pass on the whole dataset.") help="Number of pass on the whole dataset.")
parser.add_argument("--batch_size", type=int, default=5, parser.add_argument("--batch_size", type=int, default=5,
help="Batch size (for each process).") help="Batch size (for each process).")
parser.add_argument("--tokens_per_batch", type=int, default=-1,
help="If specified, modify the batches so that they have approximately this number of tokens.")
parser.add_argument("--shuffle", action='store_false',
help="If true, shuffle the sequence order. Default is true.")
parser.add_argument("--group_by_size", action='store_false', parser.add_argument("--group_by_size", action='store_false',
help="If true, group sequences that have similar length into the same batch. Default is true.") help="If true, group sequences that have similar length into the same batch. Default is true.")
...@@ -141,6 +173,7 @@ def main(): ...@@ -141,6 +173,7 @@ def main():
parser.add_argument("--checkpoint_interval", type=int, default=4000, parser.add_argument("--checkpoint_interval", type=int, default=4000,
help="Checkpoint interval.") help="Checkpoint interval.")
args = parser.parse_args() args = parser.parse_args()
sanity_checks(args)
## ARGS ## ## ARGS ##
...@@ -164,21 +197,19 @@ def main(): ...@@ -164,21 +197,19 @@ def main():
with open(os.path.join(args.dump_path, 'parameters.json'), 'w') as f: with open(os.path.join(args.dump_path, 'parameters.json'), 'w') as f:
json.dump(vars(args), f, indent=4) json.dump(vars(args), f, indent=4)
git_log(args.dump_path) git_log(args.dump_path)
assert (args.from_pretrained_weights is None and args.from_pretrained_config is None) or \
(args.from_pretrained_weights is not None and args.from_pretrained_config is not None)
student_config_class, student_model_class, _ = MODEL_CLASSES[args.student_type]
teacher_config_class, teacher_model_class, teacher_tokenizer_class = MODEL_CLASSES[args.teacher_type]
### TOKENIZER ### ### TOKENIZER ###
if args.teacher_type == 'bert': tokenizer = teacher_tokenizer_class.from_pretrained(args.teacher_name)
tokenizer = BertTokenizer.from_pretrained(args.teacher_name)
elif args.teacher_type == 'roberta':
tokenizer = RobertaTokenizer.from_pretrained(args.teacher_name)
special_tok_ids = {} special_tok_ids = {}
for tok_name, tok_symbol in tokenizer.special_tokens_map.items(): for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
idx = tokenizer.all_special_tokens.index(tok_symbol) idx = tokenizer.all_special_tokens.index(tok_symbol)
special_tok_ids[tok_name] = tokenizer.all_special_ids[idx] special_tok_ids[tok_name] = tokenizer.all_special_ids[idx]
logger.info(f'Special tokens {special_tok_ids}') logger.info(f'Special tokens {special_tok_ids}')
args.special_tok_ids = special_tok_ids args.special_tok_ids = special_tok_ids
args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name]
## DATA LOADER ## ## DATA LOADER ##
...@@ -187,35 +218,34 @@ def main(): ...@@ -187,35 +218,34 @@ def main():
data = pickle.load(fp) data = pickle.load(fp)
assert os.path.isfile(args.token_counts) if args.mlm:
logger.info(f'Loading token counts from {args.token_counts} (already pre-computed)') logger.info(f'Loading token counts from {args.token_counts} (already pre-computed)')
with open(args.token_counts, 'rb') as fp: with open(args.token_counts, 'rb') as fp:
counts = pickle.load(fp) counts = pickle.load(fp)
assert len(counts) == args.vocab_size
token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing
for idx in special_tok_ids.values(): for idx in special_tok_ids.values():
token_probs[idx] = 0. # do not predict special tokens token_probs[idx] = 0. # do not predict special tokens
token_probs = torch.from_numpy(token_probs) token_probs = torch.from_numpy(token_probs)
else:
token_probs = None
train_dataloader = Dataset(params=args, data=data) train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
logger.info(f'Data loader created.') logger.info(f'Data loader created.')
## STUDENT ## ## STUDENT ##
if args.from_pretrained_weights is not None: logger.info(f'Loading student config from {args.student_config}')
assert os.path.isfile(args.from_pretrained_weights) stu_architecture_config = student_config_class.from_pretrained(args.student_config)
assert os.path.isfile(args.from_pretrained_config) stu_architecture_config.output_hidden_states = True
logger.info(f'Loading pretrained weights from {args.from_pretrained_weights}')
logger.info(f'Loading pretrained config from {args.from_pretrained_config}') if args.student_pretrained_weights is not None:
stu_architecture_config = DistilBertConfig.from_json_file(args.from_pretrained_config) logger.info(f'Loading pretrained weights from {args.student_pretrained_weights}')
stu_architecture_config.output_hidden_states = True student = student_model_class.from_pretrained(args.student_pretrained_weights,
student = DistilBertForMaskedLM.from_pretrained(args.from_pretrained_weights, config=stu_architecture_config)
config=stu_architecture_config)
else: else:
args.vocab_size_or_config_json_file = args.vocab_size student = student_model_class(stu_architecture_config)
stu_architecture_config = DistilBertConfig(**vars(args), output_hidden_states=True)
student = DistilBertForMaskedLM(stu_architecture_config)
if args.n_gpu > 0: if args.n_gpu > 0:
...@@ -224,18 +254,31 @@ def main(): ...@@ -224,18 +254,31 @@ def main():
## TEACHER ## ## TEACHER ##
if args.teacher_type == 'bert': teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True)
teacher = BertForMaskedLM.from_pretrained(args.teacher_name, output_hidden_states=True)
elif args.teacher_type == 'roberta':
teacher = RobertaForMaskedLM.from_pretrained(args.teacher_name, output_hidden_states=True)
if args.n_gpu > 0: if args.n_gpu > 0:
teacher.to(f'cuda:{args.local_rank}') teacher.to(f'cuda:{args.local_rank}')
logger.info(f'Teacher loaded from {args.teacher_name}.') logger.info(f'Teacher loaded from {args.teacher_name}.')
## FREEZING ##
if args.freeze_pos_embs:
freeze_pos_embeddings(student, args)
if args.freeze_token_type_embds:
freeze_token_type_embeddings(student, args)
## SANITY CHECKS ##
assert student.config.vocab_size == teacher.config.vocab_size
assert student.config.hidden_size == teacher.config.hidden_size
assert student.config.max_position_embeddings == teacher.config.max_position_embeddings
if args.mlm:
assert token_probs.size(0) == stu_architecture_config.vocab_size
## DISTILLER ## ## DISTILLER ##
torch.cuda.empty_cache() torch.cuda.empty_cache()
distiller = Distiller(params=args, distiller = Distiller(params=args,
dataloader=train_dataloader, dataset=train_lm_seq_dataset,
token_probs=token_probs, token_probs=token_probs,
student=student, student=student,
teacher=teacher) teacher=teacher)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment