Commit f5891c38 authored by VictorSanh's avatar VictorSanh Committed by Victor SANH
Browse files

run_squad --> run_squad_w_distillation

parent 764a7923
This diff is collapsed.
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet) with an optional step of distillation.""" """ Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
...@@ -28,8 +28,6 @@ import torch ...@@ -28,8 +28,6 @@ import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset) TensorDataset)
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
import torch.nn.functional as F
import torch.nn as nn
from tqdm import tqdm, trange from tqdm import tqdm, trange
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
...@@ -75,7 +73,7 @@ def set_seed(args): ...@@ -75,7 +73,7 @@ def set_seed(args):
def to_list(tensor): def to_list(tensor):
return tensor.detach().cpu().tolist() return tensor.detach().cpu().tolist()
def train(args, train_dataset, model, tokenizer, teacher=None): def train(args, train_dataset, model, tokenizer):
""" Train the model """ """ Train the model """
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
tb_writer = SummaryWriter() tb_writer = SummaryWriter()
...@@ -134,8 +132,6 @@ def train(args, train_dataset, model, tokenizer, teacher=None): ...@@ -134,8 +132,6 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator): for step, batch in enumerate(epoch_iterator):
model.train() model.train()
if teacher is not None:
teacher.eval()
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0], inputs = {'input_ids': batch[0],
'attention_mask': batch[1], 'attention_mask': batch[1],
...@@ -147,27 +143,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None): ...@@ -147,27 +143,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
inputs.update({'cls_index': batch[5], inputs.update({'cls_index': batch[5],
'p_mask': batch[6]}) 'p_mask': batch[6]})
outputs = model(**inputs) outputs = model(**inputs)
loss, start_logits_stu, end_logits_stu = outputs loss = outputs[0] # model outputs are always tuple in transformers (see doc)
# Distillation loss
if teacher is not None:
if 'token_type_ids' not in inputs:
inputs['token_type_ids'] = None if args.teacher_type == 'xlm' else batch[2]
with torch.no_grad():
start_logits_tea, end_logits_tea = teacher(input_ids=inputs['input_ids'],
token_type_ids=inputs['token_type_ids'],
attention_mask=inputs['attention_mask'])
assert start_logits_tea.size() == start_logits_stu.size()
assert end_logits_tea.size() == end_logits_stu.size()
loss_fct = nn.KLDivLoss(reduction='batchmean')
loss_start = loss_fct(F.log_softmax(start_logits_stu/args.temperature, dim=-1),
F.softmax(start_logits_tea/args.temperature, dim=-1)) * (args.temperature**2)
loss_end = loss_fct(F.log_softmax(end_logits_stu/args.temperature, dim=-1),
F.softmax(end_logits_tea/args.temperature, dim=-1)) * (args.temperature**2)
loss_ce = (loss_start + loss_end)/2.
loss = args.alpha_ce*loss_ce + args.alpha_squad*loss
if args.n_gpu > 1: if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
...@@ -367,18 +343,6 @@ def main(): ...@@ -367,18 +343,6 @@ def main():
parser.add_argument("--output_dir", default=None, type=str, required=True, parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model checkpoints and predictions will be written.") help="The output directory where the model checkpoints and predictions will be written.")
# Distillation parameters (optional)
parser.add_argument('--teacher_type', default=None, type=str,
help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.")
parser.add_argument('--teacher_name_or_path', default=None, type=str,
help="Path to the already SQuAD fine-tuned teacher model. Only for distillation.")
parser.add_argument('--alpha_ce', default=0.5, type=float,
help="Distillation loss linear weight. Only for distillation.")
parser.add_argument('--alpha_squad', default=0.5, type=float,
help="True SQuAD loss linear weight. Only for distillation.")
parser.add_argument('--temperature', default=2.0, type=float,
help="Distillation temperature. Only for distillation.")
## Other parameters ## Other parameters
parser.add_argument("--config_name", default="", type=str, parser.add_argument("--config_name", default="", type=str,
help="Pretrained config name or path if not the same as model_name") help="Pretrained config name or path if not the same as model_name")
...@@ -506,17 +470,6 @@ def main(): ...@@ -506,17 +470,6 @@ def main():
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case) tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config) model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
if args.teacher_type is not None:
assert args.teacher_name_or_path is not None
assert args.alpha_ce > 0.
assert args.alpha_ce + args.alpha_squad > 0.
assert args.teacher_type != 'distilbert', "We constraint teachers not to be of type DistilBERT."
teacher_config_class, teacher_model_class, _ = MODEL_CLASSES[args.teacher_type]
teacher_config = teacher_config_class.from_pretrained(args.teacher_name_or_path)
teacher = teacher_model_class.from_pretrained(args.teacher_name_or_path, config=teacher_config)
teacher.to(args.device)
else:
teacher = None
if args.local_rank == 0: if args.local_rank == 0:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
...@@ -528,7 +481,7 @@ def main(): ...@@ -528,7 +481,7 @@ def main():
# Training # Training
if args.do_train: if args.do_train:
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False) train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
global_step, tr_loss = train(args, train_dataset, model, tokenizer, teacher=teacher) global_step, tr_loss = train(args, train_dataset, model, tokenizer)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment