Commit d9e60f4f authored by thomwolf's avatar thomwolf
Browse files

Merge branch 'master' into pr/1383

parents 07d055f8 1c507995
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocessing script before training the distilled model.
Specific to RoBERTa -> DistilRoBERTa and GPT2 -> DistilGPT2.
"""
from transformers import BertForMaskedLM, RobertaForMaskedLM, GPT2LMHeadModel
import torch
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation")
parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"])
parser.add_argument("--model_name", default='roberta-large', type=str)
parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_roberta_048131723.pth', type=str)
parser.add_argument("--vocab_transform", action='store_true')
args = parser.parse_args()
if args.model_type == 'roberta':
model = RobertaForMaskedLM.from_pretrained(args.model_name)
prefix = 'roberta'
elif args.model_type == 'gpt2':
model = GPT2LMHeadModel.from_pretrained(args.model_name)
prefix = 'transformer'
state_dict = model.state_dict()
compressed_sd = {}
### Embeddings ###
if args.model_type == 'gpt2':
for param_name in ['wte.weight', 'wpe.weight']:
compressed_sd[f'{prefix}.{param_name}'] = state_dict[f'{prefix}.{param_name}']
else:
for w in ['word_embeddings', 'position_embeddings', 'token_type_embeddings']:
param_name = f'{prefix}.embeddings.{w}.weight'
compressed_sd[param_name] = state_dict[param_name]
for w in ['weight', 'bias']:
param_name = f'{prefix}.embeddings.LayerNorm.{w}'
compressed_sd[param_name] = state_dict[param_name]
### Transformer Blocks ###
std_idx = 0
for teacher_idx in [0, 2, 4, 7, 9, 11]:
if args.model_type == 'gpt2':
for layer in ['ln_1', 'attn.c_attn', 'attn.c_proj', 'ln_2', 'mlp.c_fc', 'mlp.c_proj']:
for w in ['weight', 'bias']:
compressed_sd[f'{prefix}.h.{std_idx}.{layer}.{w}'] = \
state_dict[f'{prefix}.h.{teacher_idx}.{layer}.{w}']
compressed_sd[f'{prefix}.h.{std_idx}.attn.bias'] = state_dict[f'{prefix}.h.{teacher_idx}.attn.bias']
else:
for layer in ['attention.self.query', 'attention.self.key', 'attention.self.value',
'attention.output.dense', 'attention.output.LayerNorm',
'intermediate.dense', 'output.dense', 'output.LayerNorm']:
for w in ['weight', 'bias']:
compressed_sd[f'{prefix}.encoder.layer.{std_idx}.{layer}.{w}'] = \
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.{layer}.{w}']
std_idx += 1
### Language Modeling Head ###s
if args.model_type == 'roberta':
for layer in ['lm_head.decoder.weight', 'lm_head.bias']:
compressed_sd[f'{layer}'] = state_dict[f'{layer}']
if args.vocab_transform:
for w in ['weight', 'bias']:
compressed_sd[f'lm_head.dense.{w}'] = state_dict[f'lm_head.dense.{w}']
compressed_sd[f'lm_head.layer_norm.{w}'] = state_dict[f'lm_head.layer_norm.{w}']
elif args.model_type == 'gpt2':
for w in ['weight', 'bias']:
compressed_sd[f'{prefix}.ln_f.{w}'] = state_dict[f'{prefix}.ln_f.{w}']
compressed_sd[f'lm_head.weight'] = state_dict[f'lm_head.weight']
print(f'N layers selected for distillation: {std_idx}')
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}')
print(f'Save transfered checkpoint to {args.dump_checkpoint}.')
torch.save(compressed_sd, args.dump_checkpoint)
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" """
Preprocessing script before training DistilBERT. Preprocessing script before training DistilBERT.
Specific to BERT -> DistilBERT.
""" """
from transformers import BertForMaskedLM, RobertaForMaskedLM from transformers import BertForMaskedLM, RobertaForMaskedLM
import torch import torch
...@@ -21,7 +22,7 @@ import argparse ...@@ -21,7 +22,7 @@ import argparse
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation") parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation")
parser.add_argument("--model_type", default="bert", choices=["bert", "roberta"]) parser.add_argument("--model_type", default="bert", choices=["bert"])
parser.add_argument("--model_name", default='bert-base-uncased', type=str) parser.add_argument("--model_name", default='bert-base-uncased', type=str)
parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_bert-base-uncased_0247911.pth', type=str) parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_bert-base-uncased_0247911.pth', type=str)
parser.add_argument("--vocab_transform", action='store_true') parser.add_argument("--vocab_transform", action='store_true')
...@@ -31,9 +32,8 @@ if __name__ == '__main__': ...@@ -31,9 +32,8 @@ if __name__ == '__main__':
if args.model_type == 'bert': if args.model_type == 'bert':
model = BertForMaskedLM.from_pretrained(args.model_name) model = BertForMaskedLM.from_pretrained(args.model_name)
prefix = 'bert' prefix = 'bert'
elif args.model_type == 'roberta': else:
model = RobertaForMaskedLM.from_pretrained(args.model_name) raise ValueError(f'args.model_type should be "bert".')
prefix = 'roberta'
state_dict = model.state_dict() state_dict = model.state_dict()
compressed_sd = {} compressed_sd = {}
...@@ -68,20 +68,12 @@ if __name__ == '__main__': ...@@ -68,20 +68,12 @@ if __name__ == '__main__':
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}'] state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}']
std_idx += 1 std_idx += 1
if args.model_type == 'bert':
compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight'] compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight']
compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias'] compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias']
if args.vocab_transform: if args.vocab_transform:
for w in ['weight', 'bias']: for w in ['weight', 'bias']:
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}'] compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}']
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}'] compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}']
elif args.model_type == 'roberta':
compressed_sd[f'vocab_projector.weight'] = state_dict[f'lm_head.decoder.weight']
compressed_sd[f'vocab_projector.bias'] = state_dict[f'lm_head.bias']
if args.vocab_transform:
for w in ['weight', 'bias']:
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'lm_head.dense.{w}']
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'lm_head.layer_norm.{w}']
print(f'N layers selected for distillation: {std_idx}') print(f'N layers selected for distillation: {std_idx}')
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}') print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}')
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Preprocessing script before training DistilBERT. Preprocessing script before training the distilled model.
""" """
from collections import Counter from collections import Counter
import argparse import argparse
......
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Training DistilBERT. Training the distilled model.
Supported architectures include: BERT -> DistilBERT, RoBERTa -> DistilRoBERTa, GPT2 -> DistilGPT2.
""" """
import os import os
import argparse import argparse
...@@ -23,68 +24,96 @@ import shutil ...@@ -23,68 +24,96 @@ import shutil
import numpy as np import numpy as np
import torch import torch
from transformers import BertTokenizer, BertForMaskedLM, RobertaTokenizer, RobertaForMaskedLM from transformers import BertConfig, BertForMaskedLM, BertTokenizer
from transformers import DistilBertForMaskedLM, DistilBertConfig from transformers import RobertaConfig, RobertaForMaskedLM, RobertaTokenizer
from transformers import DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
from distiller import Distiller from distiller import Distiller
from utils import git_log, logger, init_gpu_params, set_seed from utils import git_log, logger, init_gpu_params, set_seed
from dataset import Dataset from lm_seqs_dataset import LmSeqsDataset
MODEL_CLASSES = {
'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
'bert': (BertConfig, BertForMaskedLM, BertTokenizer),
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer)
}
def sanity_checks(args):
"""
A bunch of args sanity checks to perform even starting...
"""
assert (args.mlm and args.alpha_mlm > 0.) or (not args.mlm and args.alpha_mlm == 0.)
assert (args.alpha_mlm > 0. and args.alpha_clm == 0.) or (args.alpha_mlm == 0. and args.alpha_clm > 0.)
if args.mlm:
assert os.path.isfile(args.token_counts)
assert (args.student_type in ['roberta', 'distilbert']) and (args.teacher_type in ['roberta', 'bert'])
else:
assert (args.student_type in ['gpt2']) and (args.teacher_type in ['gpt2'])
assert args.teacher_type == args.student_type or (args.student_type=='distilbert' and args.teacher_type=='bert')
assert os.path.isfile(args.student_config)
if args.student_pretrained_weights is not None:
assert os.path.isfile(args.student_pretrained_weights)
if args.freeze_token_type_embds: assert args.student_type in ['roberta']
assert args.alpha_ce >= 0.
assert args.alpha_mlm >= 0.
assert args.alpha_clm >= 0.
assert args.alpha_mse >= 0.
assert args.alpha_cos >= 0.
assert args.alpha_ce + args.alpha_mlm + args.alpha_clm + args.alpha_mse + args.alpha_cos > 0.
def freeze_pos_embeddings(student, args):
if args.student_type == 'roberta':
student.roberta.embeddings.position_embeddings.weight.requires_grad = False
elif args.student_type == 'gpt2':
student.transformer.wpe.weight.requires_grad = False
def freeze_token_type_embeddings(student, args):
if args.student_type == 'roberta':
student.roberta.embeddings.token_type_embeddings.weight.requires_grad = False
def main(): def main():
parser = argparse.ArgumentParser(description="Training") parser = argparse.ArgumentParser(description="Training")
parser.add_argument("--force", action='store_true',
help="Overwrite dump_path if it already exists.")
parser.add_argument("--dump_path", type=str, required=True, parser.add_argument("--dump_path", type=str, required=True,
help="The output directory (log, checkpoints, parameters, etc.)") help="The output directory (log, checkpoints, parameters, etc.)")
parser.add_argument("--data_file", type=str, required=True, parser.add_argument("--data_file", type=str, required=True,
help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.") help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.")
parser.add_argument("--token_counts", type=str, required=True,
help="The token counts in the data_file for MLM.")
parser.add_argument("--force", action='store_true',
help="Overwrite dump_path if it already exists.")
parser.add_argument("--vocab_size", default=30522, type=int, parser.add_argument("--student_type", type=str, choices=["distilbert", "roberta", "gpt2"], required=True,
help="The vocabulary size.") help="The student type (DistilBERT, RoBERTa).")
parser.add_argument("--max_position_embeddings", default=512, type=int, parser.add_argument("--student_config", type=str, required=True,
help="Maximum sequence length we can model (including [CLS] and [SEP]).") help="Path to the student configuration.")
parser.add_argument("--sinusoidal_pos_embds", action='store_false', parser.add_argument("--student_pretrained_weights", default=None, type=str,
help="If true, the position embeddings are simply fixed with sinusoidal embeddings.")
parser.add_argument("--n_layers", default=6, type=int,
help="Number of Transformer blocks.")
parser.add_argument("--n_heads", default=12, type=int,
help="Number of heads in the self-attention module.")
parser.add_argument("--dim", default=768, type=int,
help="Dimension through the network. Must be divisible by n_heads")
parser.add_argument("--hidden_dim", default=3072, type=int,
help="Intermediate dimension in the FFN.")
parser.add_argument("--dropout", default=0.1, type=float,
help="Dropout.")
parser.add_argument("--attention_dropout", default=0.1, type=float,
help="Dropout in self-attention.")
parser.add_argument("--activation", default='gelu', type=str,
help="Activation to use in self-attention")
parser.add_argument("--tie_weights_", action='store_false',
help="If true, we tie the embeddings matrix with the projection over the vocabulary matrix. Default is true.")
parser.add_argument("--from_pretrained_weights", default=None, type=str,
help="Load student initialization checkpoint.") help="Load student initialization checkpoint.")
parser.add_argument("--from_pretrained_config", default=None, type=str,
help="Load student initialization architecture config.") parser.add_argument("--teacher_type", choices=["bert", "roberta", "gpt2"], required=True,
parser.add_argument("--teacher_type", default="bert", choices=["bert", "roberta"],
help="Teacher type (BERT, RoBERTa).") help="Teacher type (BERT, RoBERTa).")
parser.add_argument("--teacher_name", default="bert-base-uncased", type=str, parser.add_argument("--teacher_name", type=str, required=True,
help="The teacher model.") help="The teacher model.")
parser.add_argument("--temperature", default=2., type=float, parser.add_argument("--temperature", default=2., type=float,
help="Temperature for the softmax temperature.") help="Temperature for the softmax temperature.")
parser.add_argument("--alpha_ce", default=0.5, type=float, parser.add_argument("--alpha_ce", default=0.5, type=float,
help="Linear weight for the distillation loss. Must be >=0.") help="Linear weight for the distillation loss. Must be >=0.")
parser.add_argument("--alpha_mlm", default=0.5, type=float, parser.add_argument("--alpha_mlm", default=0.0, type=float,
help="Linear weight for the MLM loss. Must be >=0.") help="Linear weight for the MLM loss. Must be >=0. Should be used in coonjunction with `mlm` flag.")
parser.add_argument("--alpha_clm", default=0.5, type=float,
help="Linear weight for the CLM loss. Must be >=0.")
parser.add_argument("--alpha_mse", default=0.0, type=float, parser.add_argument("--alpha_mse", default=0.0, type=float,
help="Linear weight of the MSE loss. Must be >=0.") help="Linear weight of the MSE loss. Must be >=0.")
parser.add_argument("--alpha_cos", default=0.0, type=float, parser.add_argument("--alpha_cos", default=0.0, type=float,
help="Linear weight of the cosine embedding loss. Must be >=0.") help="Linear weight of the cosine embedding loss. Must be >=0.")
parser.add_argument("--mlm", action="store_true",
help="The LM step: MLM or CLM. If `mlm` is True, the MLM is used over CLM.")
parser.add_argument("--mlm_mask_prop", default=0.15, type=float, parser.add_argument("--mlm_mask_prop", default=0.15, type=float,
help="Proportion of tokens for which we need to make a prediction.") help="Proportion of tokens for which we need to make a prediction.")
parser.add_argument("--word_mask", default=0.8, type=float, parser.add_argument("--word_mask", default=0.8, type=float,
...@@ -95,17 +124,20 @@ def main(): ...@@ -95,17 +124,20 @@ def main():
help="Proportion of tokens to randomly replace.") help="Proportion of tokens to randomly replace.")
parser.add_argument("--mlm_smoothing", default=0.7, type=float, parser.add_argument("--mlm_smoothing", default=0.7, type=float,
help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).") help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).")
parser.add_argument("--token_counts", type=str,
help="The token counts in the data_file for MLM.")
parser.add_argument("--restrict_ce_to_mask", action='store_true', parser.add_argument("--restrict_ce_to_mask", action='store_true',
help="If true, compute the distilation loss only the [MLM] prediction distribution.") help="If true, compute the distilation loss only the [MLM] prediction distribution.")
parser.add_argument("--freeze_pos_embs", action="store_true",
help="Freeze positional embeddings during distillation. For student_type in ['roberta', 'gpt2'] only.")
parser.add_argument("--freeze_token_type_embds", action="store_true",
help="Freeze token type embeddings during distillation if existent. For student_type in ['roberta'] only.")
parser.add_argument("--n_epoch", type=int, default=3, parser.add_argument("--n_epoch", type=int, default=3,
help="Number of pass on the whole dataset.") help="Number of pass on the whole dataset.")
parser.add_argument("--batch_size", type=int, default=5, parser.add_argument("--batch_size", type=int, default=5,
help="Batch size (for each process).") help="Batch size (for each process).")
parser.add_argument("--tokens_per_batch", type=int, default=-1,
help="If specified, modify the batches so that they have approximately this number of tokens.")
parser.add_argument("--shuffle", action='store_false',
help="If true, shuffle the sequence order. Default is true.")
parser.add_argument("--group_by_size", action='store_false', parser.add_argument("--group_by_size", action='store_false',
help="If true, group sequences that have similar length into the same batch. Default is true.") help="If true, group sequences that have similar length into the same batch. Default is true.")
...@@ -141,6 +173,7 @@ def main(): ...@@ -141,6 +173,7 @@ def main():
parser.add_argument("--checkpoint_interval", type=int, default=4000, parser.add_argument("--checkpoint_interval", type=int, default=4000,
help="Checkpoint interval.") help="Checkpoint interval.")
args = parser.parse_args() args = parser.parse_args()
sanity_checks(args)
## ARGS ## ## ARGS ##
...@@ -164,21 +197,19 @@ def main(): ...@@ -164,21 +197,19 @@ def main():
with open(os.path.join(args.dump_path, 'parameters.json'), 'w') as f: with open(os.path.join(args.dump_path, 'parameters.json'), 'w') as f:
json.dump(vars(args), f, indent=4) json.dump(vars(args), f, indent=4)
git_log(args.dump_path) git_log(args.dump_path)
assert (args.from_pretrained_weights is None and args.from_pretrained_config is None) or \
(args.from_pretrained_weights is not None and args.from_pretrained_config is not None)
student_config_class, student_model_class, _ = MODEL_CLASSES[args.student_type]
teacher_config_class, teacher_model_class, teacher_tokenizer_class = MODEL_CLASSES[args.teacher_type]
### TOKENIZER ### ### TOKENIZER ###
if args.teacher_type == 'bert': tokenizer = teacher_tokenizer_class.from_pretrained(args.teacher_name)
tokenizer = BertTokenizer.from_pretrained(args.teacher_name)
elif args.teacher_type == 'roberta':
tokenizer = RobertaTokenizer.from_pretrained(args.teacher_name)
special_tok_ids = {} special_tok_ids = {}
for tok_name, tok_symbol in tokenizer.special_tokens_map.items(): for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
idx = tokenizer.all_special_tokens.index(tok_symbol) idx = tokenizer.all_special_tokens.index(tok_symbol)
special_tok_ids[tok_name] = tokenizer.all_special_ids[idx] special_tok_ids[tok_name] = tokenizer.all_special_ids[idx]
logger.info(f'Special tokens {special_tok_ids}') logger.info(f'Special tokens {special_tok_ids}')
args.special_tok_ids = special_tok_ids args.special_tok_ids = special_tok_ids
args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name]
## DATA LOADER ## ## DATA LOADER ##
...@@ -187,35 +218,34 @@ def main(): ...@@ -187,35 +218,34 @@ def main():
data = pickle.load(fp) data = pickle.load(fp)
assert os.path.isfile(args.token_counts) if args.mlm:
logger.info(f'Loading token counts from {args.token_counts} (already pre-computed)') logger.info(f'Loading token counts from {args.token_counts} (already pre-computed)')
with open(args.token_counts, 'rb') as fp: with open(args.token_counts, 'rb') as fp:
counts = pickle.load(fp) counts = pickle.load(fp)
assert len(counts) == args.vocab_size
token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing
for idx in special_tok_ids.values(): for idx in special_tok_ids.values():
token_probs[idx] = 0. # do not predict special tokens token_probs[idx] = 0. # do not predict special tokens
token_probs = torch.from_numpy(token_probs) token_probs = torch.from_numpy(token_probs)
else:
token_probs = None
train_dataloader = Dataset(params=args, data=data) train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
logger.info(f'Data loader created.') logger.info(f'Data loader created.')
## STUDENT ## ## STUDENT ##
if args.from_pretrained_weights is not None: logger.info(f'Loading student config from {args.student_config}')
assert os.path.isfile(args.from_pretrained_weights) stu_architecture_config = student_config_class.from_pretrained(args.student_config)
assert os.path.isfile(args.from_pretrained_config)
logger.info(f'Loading pretrained weights from {args.from_pretrained_weights}')
logger.info(f'Loading pretrained config from {args.from_pretrained_config}')
stu_architecture_config = DistilBertConfig.from_json_file(args.from_pretrained_config)
stu_architecture_config.output_hidden_states = True stu_architecture_config.output_hidden_states = True
student = DistilBertForMaskedLM.from_pretrained(args.from_pretrained_weights,
if args.student_pretrained_weights is not None:
logger.info(f'Loading pretrained weights from {args.student_pretrained_weights}')
student = student_model_class.from_pretrained(args.student_pretrained_weights,
config=stu_architecture_config) config=stu_architecture_config)
else: else:
args.vocab_size_or_config_json_file = args.vocab_size student = student_model_class(stu_architecture_config)
stu_architecture_config = DistilBertConfig(**vars(args), output_hidden_states=True)
student = DistilBertForMaskedLM(stu_architecture_config)
if args.n_gpu > 0: if args.n_gpu > 0:
...@@ -224,18 +254,31 @@ def main(): ...@@ -224,18 +254,31 @@ def main():
## TEACHER ## ## TEACHER ##
if args.teacher_type == 'bert': teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True)
teacher = BertForMaskedLM.from_pretrained(args.teacher_name, output_hidden_states=True)
elif args.teacher_type == 'roberta':
teacher = RobertaForMaskedLM.from_pretrained(args.teacher_name, output_hidden_states=True)
if args.n_gpu > 0: if args.n_gpu > 0:
teacher.to(f'cuda:{args.local_rank}') teacher.to(f'cuda:{args.local_rank}')
logger.info(f'Teacher loaded from {args.teacher_name}.') logger.info(f'Teacher loaded from {args.teacher_name}.')
## FREEZING ##
if args.freeze_pos_embs:
freeze_pos_embeddings(student, args)
if args.freeze_token_type_embds:
freeze_token_type_embeddings(student, args)
## SANITY CHECKS ##
assert student.config.vocab_size == teacher.config.vocab_size
assert student.config.hidden_size == teacher.config.hidden_size
assert student.config.max_position_embeddings == teacher.config.max_position_embeddings
if args.mlm:
assert token_probs.size(0) == stu_architecture_config.vocab_size
## DISTILLER ## ## DISTILLER ##
torch.cuda.empty_cache() torch.cuda.empty_cache()
distiller = Distiller(params=args, distiller = Distiller(params=args,
dataloader=train_dataloader, dataset=train_lm_seq_dataset,
token_probs=token_probs, token_probs=token_probs,
student=student, student=student,
teacher=teacher) teacher=teacher)
......
{
"activation": "gelu",
"attention_dropout": 0.1,
"dim": 768,
"dropout": 0.1,
"hidden_dim": 3072,
"initializer_range": 0.02,
"max_position_embeddings": 512,
"n_heads": 12,
"n_layers": 6,
"sinusoidal_pos_embds": true,
"tie_weights_": true,
"vocab_size": 30522
}
\ No newline at end of file
{
"initializer_range": 0.02,
"layer_norm_epsilon": 0.00001,
"n_ctx": 1024,
"n_embd": 768,
"n_head": 12,
"n_layer": 6,
"n_positions": 1024,
"vocab_size": 50257
}
\ No newline at end of file
...@@ -26,13 +26,15 @@ import torch ...@@ -26,13 +26,15 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, CTRLConfig from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig
from transformers import GPT2LMHeadModel, GPT2Tokenizer from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
from transformers import XLNetLMHeadModel, XLNetTokenizer from transformers import XLNetLMHeadModel, XLNetTokenizer
from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
from transformers import CTRLLMHeadModel, CTRLTokenizer from transformers import CTRLLMHeadModel, CTRLTokenizer
from transformers import XLMWithLMHeadModel, XLMTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
...@@ -41,7 +43,7 @@ logger = logging.getLogger(__name__) ...@@ -41,7 +43,7 @@ logger = logging.getLogger(__name__)
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, CTRLConfig)), ()) ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig)), ())
MODEL_CLASSES = { MODEL_CLASSES = {
'gpt2': (GPT2LMHeadModel, GPT2Tokenizer), 'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
...@@ -49,6 +51,7 @@ MODEL_CLASSES = { ...@@ -49,6 +51,7 @@ MODEL_CLASSES = {
'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), 'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
'xlnet': (XLNetLMHeadModel, XLNetTokenizer), 'xlnet': (XLNetLMHeadModel, XLNetTokenizer),
'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer), 'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer),
'xlm': (XLMWithLMHeadModel, XLMTokenizer),
} }
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
...@@ -104,7 +107,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf') ...@@ -104,7 +107,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
return logits return logits
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, is_xlnet=False, device='cpu'): def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, is_xlnet=False, xlm_lang=None, device='cpu'):
context = torch.tensor(context, dtype=torch.long, device=device) context = torch.tensor(context, dtype=torch.long, device=device)
context = context.unsqueeze(0).repeat(num_samples, 1) context = context.unsqueeze(0).repeat(num_samples, 1)
generated = context generated = context
...@@ -122,6 +125,9 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k= ...@@ -122,6 +125,9 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
target_mapping[0, 0, -1] = 1.0 # predict last token target_mapping[0, 0, -1] = 1.0 # predict last token
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping} inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
if xlm_lang is not None:
inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1)
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states) outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
next_token_logits = outputs[0][0, -1, :] / (temperature if temperature > 0 else 1.) next_token_logits = outputs[0][0, -1, :] / (temperature if temperature > 0 else 1.)
...@@ -146,6 +152,7 @@ def main(): ...@@ -146,6 +152,7 @@ def main():
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
parser.add_argument("--prompt", type=str, default="") parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--padding_text", type=str, default="") parser.add_argument("--padding_text", type=str, default="")
parser.add_argument("--xlm_lang", type=str, default="", help="Optional language when used with the XLM model.")
parser.add_argument("--length", type=int, default=20) parser.add_argument("--length", type=int, default=20)
parser.add_argument("--temperature", type=float, default=1.0, parser.add_argument("--temperature", type=float, default=1.0,
help="temperature of 0 implies greedy sampling") help="temperature of 0 implies greedy sampling")
...@@ -157,6 +164,8 @@ def main(): ...@@ -157,6 +164,8 @@ def main():
help="Avoid using CUDA when available") help="Avoid using CUDA when available")
parser.add_argument('--seed', type=int, default=42, parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization") help="random seed for initialization")
parser.add_argument('--stop_token', type=str, default=None,
help="Token at which text generation is stopped")
args = parser.parse_args() args = parser.parse_args()
if args.model_type in ["ctrl"]: if args.model_type in ["ctrl"]:
if args.temperature > 0.7 : if args.temperature > 0.7 :
...@@ -183,6 +192,18 @@ def main(): ...@@ -183,6 +192,18 @@ def main():
print(args) print(args)
while True: while True:
xlm_lang = None
# XLM Language usage detailed in the issues #1414
if args.model_type in ["xlm"] and hasattr(tokenizer, 'lang2id') and hasattr(model.config, 'use_lang_emb') \
and model.config.use_lang_emb:
if args.xlm_lang:
language = args.xlm_lang
else:
language = None
while language not in tokenizer.lang2id.keys():
language = input("Using XLM. Select language in " + str(list(tokenizer.lang2id.keys())) + " >>> ")
xlm_lang = tokenizer.lang2id[language]
raw_text = args.prompt if args.prompt else input("Model prompt >>> ") raw_text = args.prompt if args.prompt else input("Model prompt >>> ")
if args.model_type in ["transfo-xl", "xlnet"]: if args.model_type in ["transfo-xl", "xlnet"]:
# Models with memory likes to have a long prompt for short inputs. # Models with memory likes to have a long prompt for short inputs.
...@@ -196,11 +217,15 @@ def main(): ...@@ -196,11 +217,15 @@ def main():
top_k=args.top_k, top_k=args.top_k,
top_p=args.top_p, top_p=args.top_p,
repetition_penalty=args.repetition_penalty, repetition_penalty=args.repetition_penalty,
device=args.device,
is_xlnet=bool(args.model_type == "xlnet"), is_xlnet=bool(args.model_type == "xlnet"),
xlm_lang=xlm_lang,
device=args.device,
) )
out = out[0, len(context_tokens):].tolist() out = out[0, len(context_tokens):].tolist()
text = tokenizer.decode(out, clean_up_tokenization_spaces=True)
text = tokenizer.decode(out, clean_up_tokenization_spaces=True, skip_special_tokens=True)
text = text[: text.find(args.stop_token) if args.stop_token else None]
print(text) print(text)
if args.prompt: if args.prompt:
break break
......
...@@ -53,7 +53,8 @@ from transformers import glue_convert_examples_to_features as convert_examples_t ...@@ -53,7 +53,8 @@ from transformers import glue_convert_examples_to_features as convert_examples_t
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig)), ()) ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig,
RobertaConfig, DistilBertConfig)), ())
MODEL_CLASSES = { MODEL_CLASSES = {
'bert': (BertConfig, BertForSequenceClassification, BertTokenizer), 'bert': (BertConfig, BertForSequenceClassification, BertTokenizer),
...@@ -248,7 +249,7 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -248,7 +249,7 @@ def evaluate(args, model, tokenizer, prefix=""):
result = compute_metrics(eval_task, preds, out_label_ids) result = compute_metrics(eval_task, preds, out_label_ids)
results.update(result) results.update(result)
output_eval_file = os.path.join(eval_output_dir, "eval_results.txt") output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
with open(output_eval_file, "w") as writer: with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(prefix)) logger.info("***** Eval results {} *****".format(prefix))
for key in sorted(result.keys()): for key in sorted(result.keys()):
...@@ -489,9 +490,11 @@ def main(): ...@@ -489,9 +490,11 @@ def main():
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
model = model_class.from_pretrained(checkpoint) model = model_class.from_pretrained(checkpoint)
model.to(args.device) model.to(args.device)
result = evaluate(args, model, tokenizer, prefix=global_step) result = evaluate(args, model, tokenizer, prefix=prefix)
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
results.update(result) results.update(result)
......
...@@ -59,7 +59,7 @@ class TextDataset(Dataset): ...@@ -59,7 +59,7 @@ class TextDataset(Dataset):
def __init__(self, tokenizer, file_path='train', block_size=512): def __init__(self, tokenizer, file_path='train', block_size=512):
assert os.path.isfile(file_path) assert os.path.isfile(file_path)
directory, filename = os.path.split(file_path) directory, filename = os.path.split(file_path)
cached_features_file = os.path.join(directory, f'cached_lm_{block_size}_{filename}') cached_features_file = os.path.join(directory, 'cached_lm_{}_{}'.format(block_size, filename))
if os.path.exists(cached_features_file): if os.path.exists(cached_features_file):
logger.info("Loading features from cached file %s", cached_features_file) logger.info("Loading features from cached file %s", cached_features_file)
...@@ -282,7 +282,7 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -282,7 +282,7 @@ def evaluate(args, model, tokenizer, prefix=""):
"perplexity": perplexity "perplexity": perplexity
} }
output_eval_file = os.path.join(eval_output_dir, "eval_results.txt") output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
with open(output_eval_file, "w") as writer: with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(prefix)) logger.info("***** Eval results {} *****".format(prefix))
for key in sorted(result.keys()): for key in sorted(result.keys()):
...@@ -484,9 +484,11 @@ def main(): ...@@ -484,9 +484,11 @@ def main():
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
model = model_class.from_pretrained(checkpoint) model = model_class.from_pretrained(checkpoint)
model.to(args.device) model.to(args.device)
result = evaluate(args, model, tokenizer, prefix=global_step) result = evaluate(args, model, tokenizer, prefix=prefix)
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
results.update(result) results.update(result)
......
...@@ -512,9 +512,11 @@ def main(): ...@@ -512,9 +512,11 @@ def main():
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
model = model_class.from_pretrained(checkpoint) model = model_class.from_pretrained(checkpoint)
model.to(args.device) model.to(args.device)
result = evaluate(args, model, tokenizer, prefix=global_step) result = evaluate(args, model, tokenizer, prefix=prefix)
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
results.update(result) results.update(result)
...@@ -528,9 +530,11 @@ def main(): ...@@ -528,9 +530,11 @@ def main():
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
model = model_class.from_pretrained(checkpoint) model = model_class.from_pretrained(checkpoint)
model.to(args.device) model.to(args.device)
result = evaluate(args, model, tokenizer, prefix=global_step, test=True) result = evaluate(args, model, tokenizer, prefix=prefix, test=True)
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
results.update(result) results.update(result)
if best_steps: if best_steps:
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Finetuning the library models for question-answering on SQuAD (Bert, XLM, XLNet).""" """ Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
...@@ -135,9 +135,10 @@ def train(args, train_dataset, model, tokenizer): ...@@ -135,9 +135,10 @@ def train(args, train_dataset, model, tokenizer):
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0], inputs = {'input_ids': batch[0],
'attention_mask': batch[1], 'attention_mask': batch[1],
'token_type_ids': None if args.model_type == 'xlm' else batch[2],
'start_positions': batch[3], 'start_positions': batch[3],
'end_positions': batch[4]} 'end_positions': batch[4]}
if args.model_type != 'distilbert':
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2]
if args.model_type in ['xlnet', 'xlm']: if args.model_type in ['xlnet', 'xlm']:
inputs.update({'cls_index': batch[5], inputs.update({'cls_index': batch[5],
'p_mask': batch[6]}) 'p_mask': batch[6]})
...@@ -218,9 +219,10 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -218,9 +219,10 @@ def evaluate(args, model, tokenizer, prefix=""):
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
with torch.no_grad(): with torch.no_grad():
inputs = {'input_ids': batch[0], inputs = {'input_ids': batch[0],
'attention_mask': batch[1], 'attention_mask': batch[1]
'token_type_ids': None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
} }
if args.model_type != 'distilbert':
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
example_indices = batch[3] example_indices = batch[3]
if args.model_type in ['xlnet', 'xlm']: if args.model_type in ['xlnet', 'xlm']:
inputs.update({'cls_index': batch[4], inputs.update({'cls_index': batch[4],
......
absl-py==0.8.0
astor==0.8.0
atomicwrites==1.3.0
attrs==19.2.0
boto3==1.9.243
botocore==1.12.243
certifi==2019.9.11
chardet==3.0.4
Click==7.0
docutils==0.15.2
gast==0.2.2
google-pasta==0.1.7
grpcio==1.24.1
h5py==2.10.0
idna==2.8
importlib-metadata==0.23
jmespath==0.9.4
joblib==0.14.0
Keras-Applications==1.0.8
Keras-Preprocessing==1.1.0
Markdown==3.1.1
more-itertools==7.2.0
numpy==1.17.2
opt-einsum==3.1.0
packaging==19.2
pluggy==0.13.0
protobuf==3.10.0
py==1.8.0
pyparsing==2.4.2
pytest==5.2.1
python-dateutil==2.8.0
regex==2019.8.19
requests==2.22.0
s3transfer==0.2.1
sacremoses==0.0.35
sentencepiece==0.1.83
six==1.12.0
tensorboard==2.0.0
tensorflow==2.0.0
tensorflow-estimator==2.0.0
termcolor==1.1.0
torch==1.2.0
tqdm==4.36.1
urllib3==1.25.6
wcwidth==0.1.7
Werkzeug==0.16.0
wrapt==1.11.2
zipp==0.6.0
...@@ -28,7 +28,8 @@ logger = logging.getLogger(__name__) ...@@ -28,7 +28,8 @@ logger = logging.getLogger(__name__)
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json", GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json", "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json",
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json"} "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json",
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-config.json",}
class GPT2Config(PretrainedConfig): class GPT2Config(PretrainedConfig):
"""Configuration class to store the configuration of a `GPT2Model`. """Configuration class to store the configuration of a `GPT2Model`.
......
...@@ -178,10 +178,12 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc ...@@ -178,10 +178,12 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
else: else:
model_file = cached_path(model_shortcut_name, force_download=not use_cached_models) model_file = cached_path(model_shortcut_name, force_download=not use_cached_models)
convert_pt_checkpoint_to_tf(model_type, if os.path.isfile(model_shortcut_name):
model_file, model_shortcut_name = 'converted_model'
config_file, convert_pt_checkpoint_to_tf(model_type=model_type,
os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'), pytorch_checkpoint_path=model_file,
config_file=config_file,
tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'),
compare_with_pt_model=compare_with_pt_model) compare_with_pt_model=compare_with_pt_model)
os.remove(config_file) os.remove(config_file)
os.remove(model_file) os.remove(model_file)
......
...@@ -118,7 +118,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path): ...@@ -118,7 +118,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
def gelu(x): def gelu(x):
""" Original Implementation of the gelu activation function in Google Bert repo when initialy created. """ Original Implementation of the gelu activation function in Google Bert repo when initially created.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
Also see https://arxiv.org/abs/1606.08415 Also see https://arxiv.org/abs/1606.08415
......
...@@ -159,8 +159,6 @@ class MultiHeadSelfAttention(nn.Module): ...@@ -159,8 +159,6 @@ class MultiHeadSelfAttention(nn.Module):
dim_per_head = self.dim // self.n_heads dim_per_head = self.dim // self.n_heads
assert 2 <= mask.dim() <= 3
causal = (mask.dim() == 3)
mask_reshp = (bs, 1, 1, k_length) mask_reshp = (bs, 1, 1, k_length)
def shape(x): def shape(x):
...@@ -649,7 +647,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel): ...@@ -649,7 +647,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
start_positions = torch.tensor([1]) start_positions = torch.tensor([1])
end_positions = torch.tensor([3]) end_positions = torch.tensor([3])
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions) outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
loss, start_scores, end_scores = outputs[:2] loss, start_scores, end_scores = outputs[:3]
""" """
def __init__(self, config): def __init__(self, config):
......
...@@ -38,7 +38,8 @@ logger = logging.getLogger(__name__) ...@@ -38,7 +38,8 @@ logger = logging.getLogger(__name__)
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin", GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin", "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin",
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin"} "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin",
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-pytorch_model.bin",}
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
""" Load tf checkpoints in a pytorch model """ Load tf checkpoints in a pytorch model
......
...@@ -62,7 +62,7 @@ def load_bert_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path): ...@@ -62,7 +62,7 @@ def load_bert_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
def gelu(x): def gelu(x):
""" Gaussian Error Linear Unit. """ Gaussian Error Linear Unit.
Original Implementation of the gelu activation function in Google Bert repo when initialy created. Original Implementation of the gelu activation function in Google Bert repo when initially created.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
Also see https://arxiv.org/abs/1606.08415 Also see https://arxiv.org/abs/1606.08415
......
...@@ -45,7 +45,7 @@ TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -45,7 +45,7 @@ TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
### UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE ### ### UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE ###
def gelu(x): def gelu(x):
""" Gaussian Error Linear Unit. """ Gaussian Error Linear Unit.
Original Implementation of the gelu activation function in Google Bert repo when initialy created. Original Implementation of the gelu activation function in Google Bert repo when initially created.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
Also see https://arxiv.org/abs/1606.08415 Also see https://arxiv.org/abs/1606.08415
...@@ -226,8 +226,6 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer): ...@@ -226,8 +226,6 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
dim_per_head = self.dim // self.n_heads dim_per_head = self.dim // self.n_heads
assert 2 <= len(tf.shape(mask)) <= 3
causal = (len(tf.shape(mask)) == 3)
mask_reshape = [bs, 1, 1, k_length] mask_reshape = [bs, 1, 1, k_length]
def shape(x): def shape(x):
...@@ -603,7 +601,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel): ...@@ -603,7 +601,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel):
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = TFDistilBertForMaskedLM.from_pretrained('distilbert-base-uncased') model = TFDistilBertForMaskedLM.from_pretrained('distilbert-base-uncased')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1 input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
outputs = model(input_ids, masked_lm_labels=input_ids) outputs = model(input_ids)
prediction_scores = outputs[0] prediction_scores = outputs[0]
""" """
...@@ -715,9 +713,7 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel): ...@@ -715,9 +713,7 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel):
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = TFDistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased') model = TFDistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1 input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
start_positions = tf.constant([1]) outputs = model(input_ids)
end_positions = tf.constant([3])
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
start_scores, end_scores = outputs[:2] start_scores, end_scores = outputs[:2]
""" """
......
...@@ -38,7 +38,8 @@ logger = logging.getLogger(__name__) ...@@ -38,7 +38,8 @@ logger = logging.getLogger(__name__)
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-tf_model.h5", TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-tf_model.h5",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-tf_model.h5", "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-tf_model.h5",
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-tf_model.h5"} "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-tf_model.h5",
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-tf_model.h5",}
def load_gpt2_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path): def load_gpt2_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment