Unverified Commit 54abc67a authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2255 from aaugustin/implement-best-practices

Implement some Python best practices
parents 645713e2 c11b3e29
...@@ -86,6 +86,20 @@ jobs: ...@@ -86,6 +86,20 @@ jobs:
- run: sudo pip install --progress-bar off -r docs/requirements.txt - run: sudo pip install --progress-bar off -r docs/requirements.txt
- run: sudo pip install --progress-bar off -r requirements.txt - run: sudo pip install --progress-bar off -r requirements.txt
- run: ./.circleci/deploy.sh - run: ./.circleci/deploy.sh
check_code_quality:
working_directory: ~/transformers
docker:
- image: circleci/python:3.6
resource_class: medium
parallelism: 1
steps:
- checkout
- run: sudo pip install --editable .
- run: sudo pip install torch tensorflow
- run: sudo pip install black git+git://github.com/timothycrosley/isort.git@e63ae06ec7d70b06df9e528357650281a3d3ec22#egg=isort flake8
- run: black --check --line-length 119 examples templates transformers utils
- run: isort --check-only --recursive examples templates transformers utils
- run: flake8 examples templates transformers utils
check_repository_consistency: check_repository_consistency:
working_directory: ~/transformers working_directory: ~/transformers
docker: docker:
...@@ -105,6 +119,7 @@ workflows: ...@@ -105,6 +119,7 @@ workflows:
version: 2 version: 2
build_and_test: build_and_test:
jobs: jobs:
- check_code_quality
- check_repository_consistency - check_repository_consistency
- run_examples_py3_torch - run_examples_py3_torch
- run_tests_py3_custom_tokenizers - run_tests_py3_custom_tokenizers
......
.PHONY: style
style:
black --line-length 119 examples templates transformers utils
isort --recursive examples templates transformers utils
...@@ -18,12 +18,14 @@ ...@@ -18,12 +18,14 @@
# If checking the tensors placement # If checking the tensors placement
# tf.debugging.set_log_device_placement(True) # tf.debugging.set_log_device_placement(True)
from typing import List
import timeit
from transformers import is_tf_available, is_torch_available
from time import time
import argparse import argparse
import csv import csv
import timeit
from time import time
from typing import List
from transformers import AutoConfig, AutoTokenizer, is_tf_available, is_torch_available
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
...@@ -33,7 +35,6 @@ if is_torch_available(): ...@@ -33,7 +35,6 @@ if is_torch_available():
import torch import torch
from transformers import AutoModel from transformers import AutoModel
from transformers import AutoConfig, AutoTokenizer
input_text = """Bent over their instruments, three hundred Fertilizers were plunged, as input_text = """Bent over their instruments, three hundred Fertilizers were plunged, as
the Director of Hatcheries and Conditioning entered the room, in the the Director of Hatcheries and Conditioning entered the room, in the
...@@ -247,7 +248,8 @@ the wall, slowly on into the Social Predestination Room. ...@@ -247,7 +248,8 @@ the wall, slowly on into the Social Predestination Room.
as they entered.""" as they entered."""
def create_setup_and_compute(model_names: List[str], def create_setup_and_compute(
model_names: List[str],
gpu: bool = True, gpu: bool = True,
tensorflow: bool = False, tensorflow: bool = False,
average_over: int = 3, average_over: int = 3,
...@@ -256,7 +258,8 @@ def create_setup_and_compute(model_names: List[str], ...@@ -256,7 +258,8 @@ def create_setup_and_compute(model_names: List[str],
amp: bool = False, amp: bool = False,
fp16: bool = False, fp16: bool = False,
save_to_csv: bool = False, save_to_csv: bool = False,
csv_filename: str = f"results_{round(time())}.csv"): csv_filename: str = f"results_{round(time())}.csv",
):
if xla: if xla:
tf.config.optimizer.set_jit(True) tf.config.optimizer.set_jit(True)
if amp: if amp:
...@@ -266,7 +269,7 @@ def create_setup_and_compute(model_names: List[str], ...@@ -266,7 +269,7 @@ def create_setup_and_compute(model_names: List[str],
dictionary = {model_name: {} for model_name in model_names} dictionary = {model_name: {} for model_name in model_names}
results = _compute_tensorflow(model_names, dictionary, average_over, amp) results = _compute_tensorflow(model_names, dictionary, average_over, amp)
else: else:
device = 'cuda' if (gpu and torch.cuda.is_available()) else 'cpu' device = "cuda" if (gpu and torch.cuda.is_available()) else "cpu"
dictionary = {model_name: {} for model_name in model_names} dictionary = {model_name: {} for model_name in model_names}
results = _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16) results = _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16)
...@@ -276,22 +279,40 @@ def create_setup_and_compute(model_names: List[str], ...@@ -276,22 +279,40 @@ def create_setup_and_compute(model_names: List[str],
for batch_size in results[model_name]["bs"]: for batch_size in results[model_name]["bs"]:
print("\t\t" + f"===== BATCH SIZE: {batch_size} =====") print("\t\t" + f"===== BATCH SIZE: {batch_size} =====")
for slice_size in results[model_name]["ss"]: for slice_size in results[model_name]["ss"]:
result = results[model_name]['results'][batch_size][slice_size] result = results[model_name]["results"][batch_size][slice_size]
if isinstance(result, str): if isinstance(result, str):
print(f"\t\t{model_name}/{batch_size}/{slice_size}: " print(f"\t\t{model_name}/{batch_size}/{slice_size}: " f"{result}")
f"{result}")
else: else:
print(f"\t\t{model_name}/{batch_size}/{slice_size}: " print(f"\t\t{model_name}/{batch_size}/{slice_size}: " f"{(round(1000 * result) / 1000)}" f"s")
f"{(round(1000 * result) / 1000)}"
f"s")
if save_to_csv: if save_to_csv:
with open(csv_filename, mode='w') as csv_file: with open(csv_filename, mode="w") as csv_file:
fieldnames = ['model', fieldnames = [
'1x8', '1x64', '1x128', '1x256', '1x512', '1x1024', "model",
'2x8', '2x64', '2x128', '2x256', '2x512', '2x1024', "1x8",
'4x8', '4x64', '4x128', '4x256', '4x512', '4x1024', "1x64",
'8x8', '8x64', '8x128', '8x256', '8x512', '8x1024', "1x128",
"1x256",
"1x512",
"1x1024",
"2x8",
"2x64",
"2x128",
"2x256",
"2x512",
"2x1024",
"4x8",
"4x64",
"4x128",
"4x256",
"4x512",
"4x1024",
"8x8",
"8x64",
"8x128",
"8x256",
"8x512",
"8x1024",
] ]
writer = csv.DictWriter(csv_file, fieldnames=fieldnames) writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
...@@ -299,11 +320,11 @@ def create_setup_and_compute(model_names: List[str], ...@@ -299,11 +320,11 @@ def create_setup_and_compute(model_names: List[str],
for model_name in model_names: for model_name in model_names:
model_results = { model_results = {
f'{bs}x{ss}': results[model_name]['results'][bs][ss] f"{bs}x{ss}": results[model_name]["results"][bs][ss]
for bs in results[model_name]["results"] for bs in results[model_name]["results"]
for ss in results[model_name]['results'][bs] for ss in results[model_name]["results"][bs]
} }
writer.writerow({'model': model_name, **model_results}) writer.writerow({"model": model_name, **model_results})
def _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16): def _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16):
...@@ -343,7 +364,7 @@ def _compute_pytorch(model_names, dictionary, average_over, device, torchscript, ...@@ -343,7 +364,7 @@ def _compute_pytorch(model_names, dictionary, average_over, device, torchscript,
print("Going through model with sequence of shape", sequence.shape) print("Going through model with sequence of shape", sequence.shape)
runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3) runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3)
average_time = sum(runtimes)/float(len(runtimes)) / 3.0 average_time = sum(runtimes) / float(len(runtimes)) / 3.0
dictionary[model_name]["results"][batch_size][slice_size] = average_time dictionary[model_name]["results"][batch_size][slice_size] = average_time
except RuntimeError as e: except RuntimeError as e:
print("Doesn't fit on GPU.", e) print("Doesn't fit on GPU.", e)
...@@ -379,7 +400,9 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp): ...@@ -379,7 +400,9 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp):
if max_input_size is not None and slice_size > max_input_size: if max_input_size is not None and slice_size > max_input_size:
dictionary[model_name]["results"][batch_size][slice_size] = "N/A" dictionary[model_name]["results"][batch_size][slice_size] = "N/A"
else: else:
sequence = tf.stack([tf.squeeze(tf.constant(tokenized_sequence[:slice_size])[None, :])] * batch_size) sequence = tf.stack(
[tf.squeeze(tf.constant(tokenized_sequence[:slice_size])[None, :])] * batch_size
)
try: try:
print("Going through model with sequence of shape", sequence.shape) print("Going through model with sequence of shape", sequence.shape)
...@@ -387,7 +410,7 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp): ...@@ -387,7 +410,7 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp):
inference(sequence) inference(sequence)
runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3) runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3)
average_time = sum(runtimes)/float(len(runtimes)) / 3.0 average_time = sum(runtimes) / float(len(runtimes)) / 3.0
dictionary[model_name]["results"][batch_size][slice_size] = average_time dictionary[model_name]["results"][batch_size][slice_size] = average_time
except tf.errors.ResourceExhaustedError as e: except tf.errors.ResourceExhaustedError as e:
print("Doesn't fit on GPU.", e) print("Doesn't fit on GPU.", e)
...@@ -399,33 +422,64 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp): ...@@ -399,33 +422,64 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--models", required=False, type=str, default='all', help="Model checkpoints to be provided " parser.add_argument(
"--models",
required=False,
type=str,
default="all",
help="Model checkpoints to be provided "
"to the AutoModel classes. Leave " "to the AutoModel classes. Leave "
"blank to benchmark the base version " "blank to benchmark the base version "
"of all available model " "of all available model "
"architectures.") "architectures.",
parser.add_argument("--torch", required=False, action="store_true", help="Benchmark the Pytorch version of the " )
"models") parser.add_argument(
parser.add_argument("--torch_cuda", required=False, action="store_true", help="Pytorch only: run on available " "--torch", required=False, action="store_true", help="Benchmark the Pytorch version of the " "models"
"cuda devices") )
parser.add_argument("--torchscript", required=False, action="store_true", help="Pytorch only: trace the models " parser.add_argument(
"using torchscript") "--torch_cuda", required=False, action="store_true", help="Pytorch only: run on available " "cuda devices"
parser.add_argument("--tensorflow", required=False, action="store_true", help="Benchmark the TensorFlow version " )
parser.add_argument(
"--torchscript",
required=False,
action="store_true",
help="Pytorch only: trace the models " "using torchscript",
)
parser.add_argument(
"--tensorflow",
required=False,
action="store_true",
help="Benchmark the TensorFlow version "
"of the models. Will run on GPU if " "of the models. Will run on GPU if "
"the correct dependencies are " "the correct dependencies are "
"installed") "installed",
)
parser.add_argument("--xla", required=False, action="store_true", help="TensorFlow only: use XLA acceleration.") parser.add_argument("--xla", required=False, action="store_true", help="TensorFlow only: use XLA acceleration.")
parser.add_argument("--amp", required=False, action="store_true", help="TensorFlow only: use automatic mixed precision acceleration.") parser.add_argument(
parser.add_argument("--fp16", required=False, action="store_true", help="PyTorch only: use FP16 to accelerate inference.") "--amp",
parser.add_argument("--keras_predict", required=False, action="store_true", help="Whether to use model.predict " required=False,
"instead of model() to do a " action="store_true",
"forward pass.") help="TensorFlow only: use automatic mixed precision acceleration.",
)
parser.add_argument(
"--fp16", required=False, action="store_true", help="PyTorch only: use FP16 to accelerate inference."
)
parser.add_argument(
"--keras_predict",
required=False,
action="store_true",
help="Whether to use model.predict " "instead of model() to do a " "forward pass.",
)
parser.add_argument("--save_to_csv", required=False, action="store_true", help="Save to a CSV file.") parser.add_argument("--save_to_csv", required=False, action="store_true", help="Save to a CSV file.")
parser.add_argument("--csv_filename", required=False, default=None, help="CSV filename used if saving results to csv.") parser.add_argument(
parser.add_argument("--average_over", required=False, default=30, type=int, help="Times an experiment will be run.") "--csv_filename", required=False, default=None, help="CSV filename used if saving results to csv."
)
parser.add_argument(
"--average_over", required=False, default=30, type=int, help="Times an experiment will be run."
)
args = parser.parse_args() args = parser.parse_args()
if args.models == 'all': if args.models == "all":
args.models = [ args.models = [
"gpt2", "gpt2",
"bert-base-cased", "bert-base-cased",
...@@ -436,7 +490,7 @@ def main(): ...@@ -436,7 +490,7 @@ def main():
"distilbert-base-uncased", "distilbert-base-uncased",
"distilgpt2", "distilgpt2",
"roberta-base", "roberta-base",
"ctrl" "ctrl",
] ]
else: else:
args.models = args.models.split() args.models = args.models.split()
...@@ -453,7 +507,7 @@ def main(): ...@@ -453,7 +507,7 @@ def main():
fp16=args.fp16, fp16=args.fp16,
save_to_csv=args.save_to_csv, save_to_csv=args.save_to_csv,
csv_filename=args.csv_filename, csv_filename=args.csv_filename,
average_over=args.average_over average_over=args.average_over,
) )
else: else:
raise ImportError("Trying to run a PyTorch benchmark but PyTorch was not found in the environment.") raise ImportError("Trying to run a PyTorch benchmark but PyTorch was not found in the environment.")
...@@ -467,11 +521,11 @@ def main(): ...@@ -467,11 +521,11 @@ def main():
amp=args.amp, amp=args.amp,
save_to_csv=args.save_to_csv, save_to_csv=args.save_to_csv,
csv_filename=args.csv_filename, csv_filename=args.csv_filename,
average_over=args.average_over average_over=args.average_over,
) )
else: else:
raise ImportError("Trying to run a TensorFlow benchmark but TensorFlow was not found in the environment.") raise ImportError("Trying to run a TensorFlow benchmark but TensorFlow was not found in the environment.")
if __name__ == '__main__':
main()
if __name__ == "__main__":
main()
from pathlib import Path
import tarfile
import urllib.request
import torch import torch
from transformers.tokenization_camembert import CamembertTokenizer
from transformers.modeling_camembert import CamembertForMaskedLM from transformers.modeling_camembert import CamembertForMaskedLM
from transformers.tokenization_camembert import CamembertTokenizer
def fill_mask(masked_input, model, tokenizer, topk=5): def fill_mask(masked_input, model, tokenizer, topk=5):
# Adapted from https://github.com/pytorch/fairseq/blob/master/fairseq/models/roberta/hub_interface.py # Adapted from https://github.com/pytorch/fairseq/blob/master/fairseq/models/roberta/hub_interface.py
assert masked_input.count('<mask>') == 1 assert masked_input.count("<mask>") == 1
input_ids = torch.tensor(tokenizer.encode(masked_input, add_special_tokens=True)).unsqueeze(0) # Batch size 1 input_ids = torch.tensor(tokenizer.encode(masked_input, add_special_tokens=True)).unsqueeze(0) # Batch size 1
logits = model(input_ids)[0] # The last hidden-state is the first element of the output tuple logits = model(input_ids)[0] # The last hidden-state is the first element of the output tuple
masked_index = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero().item() masked_index = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero().item()
logits = logits[0, masked_index, :] logits = logits[0, masked_index, :]
prob = logits.softmax(dim=0) prob = logits.softmax(dim=0)
values, indices = prob.topk(k=topk, dim=0) values, indices = prob.topk(k=topk, dim=0)
topk_predicted_token_bpe = ' '.join([tokenizer.convert_ids_to_tokens(indices[i].item()) topk_predicted_token_bpe = " ".join(
for i in range(len(indices))]) [tokenizer.convert_ids_to_tokens(indices[i].item()) for i in range(len(indices))]
)
masked_token = tokenizer.mask_token masked_token = tokenizer.mask_token
topk_filled_outputs = [] topk_filled_outputs = []
for index, predicted_token_bpe in enumerate(topk_predicted_token_bpe.split(' ')): for index, predicted_token_bpe in enumerate(topk_predicted_token_bpe.split(" ")):
predicted_token = predicted_token_bpe.replace('\u2581', ' ') predicted_token = predicted_token_bpe.replace("\u2581", " ")
if " {0}".format(masked_token) in masked_input: if " {0}".format(masked_token) in masked_input:
topk_filled_outputs.append(( topk_filled_outputs.append(
masked_input.replace( (
' {0}'.format(masked_token), predicted_token masked_input.replace(" {0}".format(masked_token), predicted_token),
),
values[index].item(), values[index].item(),
predicted_token, predicted_token,
)) )
)
else: else:
topk_filled_outputs.append(( topk_filled_outputs.append(
masked_input.replace(masked_token, predicted_token), (masked_input.replace(masked_token, predicted_token), values[index].item(), predicted_token,)
values[index].item(), )
predicted_token,
))
return topk_filled_outputs return topk_filled_outputs
tokenizer = CamembertTokenizer.from_pretrained('camembert-base') tokenizer = CamembertTokenizer.from_pretrained("camembert-base")
model = CamembertForMaskedLM.from_pretrained('camembert-base') model = CamembertForMaskedLM.from_pretrained("camembert-base")
model.eval() model.eval()
masked_input = "Le camembert est <mask> :)" masked_input = "Le camembert est <mask> :)"
......
...@@ -22,48 +22,57 @@ ...@@ -22,48 +22,57 @@
--model_name openai-gpt \ --model_name openai-gpt \
--do_train \ --do_train \
--do_eval \ --do_eval \
--train_dataset $ROC_STORIES_DIR/cloze_test_val__spring2016\ -\ cloze_test_ALL_val.csv \ --train_dataset "$ROC_STORIES_DIR/cloze_test_val__spring2016 - cloze_test_ALL_val.csv" \
--eval_dataset $ROC_STORIES_DIR/cloze_test_test__spring2016\ -\ cloze_test_ALL_test.csv \ --eval_dataset "$ROC_STORIES_DIR/cloze_test_test__spring2016 - cloze_test_ALL_test.csv" \
--output_dir ../log \ --output_dir ../log \
--train_batch_size 16 \ --train_batch_size 16 \
""" """
import argparse import argparse
import os
import csv import csv
import random
import logging import logging
from tqdm import tqdm, trange import os
import random
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
TensorDataset) from tqdm import tqdm, trange
from transformers import (
CONFIG_NAME,
WEIGHTS_NAME,
AdamW,
OpenAIGPTDoubleHeadsModel,
OpenAIGPTTokenizer,
cached_path,
get_linear_schedule_with_warmup,
)
from transformers import (OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer,
AdamW, cached_path, WEIGHTS_NAME, CONFIG_NAME,
get_linear_schedule_with_warmup)
ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz" ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz"
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(
datefmt = '%m/%d/%Y %H:%M:%S', format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
level = logging.INFO) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def accuracy(out, labels): def accuracy(out, labels):
outputs = np.argmax(out, axis=1) outputs = np.argmax(out, axis=1)
return np.sum(outputs == labels) return np.sum(outputs == labels)
def load_rocstories_dataset(dataset_path): def load_rocstories_dataset(dataset_path):
""" Output a list of tuples(story, 1st continuation, 2nd continuation, label) """ """ Output a list of tuples(story, 1st continuation, 2nd continuation, label) """
with open(dataset_path, encoding='utf_8') as f: with open(dataset_path, encoding="utf_8") as f:
f = csv.reader(f) f = csv.reader(f)
output = [] output = []
next(f) # skip the first line next(f) # skip the first line
for line in tqdm(f): for line in tqdm(f):
output.append((' '.join(line[1:5]), line[5], line[6], int(line[-1])-1)) output.append((" ".join(line[1:5]), line[5], line[6], int(line[-1]) - 1))
return output return output
def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, delimiter_token, clf_token): def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, delimiter_token, clf_token):
""" Pre-process datasets containing lists of tuples(story, 1st continuation, 2nd continuation, label) """ Pre-process datasets containing lists of tuples(story, 1st continuation, 2nd continuation, label)
...@@ -80,56 +89,68 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d ...@@ -80,56 +89,68 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
for i, (story, cont1, cont2, mc_label), in enumerate(dataset): for i, (story, cont1, cont2, mc_label), in enumerate(dataset):
with_cont1 = [start_token] + story[:cap_length] + [delimiter_token] + cont1[:cap_length] + [clf_token] with_cont1 = [start_token] + story[:cap_length] + [delimiter_token] + cont1[:cap_length] + [clf_token]
with_cont2 = [start_token] + story[:cap_length] + [delimiter_token] + cont2[:cap_length] + [clf_token] with_cont2 = [start_token] + story[:cap_length] + [delimiter_token] + cont2[:cap_length] + [clf_token]
input_ids[i, 0, :len(with_cont1)] = with_cont1 input_ids[i, 0, : len(with_cont1)] = with_cont1
input_ids[i, 1, :len(with_cont2)] = with_cont2 input_ids[i, 1, : len(with_cont2)] = with_cont2
mc_token_ids[i, 0] = len(with_cont1) - 1 mc_token_ids[i, 0] = len(with_cont1) - 1
mc_token_ids[i, 1] = len(with_cont2) - 1 mc_token_ids[i, 1] = len(with_cont2) - 1
lm_labels[i, 0, :len(with_cont1)] = with_cont1 lm_labels[i, 0, : len(with_cont1)] = with_cont1
lm_labels[i, 1, :len(with_cont2)] = with_cont2 lm_labels[i, 1, : len(with_cont2)] = with_cont2
mc_labels[i] = mc_label mc_labels[i] = mc_label
all_inputs = (input_ids, mc_token_ids, lm_labels, mc_labels) all_inputs = (input_ids, mc_token_ids, lm_labels, mc_labels)
tensor_datasets.append(tuple(torch.tensor(t) for t in all_inputs)) tensor_datasets.append(tuple(torch.tensor(t) for t in all_inputs))
return tensor_datasets return tensor_datasets
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default='openai-gpt', parser.add_argument("--model_name", type=str, default="openai-gpt", help="pretrained model name")
help='pretrained model name') parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument(
parser.add_argument("--output_dir", default=None, type=str, required=True, "--output_dir",
help="The output directory where the model predictions and checkpoints will be written.") default=None,
parser.add_argument('--train_dataset', type=str, default='') type=str,
parser.add_argument('--eval_dataset', type=str, default='') required=True,
parser.add_argument('--seed', type=int, default=42) help="The output directory where the model predictions and checkpoints will be written.",
parser.add_argument('--num_train_epochs', type=int, default=3) )
parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument("--train_dataset", type=str, default="")
parser.add_argument('--eval_batch_size', type=int, default=16) parser.add_argument("--eval_dataset", type=str, default="")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, parser.add_argument("--seed", type=int, default=42)
help="Epsilon for Adam optimizer.") parser.add_argument("--num_train_epochs", type=int, default=3)
parser.add_argument('--max_grad_norm', type=int, default=1) parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument("--max_steps", default=-1, type=int, parser.add_argument("--eval_batch_size", type=int, default=16)
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", type=int, default=1)
parser.add_argument(
"--max_steps",
default=-1,
type=int,
help="If > 0: set total number of training \ help="If > 0: set total number of training \
steps to perform. Override num_train_epochs.") steps to perform. Override num_train_epochs.",
parser.add_argument('--gradient_accumulation_steps', type=int, default=1, )
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before\ help="Number of updates steps to accumulate before\
performing a backward/update pass.") performing a backward/update pass.",
parser.add_argument('--learning_rate', type=float, default=6.25e-5) )
parser.add_argument("--warmup_steps", default=0, type=int, parser.add_argument("--learning_rate", type=float, default=6.25e-5)
help="Linear warmup over warmup_steps.") parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument('--lr_schedule', type=str, default='warmup_linear') parser.add_argument("--lr_schedule", type=str, default="warmup_linear")
parser.add_argument('--weight_decay', type=float, default=0.01) parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument('--lm_coef', type=float, default=0.9) parser.add_argument("--lm_coef", type=float, default=0.9)
parser.add_argument('--n_valid', type=int, default=374) parser.add_argument("--n_valid", type=int, default=374)
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
if args.server_ip and args.server_port: if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd import ptvsd
print("Waiting for debugger attach") print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach() ptvsd.wait_for_attach()
...@@ -152,7 +173,7 @@ def main(): ...@@ -152,7 +173,7 @@ def main():
# Load tokenizer and model # Load tokenizer and model
# This loading functions also add new tokens and embeddings called `special tokens` # This loading functions also add new tokens and embeddings called `special tokens`
# These new embeddings will be fine-tuned on the RocStories dataset # These new embeddings will be fine-tuned on the RocStories dataset
special_tokens = ['_start_', '_delimiter_', '_classify_'] special_tokens = ["_start_", "_delimiter_", "_classify_"]
tokenizer = OpenAIGPTTokenizer.from_pretrained(args.model_name) tokenizer = OpenAIGPTTokenizer.from_pretrained(args.model_name)
tokenizer.add_tokens(special_tokens) tokenizer.add_tokens(special_tokens)
special_tokens_ids = tokenizer.convert_tokens_to_ids(special_tokens) special_tokens_ids = tokenizer.convert_tokens_to_ids(special_tokens)
...@@ -163,6 +184,7 @@ def main(): ...@@ -163,6 +184,7 @@ def main():
# Load and encode the datasets # Load and encode the datasets
if not args.train_dataset and not args.eval_dataset: if not args.train_dataset and not args.eval_dataset:
roc_stories = cached_path(ROCSTORIES_URL) roc_stories = cached_path(ROCSTORIES_URL)
def tokenize_and_encode(obj): def tokenize_and_encode(obj):
""" Tokenize and encode a nested object """ """ Tokenize and encode a nested object """
if isinstance(obj, str): if isinstance(obj, str):
...@@ -170,6 +192,7 @@ def main(): ...@@ -170,6 +192,7 @@ def main():
elif isinstance(obj, int): elif isinstance(obj, int):
return obj return obj
return list(tokenize_and_encode(o) for o in obj) return list(tokenize_and_encode(o) for o in obj)
logger.info("Encoding dataset...") logger.info("Encoding dataset...")
train_dataset = load_rocstories_dataset(args.train_dataset) train_dataset = load_rocstories_dataset(args.train_dataset)
eval_dataset = load_rocstories_dataset(args.eval_dataset) eval_dataset = load_rocstories_dataset(args.eval_dataset)
...@@ -178,8 +201,11 @@ def main(): ...@@ -178,8 +201,11 @@ def main():
# Compute the max input length for the Transformer # Compute the max input length for the Transformer
max_length = model.config.n_positions // 2 - 2 max_length = model.config.n_positions // 2 - 2
input_length = max(len(story[:max_length]) + max(len(cont1[:max_length]), len(cont2[:max_length])) + 3 \ input_length = max(
for dataset in encoded_datasets for story, cont1, cont2, _ in dataset) len(story[:max_length]) + max(len(cont1[:max_length]), len(cont2[:max_length])) + 3
for dataset in encoded_datasets
for story, cont1, cont2, _ in dataset
)
input_length = min(input_length, model.config.n_positions) # Max size of input for the pre-trained model input_length = min(input_length, model.config.n_positions) # Max size of input for the pre-trained model
# Prepare inputs tensors and dataloaders # Prepare inputs tensors and dataloaders
...@@ -198,20 +224,23 @@ def main(): ...@@ -198,20 +224,23 @@ def main():
if args.do_train: if args.do_train:
if args.max_steps > 0: if args.max_steps > 0:
t_total = args.max_steps t_total = args.max_steps
args.num_train_epochs = args.max_steps //\ args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
(len(train_dataloader) // args.gradient_accumulation_steps) + 1
else: else:
t_total = len(train_dataloader)\ t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
// args.gradient_accumulation_steps * args.num_train_epochs
param_optimizer = list(model.named_parameters()) param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, {
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
] ]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)
if args.do_train: if args.do_train:
nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None
...@@ -230,14 +259,16 @@ def main(): ...@@ -230,14 +259,16 @@ def main():
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
tr_loss += loss.item() tr_loss += loss.item()
exp_average_loss = loss.item() if exp_average_loss is None else 0.7*exp_average_loss+0.3*loss.item() exp_average_loss = (
loss.item() if exp_average_loss is None else 0.7 * exp_average_loss + 0.3 * loss.item()
)
nb_tr_steps += 1 nb_tr_steps += 1
tqdm_bar.desc = "Training loss: {:.2e} lr: {:.2e}".format(exp_average_loss, scheduler.get_lr()[0]) tqdm_bar.desc = "Training loss: {:.2e} lr: {:.2e}".format(exp_average_loss, scheduler.get_lr()[0])
# Save a trained model # Save a trained model
if args.do_train: if args.do_train:
# Save a trained model, configuration and tokenizer # Save a trained model, configuration and tokenizer
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model itself model_to_save = model.module if hasattr(model, "module") else model # Only save the model itself
# If we save using the predefined names, we can load using `from_pretrained` # If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
...@@ -260,10 +291,12 @@ def main(): ...@@ -260,10 +291,12 @@ def main():
batch = tuple(t.to(device) for t in batch) batch = tuple(t.to(device) for t in batch)
input_ids, mc_token_ids, lm_labels, mc_labels = batch input_ids, mc_token_ids, lm_labels, mc_labels = batch
with torch.no_grad(): with torch.no_grad():
_, mc_loss, _, mc_logits = model(input_ids, mc_token_ids=mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels) _, mc_loss, _, mc_logits = model(
input_ids, mc_token_ids=mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels
)
mc_logits = mc_logits.detach().cpu().numpy() mc_logits = mc_logits.detach().cpu().numpy()
mc_labels = mc_labels.to('cpu').numpy() mc_labels = mc_labels.to("cpu").numpy()
tmp_eval_accuracy = accuracy(mc_logits, mc_labels) tmp_eval_accuracy = accuracy(mc_logits, mc_labels)
eval_loss += mc_loss.mean().item() eval_loss += mc_loss.mean().item()
...@@ -274,10 +307,8 @@ def main(): ...@@ -274,10 +307,8 @@ def main():
eval_loss = eval_loss / nb_eval_steps eval_loss = eval_loss / nb_eval_steps
eval_accuracy = eval_accuracy / nb_eval_examples eval_accuracy = eval_accuracy / nb_eval_examples
train_loss = tr_loss/nb_tr_steps if args.do_train else None train_loss = tr_loss / nb_tr_steps if args.do_train else None
result = {'eval_loss': eval_loss, result = {"eval_loss": eval_loss, "eval_accuracy": eval_accuracy, "train_loss": train_loss}
'eval_accuracy': eval_accuracy,
'train_loss': train_loss}
output_eval_file = os.path.join(args.output_dir, "eval_results.txt") output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer: with open(output_eval_file, "w") as writer:
...@@ -286,5 +317,6 @@ def main(): ...@@ -286,5 +317,6 @@ def main():
logger.info(" %s = %s", key, str(result[key])) logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key]))) writer.write("%s = %s\n" % (key, str(result[key])))
if __name__ == '__main__':
if __name__ == "__main__":
main() main()
...@@ -19,51 +19,48 @@ ...@@ -19,51 +19,48 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import argparse import argparse
import logging
import csv import csv
import glob
import logging
import os import os
import random import random
import sys import sys
import glob
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
TensorDataset)
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from transformers import (
WEIGHTS_NAME,
AdamW,
BertConfig,
BertForMultipleChoice,
BertTokenizer,
get_linear_schedule_with_warmup,
)
try: try:
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
except: except ImportError:
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from tqdm import tqdm, trange
from transformers import (WEIGHTS_NAME, BertConfig,
BertForMultipleChoice, BertTokenizer)
from transformers import AdamW, get_linear_schedule_with_warmup
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \ ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in [BertConfig]), ())
for conf in [BertConfig]), ())
MODEL_CLASSES = { MODEL_CLASSES = {
'bert': (BertConfig, BertForMultipleChoice, BertTokenizer), "bert": (BertConfig, BertForMultipleChoice, BertTokenizer),
} }
class SwagExample(object): class SwagExample(object):
"""A single training/test example for the SWAG dataset.""" """A single training/test example for the SWAG dataset."""
def __init__(self,
swag_id, def __init__(self, swag_id, context_sentence, start_ending, ending_0, ending_1, ending_2, ending_3, label=None):
context_sentence,
start_ending,
ending_0,
ending_1,
ending_2,
ending_3,
label = None):
self.swag_id = swag_id self.swag_id = swag_id
self.context_sentence = context_sentence self.context_sentence = context_sentence
self.start_ending = start_ending self.start_ending = start_ending
...@@ -79,7 +76,7 @@ class SwagExample(object): ...@@ -79,7 +76,7 @@ class SwagExample(object):
return self.__repr__() return self.__repr__()
def __repr__(self): def __repr__(self):
l = [ attributes = [
"swag_id: {}".format(self.swag_id), "swag_id: {}".format(self.swag_id),
"context_sentence: {}".format(self.context_sentence), "context_sentence: {}".format(self.context_sentence),
"start_ending: {}".format(self.start_ending), "start_ending: {}".format(self.start_ending),
...@@ -90,61 +87,53 @@ class SwagExample(object): ...@@ -90,61 +87,53 @@ class SwagExample(object):
] ]
if self.label is not None: if self.label is not None:
l.append("label: {}".format(self.label)) attributes.append("label: {}".format(self.label))
return ", ".join(l) return ", ".join(attributes)
class InputFeatures(object):
def __init__(self,
example_id,
choices_features,
label
): class InputFeatures(object):
def __init__(self, example_id, choices_features, label):
self.example_id = example_id self.example_id = example_id
self.choices_features = [ self.choices_features = [
{ {"input_ids": input_ids, "input_mask": input_mask, "segment_ids": segment_ids}
'input_ids': input_ids,
'input_mask': input_mask,
'segment_ids': segment_ids
}
for _, input_ids, input_mask, segment_ids in choices_features for _, input_ids, input_mask, segment_ids in choices_features
] ]
self.label = label self.label = label
def read_swag_examples(input_file, is_training=True): def read_swag_examples(input_file, is_training=True):
with open(input_file, 'r', encoding='utf-8') as f: with open(input_file, "r", encoding="utf-8") as f:
reader = csv.reader(f) reader = csv.reader(f)
lines = [] lines = []
for line in reader: for line in reader:
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
line = list(unicode(cell, 'utf-8') for cell in line) line = list(unicode(cell, "utf-8") for cell in line) # noqa: F821
lines.append(line) lines.append(line)
if is_training and lines[0][-1] != 'label': if is_training and lines[0][-1] != "label":
raise ValueError( raise ValueError("For training, the input file must contain a label column.")
"For training, the input file must contain a label column."
)
examples = [ examples = [
SwagExample( SwagExample(
swag_id = line[2], swag_id=line[2],
context_sentence = line[4], context_sentence=line[4],
start_ending = line[5], # in the swag dataset, the start_ending=line[5], # in the swag dataset, the
# common beginning of each # common beginning of each
# choice is stored in "sent2". # choice is stored in "sent2".
ending_0 = line[7], ending_0=line[7],
ending_1 = line[8], ending_1=line[8],
ending_2 = line[9], ending_2=line[9],
ending_3 = line[10], ending_3=line[10],
label = int(line[11]) if is_training else None label=int(line[11]) if is_training else None,
) for line in lines[1:] # we skip the line with the column names )
for line in lines[1:] # we skip the line with the column names
] ]
return examples return examples
def convert_examples_to_features(examples, tokenizer, max_seq_length,
is_training): def convert_examples_to_features(examples, tokenizer, max_seq_length, is_training):
"""Loads a data file into a list of `InputBatch`s.""" """Loads a data file into a list of `InputBatch`s."""
# Swag is a multiple choice task. To perform this task using Bert, # Swag is a multiple choice task. To perform this task using Bert,
...@@ -204,23 +193,18 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -204,23 +193,18 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
logger.info("swag_id: {}".format(example.swag_id)) logger.info("swag_id: {}".format(example.swag_id))
for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features): for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features):
logger.info("choice: {}".format(choice_idx)) logger.info("choice: {}".format(choice_idx))
logger.info("tokens: {}".format(' '.join(tokens))) logger.info("tokens: {}".format(" ".join(tokens)))
logger.info("input_ids: {}".format(' '.join(map(str, input_ids)))) logger.info("input_ids: {}".format(" ".join(map(str, input_ids))))
logger.info("input_mask: {}".format(' '.join(map(str, input_mask)))) logger.info("input_mask: {}".format(" ".join(map(str, input_mask))))
logger.info("segment_ids: {}".format(' '.join(map(str, segment_ids)))) logger.info("segment_ids: {}".format(" ".join(map(str, segment_ids))))
if is_training: if is_training:
logger.info("label: {}".format(label)) logger.info("label: {}".format(label))
features.append( features.append(InputFeatures(example_id=example.swag_id, choices_features=choices_features, label=label))
InputFeatures(
example_id = example.swag_id,
choices_features = choices_features,
label = label
)
)
return features return features
def _truncate_seq_pair(tokens_a, tokens_b, max_length): def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length.""" """Truncates a sequence pair in place to the maximum length."""
...@@ -237,18 +221,14 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length): ...@@ -237,18 +221,14 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
else: else:
tokens_b.pop() tokens_b.pop()
def accuracy(out, labels): def accuracy(out, labels):
outputs = np.argmax(out, axis=1) outputs = np.argmax(out, axis=1)
return np.sum(outputs == labels) return np.sum(outputs == labels)
def select_field(features, field): def select_field(features, field):
return [ return [[choice[field] for choice in feature.choices_features] for feature in features]
[
choice[field]
for choice in feature.choices_features
]
for feature in features
]
def set_seed(args): def set_seed(args):
...@@ -258,24 +238,28 @@ def set_seed(args): ...@@ -258,24 +238,28 @@ def set_seed(args):
if args.n_gpu > 0: if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed) torch.cuda.manual_seed_all(args.seed)
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False): def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
if args.local_rank not in [-1, 0]: if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Load data features from cache or dataset file # Load data features from cache or dataset file
input_file = args.predict_file if evaluate else args.train_file input_file = args.predict_file if evaluate else args.train_file
cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format( cached_features_file = os.path.join(
'dev' if evaluate else 'train', os.path.dirname(input_file),
list(filter(None, args.model_name_or_path.split('/'))).pop(), "cached_{}_{}_{}".format(
str(args.max_seq_length))) "dev" if evaluate else "train",
list(filter(None, args.model_name_or_path.split("/"))).pop(),
str(args.max_seq_length),
),
)
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples: if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
logger.info("Loading features from cached file %s", cached_features_file) logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file) features = torch.load(cached_features_file)
else: else:
logger.info("Creating features from dataset file at %s", input_file) logger.info("Creating features from dataset file at %s", input_file)
examples = read_swag_examples(input_file) examples = read_swag_examples(input_file)
features = convert_examples_to_features( features = convert_examples_to_features(examples, tokenizer, args.max_seq_length, not evaluate)
examples, tokenizer, args.max_seq_length, not evaluate)
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file) logger.info("Saving features into cached file %s", cached_features_file)
...@@ -285,21 +269,21 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ...@@ -285,21 +269,21 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Convert to Tensors and build dataset # Convert to Tensors and build dataset
all_input_ids = torch.tensor(select_field(features, 'input_ids'), dtype=torch.long) all_input_ids = torch.tensor(select_field(features, "input_ids"), dtype=torch.long)
all_input_mask = torch.tensor(select_field(features, 'input_mask'), dtype=torch.long) all_input_mask = torch.tensor(select_field(features, "input_mask"), dtype=torch.long)
all_segment_ids = torch.tensor(select_field(features, 'segment_ids'), dtype=torch.long) all_segment_ids = torch.tensor(select_field(features, "segment_ids"), dtype=torch.long)
all_label = torch.tensor([f.label for f in features], dtype=torch.long) all_label = torch.tensor([f.label for f in features], dtype=torch.long)
if evaluate: if evaluate:
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
all_label)
else: else:
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
all_label)
if output_examples: if output_examples:
return dataset, examples, features return dataset, examples, features
return dataset return dataset
def train(args, train_dataset, model, tokenizer): def train(args, train_dataset, model, tokenizer):
""" Train the model """ """ Train the model """
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
...@@ -316,13 +300,18 @@ def train(args, train_dataset, model, tokenizer): ...@@ -316,13 +300,18 @@ def train(args, train_dataset, model, tokenizer):
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
# Prepare optimizer and schedule (linear warmup and decay) # Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight'] no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, {
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
] ]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)
if args.fp16: if args.fp16:
try: try:
from apex import amp from apex import amp
...@@ -336,17 +325,21 @@ def train(args, train_dataset, model, tokenizer): ...@@ -336,17 +325,21 @@ def train(args, train_dataset, model, tokenizer):
# Distributed training (should be after apex fp16 initialization) # Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], model = torch.nn.parallel.DistributedDataParallel(
output_device=args.local_rank, model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
find_unused_parameters=True) )
# Train! # Train!
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", logger.info(
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) " Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size
* args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
)
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total) logger.info(" Total optimization steps = %d", t_total)
...@@ -360,11 +353,13 @@ def train(args, train_dataset, model, tokenizer): ...@@ -360,11 +353,13 @@ def train(args, train_dataset, model, tokenizer):
for step, batch in enumerate(epoch_iterator): for step, batch in enumerate(epoch_iterator):
model.train() model.train()
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0], inputs = {
'attention_mask': batch[1], "input_ids": batch[0],
#'token_type_ids': None if args.model_type == 'xlm' else batch[2], "attention_mask": batch[1],
'token_type_ids': batch[2], # 'token_type_ids': None if args.model_type == 'xlm' else batch[2],
'labels': batch[3]} "token_type_ids": batch[2],
"labels": batch[3],
}
# if args.model_type in ['xlnet', 'xlm']: # if args.model_type in ['xlnet', 'xlm']:
# inputs.update({'cls_index': batch[5], # inputs.update({'cls_index': batch[5],
# 'p_mask': batch[6]}) # 'p_mask': batch[6]})
...@@ -393,23 +388,27 @@ def train(args, train_dataset, model, tokenizer): ...@@ -393,23 +388,27 @@ def train(args, train_dataset, model, tokenizer):
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
# Log metrics # Log metrics
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well if (
args.local_rank == -1 and args.evaluate_during_training
): # Only evaluate when single GPU otherwise metrics may not average well
results = evaluate(args, model, tokenizer) results = evaluate(args, model, tokenizer)
for key, value in results.items(): for key, value in results.items():
tb_writer.add_scalar('eval_{}'.format(key), value, global_step) tb_writer.add_scalar("eval_{}".format(key), value, global_step)
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step) tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
logging_loss = tr_loss logging_loss = tr_loss
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
# Save model checkpoint # Save model checkpoint
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir) model_to_save.save_pretrained(output_dir)
tokenizer.save_vocabulary(output_dir) tokenizer.save_vocabulary(output_dir)
torch.save(args, os.path.join(output_dir, 'training_args.bin')) torch.save(args, os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to %s", output_dir) logger.info("Saving model checkpoint to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps: if args.max_steps > 0 and global_step > args.max_steps:
...@@ -424,6 +423,7 @@ def train(args, train_dataset, model, tokenizer): ...@@ -424,6 +423,7 @@ def train(args, train_dataset, model, tokenizer):
return global_step, tr_loss / global_step return global_step, tr_loss / global_step
def evaluate(args, model, tokenizer, prefix=""): def evaluate(args, model, tokenizer, prefix=""):
dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True) dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
...@@ -440,7 +440,6 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -440,7 +440,6 @@ def evaluate(args, model, tokenizer, prefix=""):
logger.info(" Num examples = %d", len(dataset)) logger.info(" Num examples = %d", len(dataset))
logger.info(" Batch size = %d", args.eval_batch_size) logger.info(" Batch size = %d", args.eval_batch_size)
eval_loss, eval_accuracy = 0, 0 eval_loss, eval_accuracy = 0, 0
nb_eval_steps, nb_eval_examples = 0, 0 nb_eval_steps, nb_eval_examples = 0, 0
...@@ -448,11 +447,13 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -448,11 +447,13 @@ def evaluate(args, model, tokenizer, prefix=""):
model.eval() model.eval()
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
with torch.no_grad(): with torch.no_grad():
inputs = {'input_ids': batch[0], inputs = {
'attention_mask': batch[1], "input_ids": batch[0],
"attention_mask": batch[1],
# 'token_type_ids': None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids # 'token_type_ids': None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
'token_type_ids': batch[2], "token_type_ids": batch[2],
'labels': batch[3]} "labels": batch[3],
}
# if args.model_type in ['xlnet', 'xlm']: # if args.model_type in ['xlnet', 'xlm']:
# inputs.update({'cls_index': batch[4], # inputs.update({'cls_index': batch[4],
...@@ -462,17 +463,16 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -462,17 +463,16 @@ def evaluate(args, model, tokenizer, prefix=""):
eval_loss += tmp_eval_loss.mean().item() eval_loss += tmp_eval_loss.mean().item()
logits = logits.detach().cpu().numpy() logits = logits.detach().cpu().numpy()
label_ids = inputs['labels'].to('cpu').numpy() label_ids = inputs["labels"].to("cpu").numpy()
tmp_eval_accuracy = accuracy(logits, label_ids) tmp_eval_accuracy = accuracy(logits, label_ids)
eval_accuracy += tmp_eval_accuracy eval_accuracy += tmp_eval_accuracy
nb_eval_steps += 1 nb_eval_steps += 1
nb_eval_examples += inputs['input_ids'].size(0) nb_eval_examples += inputs["input_ids"].size(0)
eval_loss = eval_loss / nb_eval_steps eval_loss = eval_loss / nb_eval_steps
eval_accuracy = eval_accuracy / nb_eval_examples eval_accuracy = eval_accuracy / nb_eval_examples
result = {'eval_loss': eval_loss, result = {"eval_loss": eval_loss, "eval_accuracy": eval_accuracy}
'eval_accuracy': eval_accuracy}
output_eval_file = os.path.join(args.output_dir, "eval_results.txt") output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer: with open(output_eval_file, "w") as writer:
...@@ -483,92 +483,144 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -483,92 +483,144 @@ def evaluate(args, model, tokenizer, prefix=""):
return result return result
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument("--train_file", default=None, type=str, required=True, parser.add_argument(
help="SWAG csv for training. E.g., train.csv") "--train_file", default=None, type=str, required=True, help="SWAG csv for training. E.g., train.csv"
parser.add_argument("--predict_file", default=None, type=str, required=True, )
help="SWAG csv for predictions. E.g., val.csv or test.csv") parser.add_argument(
parser.add_argument("--model_type", default=None, type=str, required=True, "--predict_file",
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) default=None,
parser.add_argument("--model_name_or_path", default=None, type=str, required=True, type=str,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) required=True,
parser.add_argument("--output_dir", default=None, type=str, required=True, help="SWAG csv for predictions. E.g., val.csv or test.csv",
help="The output directory where the model checkpoints and predictions will be written.") )
parser.add_argument(
## Other parameters "--model_type",
parser.add_argument("--config_name", default="", type=str, default=None,
help="Pretrained config name or path if not the same as model_name") type=str,
parser.add_argument("--tokenizer_name", default="", type=str, required=True,
help="Pretrained tokenizer name or path if not the same as model_name") help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
parser.add_argument("--max_seq_length", default=384, type=int, )
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
)
parser.add_argument(
"--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model checkpoints and predictions will be written.",
)
# Other parameters
parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
)
parser.add_argument(
"--tokenizer_name",
default="",
type=str,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--max_seq_length",
default=384,
type=int,
help="The maximum total input sequence length after tokenization. Sequences " help="The maximum total input sequence length after tokenization. Sequences "
"longer than this will be truncated, and sequences shorter than this will be padded.") "longer than this will be truncated, and sequences shorter than this will be padded.",
parser.add_argument("--do_train", action='store_true', )
help="Whether to run training.") parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true', parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
help="Whether to run eval on the dev set.") parser.add_argument(
parser.add_argument("--evaluate_during_training", action='store_true', "--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
help="Rul evaluation during training at each logging step.") )
parser.add_argument("--do_lower_case", action='store_true', parser.add_argument(
help="Set this flag if you are using an uncased model.") "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
)
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
help="Batch size per GPU/CPU for training.") parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, parser.add_argument(
help="Batch size per GPU/CPU for evaluation.") "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
parser.add_argument("--learning_rate", default=5e-5, type=float, )
help="The initial learning rate for Adam.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1, parser.add_argument(
help="Number of updates steps to accumulate before performing a backward/update pass.") "--gradient_accumulation_steps",
parser.add_argument("--weight_decay", default=0.0, type=float, type=int,
help="Weight deay if we apply some.") default=1,
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Number of updates steps to accumulate before performing a backward/update pass.",
help="Epsilon for Adam optimizer.") )
parser.add_argument("--max_grad_norm", default=1.0, type=float, parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
help="Max gradient norm.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--num_train_epochs", default=3.0, type=float, parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
help="Total number of training epochs to perform.") parser.add_argument(
parser.add_argument("--max_steps", default=-1, type=int, "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
help="If > 0: set total number of training steps to perform. Override num_train_epochs.") )
parser.add_argument("--warmup_steps", default=0, type=int, parser.add_argument(
help="Linear warmup over warmup_steps.") "--max_steps",
default=-1,
parser.add_argument('--logging_steps', type=int, default=50, type=int,
help="Log every X updates steps.") help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
parser.add_argument('--save_steps', type=int, default=50, )
help="Save checkpoint every X updates steps.") parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--eval_all_checkpoints", action='store_true',
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
parser.add_argument("--no_cuda", action='store_true', parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
help="Whether not to use CUDA when available") parser.add_argument(
parser.add_argument('--overwrite_output_dir', action='store_true', "--eval_all_checkpoints",
help="Overwrite the content of the output directory") action="store_true",
parser.add_argument('--overwrite_cache', action='store_true', help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
help="Overwrite the cached training and evaluation sets") )
parser.add_argument('--seed', type=int, default=42, parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
help="random seed for initialization") parser.add_argument(
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
parser.add_argument("--local_rank", type=int, default=-1, )
help="local_rank for distributed training on gpus") parser.add_argument(
parser.add_argument('--fp16', action='store_true', "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") )
parser.add_argument('--fp16_opt_level', type=str, default='O1', parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
parser.add_argument(
"--fp16",
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
parser.add_argument(
"--fp16_opt_level",
type=str,
default="O1",
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html") "See details at https://nvidia.github.io/apex/amp.html",
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") )
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
args = parser.parse_args() args = parser.parse_args()
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: if (
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) os.path.exists(args.output_dir)
and os.listdir(args.output_dir)
and args.do_train
and not args.overwrite_output_dir
):
raise ValueError(
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
args.output_dir
)
)
# Setup distant debugging if needed # Setup distant debugging if needed
if args.server_ip and args.server_port: if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd import ptvsd
print("Waiting for debugger attach") print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach() ptvsd.wait_for_attach()
...@@ -580,16 +632,24 @@ def main(): ...@@ -580,16 +632,24 @@ def main():
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.cuda.set_device(args.local_rank) torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank) device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend='nccl') torch.distributed.init_process_group(backend="nccl")
args.n_gpu = 1 args.n_gpu = 1
args.device = device args.device = device
# Setup logging # Setup logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(
datefmt = '%m/%d/%Y %H:%M:%S', format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) datefmt="%m/%d/%Y %H:%M:%S",
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) )
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
args.local_rank,
device,
args.n_gpu,
bool(args.local_rank != -1),
args.fp16,
)
# Set seed # Set seed
set_seed(args) set_seed(args)
...@@ -601,8 +661,12 @@ def main(): ...@@ -601,8 +661,12 @@ def main():
args.model_type = args.model_type.lower() args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case) tokenizer = tokenizer_class.from_pretrained(
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config) args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case
)
model = model_class.from_pretrained(
args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config
)
if args.local_rank == 0: if args.local_rank == 0:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
...@@ -617,7 +681,6 @@ def main(): ...@@ -617,7 +681,6 @@ def main():
global_step, tr_loss = train(args, train_dataset, model, tokenizer) global_step, tr_loss = train(args, train_dataset, model, tokenizer)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
# Save the trained model and the tokenizer # Save the trained model and the tokenizer
if args.local_rank == -1 or torch.distributed.get_rank() == 0: if args.local_rank == -1 or torch.distributed.get_rank() == 0:
# Create output directory if needed # Create output directory if needed
...@@ -627,19 +690,20 @@ def main(): ...@@ -627,19 +690,20 @@ def main():
logger.info("Saving model checkpoint to %s", args.output_dir) logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`. # Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()` # They can then be reloaded using `from_pretrained()`
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(args.output_dir) model_to_save.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
# Good practice: save your training arguments together with the trained model # Good practice: save your training arguments together with the trained model
torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
# Load a trained model and vocabulary that you have fine-tuned # Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir) model = model_class.from_pretrained(args.output_dir)
tokenizer = tokenizer_class.from_pretrained(args.output_dir) tokenizer = tokenizer_class.from_pretrained(args.output_dir)
model.to(args.device) model.to(args.device)
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
results = {} results = {}
if args.do_eval and args.local_rank in [-1, 0]: if args.do_eval and args.local_rank in [-1, 0]:
...@@ -650,14 +714,16 @@ def main(): ...@@ -650,14 +714,16 @@ def main():
checkpoints = [args.model_name_or_path] checkpoints = [args.model_name_or_path]
if args.eval_all_checkpoints: if args.eval_all_checkpoints:
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) checkpoints = list(
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
)
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
# Reload the model # Reload the model
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
model = model_class.from_pretrained(checkpoint) model = model_class.from_pretrained(checkpoint)
tokenizer = tokenizer_class.from_pretrained(checkpoint) tokenizer = tokenizer_class.from_pretrained(checkpoint)
model.to(args.device) model.to(args.device)
...@@ -665,7 +731,7 @@ def main(): ...@@ -665,7 +731,7 @@ def main():
# Evaluate # Evaluate
result = evaluate(args, model, tokenizer, prefix=global_step) result = evaluate(args, model, tokenizer, prefix=global_step)
result = dict((k + ('_{}'.format(global_step) if global_step else ''), v) for k, v in result.items()) result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
results.update(result) results.update(result)
logger.info("Results: {}".format(results)) logger.info("Results: {}".format(results))
......
...@@ -23,51 +23,44 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -23,51 +23,44 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import argparse import argparse
import logging import logging
import time
import math import math
import time
import torch import torch
from transformers import TransfoXLLMHeadModel, TransfoXLCorpus, TransfoXLTokenizer from transformers import TransfoXLCorpus, TransfoXLLMHeadModel, TransfoXLTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(
datefmt = '%m/%d/%Y %H:%M:%S', format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
level = logging.INFO) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def main(): def main():
parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model') parser = argparse.ArgumentParser(description="PyTorch Transformer Language Model")
parser.add_argument('--model_name', type=str, default='transfo-xl-wt103', parser.add_argument("--model_name", type=str, default="transfo-xl-wt103", help="pretrained model name")
help='pretrained model name') parser.add_argument(
parser.add_argument('--split', type=str, default='test', "--split", type=str, default="test", choices=["all", "valid", "test"], help="which split to evaluate"
choices=['all', 'valid', 'test'], )
help='which split to evaluate') parser.add_argument("--batch_size", type=int, default=10, help="batch size")
parser.add_argument('--batch_size', type=int, default=10, parser.add_argument("--tgt_len", type=int, default=128, help="number of tokens to predict")
help='batch size') parser.add_argument("--ext_len", type=int, default=0, help="length of the extended context")
parser.add_argument('--tgt_len', type=int, default=128, parser.add_argument("--mem_len", type=int, default=1600, help="length of the retained previous heads")
help='number of tokens to predict') parser.add_argument("--clamp_len", type=int, default=1000, help="max positional embedding index")
parser.add_argument('--ext_len', type=int, default=0, parser.add_argument("--no_cuda", action="store_true", help="Do not use CUDA even though CUA is available")
help='length of the extended context') parser.add_argument("--work_dir", type=str, required=True, help="path to the work_dir")
parser.add_argument('--mem_len', type=int, default=1600, parser.add_argument("--no_log", action="store_true", help="do not log the eval result")
help='length of the retained previous heads') parser.add_argument("--same_length", action="store_true", help="set same length attention with masking")
parser.add_argument('--clamp_len', type=int, default=1000, parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
help='max positional embedding index') parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument('--no_cuda', action='store_true',
help='Do not use CUDA even though CUA is available')
parser.add_argument('--work_dir', type=str, required=True,
help='path to the work_dir')
parser.add_argument('--no_log', action='store_true',
help='do not log the eval result')
parser.add_argument('--same_length', action='store_true',
help='set same length attention with masking')
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
args = parser.parse_args() args = parser.parse_args()
assert args.ext_len >= 0, 'extended context length must be non-negative' assert args.ext_len >= 0, "extended context length must be non-negative"
if args.server_ip and args.server_port: if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd import ptvsd
print("Waiting for debugger attach") print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach() ptvsd.wait_for_attach()
...@@ -84,17 +77,18 @@ def main(): ...@@ -84,17 +77,18 @@ def main():
corpus = TransfoXLCorpus.from_pretrained(args.model_name) corpus = TransfoXLCorpus.from_pretrained(args.model_name)
ntokens = len(corpus.vocab) ntokens = len(corpus.vocab)
va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len, va_iter = corpus.get_iterator("valid", args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len)
device=device, ext_len=args.ext_len) te_iter = corpus.get_iterator("test", args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len)
te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
device=device, ext_len=args.ext_len)
# Load a pre-trained model # Load a pre-trained model
model = TransfoXLLMHeadModel.from_pretrained(args.model_name) model = TransfoXLLMHeadModel.from_pretrained(args.model_name)
model = model.to(device) model = model.to(device)
logger.info('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format( logger.info(
args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len)) "Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}".format(
args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len
)
)
model.reset_length(args.tgt_len, args.ext_len, args.mem_len) model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
if args.clamp_len > 0: if args.clamp_len > 0:
...@@ -108,7 +102,7 @@ def main(): ...@@ -108,7 +102,7 @@ def main():
def evaluate(eval_iter): def evaluate(eval_iter):
# Turn on evaluation mode which disables dropout. # Turn on evaluation mode which disables dropout.
model.eval() model.eval()
total_len, total_loss = 0, 0. total_len, total_loss = 0, 0.0
start_time = time.time() start_time = time.time()
with torch.no_grad(): with torch.no_grad():
mems = None mems = None
...@@ -119,35 +113,34 @@ def main(): ...@@ -119,35 +113,34 @@ def main():
total_loss += seq_len * loss.item() total_loss += seq_len * loss.item()
total_len += seq_len total_len += seq_len
total_time = time.time() - start_time total_time = time.time() - start_time
logger.info('Time : {:.2f}s, {:.2f}ms/segment'.format( logger.info("Time : {:.2f}s, {:.2f}ms/segment".format(total_time, 1000 * total_time / (idx + 1)))
total_time, 1000 * total_time / (idx+1)))
return total_loss / total_len return total_loss / total_len
# Run on test data. # Run on test data.
if args.split == 'all': if args.split == "all":
test_loss = evaluate(te_iter) test_loss = evaluate(te_iter)
valid_loss = evaluate(va_iter) valid_loss = evaluate(va_iter)
elif args.split == 'valid': elif args.split == "valid":
valid_loss = evaluate(va_iter) valid_loss = evaluate(va_iter)
test_loss = None test_loss = None
elif args.split == 'test': elif args.split == "test":
test_loss = evaluate(te_iter) test_loss = evaluate(te_iter)
valid_loss = None valid_loss = None
def format_log(loss, split): def format_log(loss, split):
log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format( log_str = "| {0} loss {1:5.2f} | {0} ppl {2:9.3f} ".format(split, loss, math.exp(loss))
split, loss, math.exp(loss))
return log_str return log_str
log_str = '' log_str = ""
if valid_loss is not None: if valid_loss is not None:
log_str += format_log(valid_loss, 'valid') log_str += format_log(valid_loss, "valid")
if test_loss is not None: if test_loss is not None:
log_str += format_log(test_loss, 'test') log_str += format_log(test_loss, "test")
logger.info('=' * 100) logger.info("=" * 100)
logger.info(log_str) logger.info(log_str)
logger.info('=' * 100) logger.info("=" * 100)
if __name__ == '__main__': if __name__ == "__main__":
main() main()
...@@ -15,39 +15,36 @@ ...@@ -15,39 +15,36 @@
""" The distiller to distil the student. """ The distiller to distil the student.
Adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) Adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
""" """
import os
import math import math
import psutil import os
import time import time
from tqdm import trange, tqdm
import numpy as np
import psutil
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.optim import AdamW from torch.optim import AdamW
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import RandomSampler, BatchSampler, DataLoader from tqdm import tqdm
from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
from lm_seqs_dataset import LmSeqsDataset
from transformers import get_linear_schedule_with_warmup
from utils import logger
try: try:
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
except: except ImportError:
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from transformers import get_linear_schedule_with_warmup
from utils import logger
from lm_seqs_dataset import LmSeqsDataset
from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
class Distiller: class Distiller:
def __init__(self, def __init__(
params: dict, self, params: dict, dataset: LmSeqsDataset, token_probs: torch.tensor, student: nn.Module, teacher: nn.Module
dataset: LmSeqsDataset, ):
token_probs: torch.tensor, logger.info("Initializing Distiller")
student: nn.Module,
teacher: nn.Module):
logger.info('Initializing Distiller')
self.params = params self.params = params
self.dump_path = params.dump_path self.dump_path = params.dump_path
self.multi_gpu = params.multi_gpu self.multi_gpu = params.multi_gpu
...@@ -70,12 +67,10 @@ class Distiller: ...@@ -70,12 +67,10 @@ class Distiller:
else: else:
sampler = BatchSampler(sampler=sampler, batch_size=params.batch_size, drop_last=False) sampler = BatchSampler(sampler=sampler, batch_size=params.batch_size, drop_last=False)
self.dataloader = DataLoader(dataset=dataset, self.dataloader = DataLoader(dataset=dataset, batch_sampler=sampler, collate_fn=dataset.batch_sequences)
batch_sampler=sampler,
collate_fn=dataset.batch_sequences)
self.temperature = params.temperature self.temperature = params.temperature
assert self.temperature > 0. assert self.temperature > 0.0
self.alpha_ce = params.alpha_ce self.alpha_ce = params.alpha_ce
self.alpha_mlm = params.alpha_mlm self.alpha_mlm = params.alpha_mlm
...@@ -85,18 +80,18 @@ class Distiller: ...@@ -85,18 +80,18 @@ class Distiller:
self.mlm = params.mlm self.mlm = params.mlm
if self.mlm: if self.mlm:
logger.info(f'Using MLM loss for LM step.') logger.info(f"Using MLM loss for LM step.")
self.mlm_mask_prop = params.mlm_mask_prop self.mlm_mask_prop = params.mlm_mask_prop
assert 0.0 <= self.mlm_mask_prop <= 1.0 assert 0.0 <= self.mlm_mask_prop <= 1.0
assert params.word_mask + params.word_keep + params.word_rand == 1.0 assert params.word_mask + params.word_keep + params.word_rand == 1.0
self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand]) self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
self.pred_probs = self.pred_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else self.pred_probs self.pred_probs = self.pred_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else self.pred_probs
self.token_probs = token_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else token_probs self.token_probs = token_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else token_probs
if self.fp16: if self.fp16:
self.pred_probs = self.pred_probs.half() self.pred_probs = self.pred_probs.half()
self.token_probs = self.token_probs.half() self.token_probs = self.token_probs.half()
else: else:
logger.info(f'Using CLM loss for LM step.') logger.info(f"Using CLM loss for LM step.")
self.epoch = 0 self.epoch = 0
self.n_iter = 0 self.n_iter = 0
...@@ -107,38 +102,54 @@ class Distiller: ...@@ -107,38 +102,54 @@ class Distiller:
self.last_loss_ce = 0 self.last_loss_ce = 0
self.last_loss_mlm = 0 self.last_loss_mlm = 0
self.last_loss_clm = 0 self.last_loss_clm = 0
if self.alpha_mse > 0.: self.last_loss_mse = 0 if self.alpha_mse > 0.0:
if self.alpha_cos > 0.: self.last_loss_cos = 0 self.last_loss_mse = 0
if self.alpha_cos > 0.0:
self.last_loss_cos = 0
self.last_log = 0 self.last_log = 0
self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean') self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100) self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
if self.alpha_mse > 0.: if self.alpha_mse > 0.0:
self.mse_loss_fct = nn.MSELoss(reduction='sum') self.mse_loss_fct = nn.MSELoss(reduction="sum")
if self.alpha_cos > 0.: if self.alpha_cos > 0.0:
self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction='mean') self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")
logger.info('--- Initializing model optimizer') logger.info("--- Initializing model optimizer")
assert params.gradient_accumulation_steps >= 1 assert params.gradient_accumulation_steps >= 1
self.num_steps_epoch = len(self.dataloader) self.num_steps_epoch = len(self.dataloader)
num_train_optimization_steps = int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1 num_train_optimization_steps = (
int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1
)
no_decay = ['bias', 'LayerNorm.weight'] no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
{'params': [p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': params.weight_decay}, {
{'params': [p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': 0.0} "params": [
p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad
],
"weight_decay": params.weight_decay,
},
{
"params": [
p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad
],
"weight_decay": 0.0,
},
] ]
logger.info("------ Number of trainable parameters (student): %i" % sum([p.numel() for p in self.student.parameters() if p.requires_grad])) logger.info(
"------ Number of trainable parameters (student): %i"
% sum([p.numel() for p in self.student.parameters() if p.requires_grad])
)
logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()])) logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()]))
self.optimizer = AdamW(optimizer_grouped_parameters, self.optimizer = AdamW(
lr=params.learning_rate, optimizer_grouped_parameters, lr=params.learning_rate, eps=params.adam_epsilon, betas=(0.9, 0.98)
eps=params.adam_epsilon, )
betas=(0.9, 0.98))
warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop) warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
self.scheduler = get_linear_schedule_with_warmup(self.optimizer, self.scheduler = get_linear_schedule_with_warmup(
num_warmup_steps=warmup_steps, self.optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps
num_training_steps=num_train_optimization_steps) )
if self.fp16: if self.fp16:
try: try:
...@@ -146,33 +157,36 @@ class Distiller: ...@@ -146,33 +157,36 @@ class Distiller:
except ImportError: except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
logger.info(f"Using fp16 training: {self.params.fp16_opt_level} level") logger.info(f"Using fp16 training: {self.params.fp16_opt_level} level")
self.student, self.optimizer = amp.initialize(self.student, self.student, self.optimizer = amp.initialize(
self.optimizer, self.student, self.optimizer, opt_level=self.params.fp16_opt_level
opt_level=self.params.fp16_opt_level) )
self.teacher = self.teacher.half() self.teacher = self.teacher.half()
if self.multi_gpu: if self.multi_gpu:
if self.fp16: if self.fp16:
from apex.parallel import DistributedDataParallel from apex.parallel import DistributedDataParallel
logger.info("Using apex.parallel.DistributedDataParallel for distributed training.") logger.info("Using apex.parallel.DistributedDataParallel for distributed training.")
self.student = DistributedDataParallel(self.student) self.student = DistributedDataParallel(self.student)
else: else:
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
logger.info("Using nn.parallel.DistributedDataParallel for distributed training.") logger.info("Using nn.parallel.DistributedDataParallel for distributed training.")
self.student = DistributedDataParallel(self.student, self.student = DistributedDataParallel(
self.student,
device_ids=[params.local_rank], device_ids=[params.local_rank],
output_device=params.local_rank, output_device=params.local_rank,
find_unused_parameters=True) find_unused_parameters=True,
)
self.is_master = params.is_master self.is_master = params.is_master
if self.is_master: if self.is_master:
logger.info('--- Initializing Tensorboard') logger.info("--- Initializing Tensorboard")
self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, 'log', 'train')) self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, "log", "train"))
self.tensorboard.add_text(tag='config/training', text_string=str(self.params), global_step=0) self.tensorboard.add_text(tag="config/training", text_string=str(self.params), global_step=0)
self.tensorboard.add_text(tag='config/student', text_string=str(self.student_config), global_step=0) self.tensorboard.add_text(tag="config/student", text_string=str(self.student_config), global_step=0)
def prepare_batch_mlm(self, def prepare_batch_mlm(self, batch):
batch):
""" """
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM. Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.
...@@ -192,7 +206,7 @@ class Distiller: ...@@ -192,7 +206,7 @@ class Distiller:
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths) token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
assert token_ids.size(0) == lengths.size(0) assert token_ids.size(0) == lengths.size(0)
attn_mask = (torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]) attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
bs, max_seq_len = token_ids.size() bs, max_seq_len = token_ids.size()
mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids) mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
...@@ -200,11 +214,13 @@ class Distiller: ...@@ -200,11 +214,13 @@ class Distiller:
x_prob = self.token_probs[token_ids.flatten()] x_prob = self.token_probs[token_ids.flatten()]
n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item()) n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False) tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
pred_mask = torch.zeros(bs * max_seq_len, dtype=torch.bool, device=token_ids.device) # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility pred_mask = torch.zeros(
bs * max_seq_len, dtype=torch.bool, device=token_ids.device
) # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
pred_mask[tgt_ids] = 1 pred_mask[tgt_ids] = 1
pred_mask = pred_mask.view(bs, max_seq_len) pred_mask = pred_mask.view(bs, max_seq_len)
pred_mask[token_ids == self.params.special_tok_ids['pad_token']] = 0 pred_mask[token_ids == self.params.special_tok_ids["pad_token"]] = 0
# mask a number of words == 0 [8] (faster with fp16) # mask a number of words == 0 [8] (faster with fp16)
if self.fp16: if self.fp16:
...@@ -213,15 +229,19 @@ class Distiller: ...@@ -213,15 +229,19 @@ class Distiller:
pred_mask = pred_mask.view(-1) pred_mask = pred_mask.view(-1)
n2 = max(n1 % 8, 8 * (n1 // 8)) n2 = max(n1 % 8, 8 * (n1 // 8))
if n2 != n1: if n2 != n1:
pred_mask[torch.nonzero(pred_mask).view(-1)[:n1-n2]] = 0 pred_mask[torch.nonzero(pred_mask).view(-1)[: n1 - n2]] = 0
pred_mask = pred_mask.view(bs, max_seq_len) pred_mask = pred_mask.view(bs, max_seq_len)
assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item() assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()
_token_ids_real = token_ids[pred_mask] _token_ids_real = token_ids[pred_mask]
_token_ids_rand = _token_ids_real.clone().random_(self.vocab_size) _token_ids_rand = _token_ids_real.clone().random_(self.vocab_size)
_token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids['mask_token']) _token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids["mask_token"])
probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True) probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True)
_token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long() _token_ids = (
_token_ids_mask * (probs == 0).long()
+ _token_ids_real * (probs == 1).long()
+ _token_ids_rand * (probs == 2).long()
)
token_ids = token_ids.masked_scatter(pred_mask, _token_ids) token_ids = token_ids.masked_scatter(pred_mask, _token_ids)
mlm_labels[~pred_mask] = -100 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility mlm_labels[~pred_mask] = -100 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
...@@ -231,8 +251,7 @@ class Distiller: ...@@ -231,8 +251,7 @@ class Distiller:
return token_ids, attn_mask, mlm_labels return token_ids, attn_mask, mlm_labels
def prepare_batch_clm(self, def prepare_batch_clm(self, batch):
batch):
""" """
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the labels for CLM. Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the labels for CLM.
...@@ -252,7 +271,7 @@ class Distiller: ...@@ -252,7 +271,7 @@ class Distiller:
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths) token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
assert token_ids.size(0) == lengths.size(0) assert token_ids.size(0) == lengths.size(0)
attn_mask = (torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]) attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
clm_labels = token_ids.new(token_ids.size()).copy_(token_ids) clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
clm_labels[~attn_mask] = -100 # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility clm_labels[~attn_mask] = -100 # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
...@@ -261,9 +280,7 @@ class Distiller: ...@@ -261,9 +280,7 @@ class Distiller:
return token_ids, attn_mask, clm_labels return token_ids, attn_mask, clm_labels
def round_batch(self, def round_batch(self, x: torch.tensor, lengths: torch.tensor):
x: torch.tensor,
lengths: torch.tensor):
""" """
For float16 only. For float16 only.
Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8. Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.
...@@ -299,9 +316,9 @@ class Distiller: ...@@ -299,9 +316,9 @@ class Distiller:
pad = 8 - (ml1 % 8) pad = 8 - (ml1 % 8)
ml2 = ml1 + pad ml2 = ml1 + pad
if self.mlm: if self.mlm:
pad_id = self.params.special_tok_ids['pad_token'] pad_id = self.params.special_tok_ids["pad_token"]
else: else:
pad_id = self.params.special_tok_ids['unk_token'] pad_id = self.params.special_tok_ids["unk_token"]
padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id) padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id)
x = torch.cat([x, padding_tensor], 1) x = torch.cat([x, padding_tensor], 1)
assert x.size() == (bs2, ml2) assert x.size() == (bs2, ml2)
...@@ -314,20 +331,22 @@ class Distiller: ...@@ -314,20 +331,22 @@ class Distiller:
""" """
The real training loop. The real training loop.
""" """
if self.is_master: logger.info('Starting training') if self.is_master:
logger.info("Starting training")
self.last_log = time.time() self.last_log = time.time()
self.student.train() self.student.train()
self.teacher.eval() self.teacher.eval()
for _ in range(self.params.n_epoch): for _ in range(self.params.n_epoch):
if self.is_master: logger.info(f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}') if self.is_master:
logger.info(f"--- Starting epoch {self.epoch}/{self.params.n_epoch-1}")
if self.multi_gpu: if self.multi_gpu:
torch.distributed.barrier() torch.distributed.barrier()
iter_bar = tqdm(self.dataloader, desc="-Iter", disable=self.params.local_rank not in [-1, 0]) iter_bar = tqdm(self.dataloader, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
for batch in iter_bar: for batch in iter_bar:
if self.params.n_gpu > 0: if self.params.n_gpu > 0:
batch = tuple(t.to(f'cuda:{self.params.local_rank}') for t in batch) batch = tuple(t.to(f"cuda:{self.params.local_rank}") for t in batch)
if self.mlm: if self.mlm:
token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(batch=batch) token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(batch=batch)
...@@ -336,22 +355,21 @@ class Distiller: ...@@ -336,22 +355,21 @@ class Distiller:
self.step(input_ids=token_ids, attention_mask=attn_mask, lm_labels=lm_labels) self.step(input_ids=token_ids, attention_mask=attn_mask, lm_labels=lm_labels)
iter_bar.update() iter_bar.update()
iter_bar.set_postfix({'Last_loss': f'{self.last_loss:.2f}', iter_bar.set_postfix(
'Avg_cum_loss': f'{self.total_loss_epoch/self.n_iter:.2f}'}) {"Last_loss": f"{self.last_loss:.2f}", "Avg_cum_loss": f"{self.total_loss_epoch/self.n_iter:.2f}"}
)
iter_bar.close() iter_bar.close()
if self.is_master: logger.info(f'--- Ending epoch {self.epoch}/{self.params.n_epoch-1}') if self.is_master:
logger.info(f"--- Ending epoch {self.epoch}/{self.params.n_epoch-1}")
self.end_epoch() self.end_epoch()
if self.is_master: if self.is_master:
logger.info(f'Save very last checkpoint as `pytorch_model.bin`.') logger.info(f"Save very last checkpoint as `pytorch_model.bin`.")
self.save_checkpoint(checkpoint_name=f'pytorch_model.bin') self.save_checkpoint(checkpoint_name=f"pytorch_model.bin")
logger.info('Training is finished') logger.info("Training is finished")
def step(self, def step(self, input_ids: torch.tensor, attention_mask: torch.tensor, lm_labels: torch.tensor):
input_ids: torch.tensor,
attention_mask: torch.tensor,
lm_labels: torch.tensor):
""" """
One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation), One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
and possibly a parameter update (depending on the gradient accumulation). and possibly a parameter update (depending on the gradient accumulation).
...@@ -363,19 +381,27 @@ class Distiller: ...@@ -363,19 +381,27 @@ class Distiller:
lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM). lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
""" """
if self.mlm: if self.mlm:
s_logits, s_hidden_states = self.student(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size) s_logits, s_hidden_states = self.student(
input_ids=input_ids, attention_mask=attention_mask
) # (bs, seq_length, voc_size)
with torch.no_grad(): with torch.no_grad():
t_logits, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size) t_logits, t_hidden_states = self.teacher(
input_ids=input_ids, attention_mask=attention_mask
) # (bs, seq_length, voc_size)
else: else:
s_logits, _, s_hidden_states = self.student(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size) s_logits, _, s_hidden_states = self.student(
input_ids=input_ids, attention_mask=None
) # (bs, seq_length, voc_size)
with torch.no_grad(): with torch.no_grad():
t_logits, _, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size) t_logits, _, t_hidden_states = self.teacher(
input_ids=input_ids, attention_mask=None
) # (bs, seq_length, voc_size)
assert s_logits.size() == t_logits.size() assert s_logits.size() == t_logits.size()
#https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 # https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
#https://github.com/peterliht/knowledge-distillation-pytorch/issues/2 # https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
if self.params.restrict_ce_to_mask: if self.params.restrict_ce_to_mask:
mask = (lm_labels>-1).unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size) mask = (lm_labels > -1).unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size)
else: else:
mask = attention_mask.unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size) mask = attention_mask.unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size)
s_logits_slct = torch.masked_select(s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask s_logits_slct = torch.masked_select(s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
...@@ -384,24 +410,30 @@ class Distiller: ...@@ -384,24 +410,30 @@ class Distiller:
t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
assert t_logits_slct.size() == s_logits_slct.size() assert t_logits_slct.size() == s_logits_slct.size()
loss_ce = self.ce_loss_fct(F.log_softmax(s_logits_slct/self.temperature, dim=-1), loss_ce = (
F.softmax(t_logits_slct/self.temperature, dim=-1)) * (self.temperature)**2 self.ce_loss_fct(
loss = self.alpha_ce*loss_ce F.log_softmax(s_logits_slct / self.temperature, dim=-1),
F.softmax(t_logits_slct / self.temperature, dim=-1),
)
* (self.temperature) ** 2
)
loss = self.alpha_ce * loss_ce
if self.alpha_mlm > 0.: if self.alpha_mlm > 0.0:
loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), lm_labels.view(-1)) loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), lm_labels.view(-1))
loss += self.alpha_mlm * loss_mlm loss += self.alpha_mlm * loss_mlm
if self.alpha_clm > 0.: if self.alpha_clm > 0.0:
shift_logits = s_logits[..., :-1, :].contiguous() shift_logits = s_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[..., 1:].contiguous() shift_labels = lm_labels[..., 1:].contiguous()
loss_clm = self.lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)), loss_clm = self.lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
shift_labels.view(-1))
loss += self.alpha_clm * loss_clm loss += self.alpha_clm * loss_clm
if self.alpha_mse > 0.: if self.alpha_mse > 0.0:
loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct)/s_logits_slct.size(0) # Reproducing batchmean reduction loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct) / s_logits_slct.size(
0
) # Reproducing batchmean reduction
loss += self.alpha_mse * loss_mse loss += self.alpha_mse * loss_mse
if self.alpha_cos > 0.: if self.alpha_cos > 0.0:
s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim) s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim)
t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim) t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim)
mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states) # (bs, seq_length, dim) mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states) # (bs, seq_length, dim)
...@@ -420,21 +452,20 @@ class Distiller: ...@@ -420,21 +452,20 @@ class Distiller:
self.total_loss_epoch += loss.item() self.total_loss_epoch += loss.item()
self.last_loss = loss.item() self.last_loss = loss.item()
self.last_loss_ce = loss_ce.item() self.last_loss_ce = loss_ce.item()
if self.alpha_mlm > 0.: if self.alpha_mlm > 0.0:
self.last_loss_mlm = loss_mlm.item() self.last_loss_mlm = loss_mlm.item()
if self.alpha_clm > 0.: if self.alpha_clm > 0.0:
self.last_loss_clm = loss_clm.item() self.last_loss_clm = loss_clm.item()
if self.alpha_mse > 0.: if self.alpha_mse > 0.0:
self.last_loss_mse = loss_mse.item() self.last_loss_mse = loss_mse.item()
if self.alpha_cos > 0.: if self.alpha_cos > 0.0:
self.last_loss_cos = loss_cos.item() self.last_loss_cos = loss_cos.item()
self.optimize(loss) self.optimize(loss)
self.n_sequences_epoch += input_ids.size(0) self.n_sequences_epoch += input_ids.size(0)
def optimize(self, def optimize(self, loss):
loss):
""" """
Normalization on the loss (gradient accumulation or distributed training), followed by Normalization on the loss (gradient accumulation or distributed training), followed by
backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation). backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
...@@ -442,7 +473,7 @@ class Distiller: ...@@ -442,7 +473,7 @@ class Distiller:
""" """
# Check for NaN # Check for NaN
if (loss != loss).data.any(): if (loss != loss).data.any():
logger.error('NaN detected') logger.error("NaN detected")
exit() exit()
if self.multi_gpu: if self.multi_gpu:
...@@ -452,6 +483,7 @@ class Distiller: ...@@ -452,6 +483,7 @@ class Distiller:
if self.fp16: if self.fp16:
from apex import amp from apex import amp
with amp.scale_loss(loss, self.optimizer) as scaled_loss: with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
else: else:
...@@ -488,53 +520,84 @@ class Distiller: ...@@ -488,53 +520,84 @@ class Distiller:
return return
for param_name, param in self.student.named_parameters(): for param_name, param in self.student.named_parameters():
self.tensorboard.add_scalar(tag='parameter_mean/' + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter) self.tensorboard.add_scalar(
self.tensorboard.add_scalar(tag='parameter_std/' + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter) tag="parameter_mean/" + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter
)
self.tensorboard.add_scalar(
tag="parameter_std/" + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter
)
if param.grad is None: if param.grad is None:
continue continue
self.tensorboard.add_scalar(tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(),global_step=self.n_total_iter) self.tensorboard.add_scalar(
self.tensorboard.add_scalar(tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter) tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(), global_step=self.n_total_iter
)
self.tensorboard.add_scalar(tag="losses/cum_avg_loss_epoch", scalar_value=self.total_loss_epoch/self.n_iter, global_step=self.n_total_iter) self.tensorboard.add_scalar(
tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter
)
self.tensorboard.add_scalar(
tag="losses/cum_avg_loss_epoch",
scalar_value=self.total_loss_epoch / self.n_iter,
global_step=self.n_total_iter,
)
self.tensorboard.add_scalar(tag="losses/loss", scalar_value=self.last_loss, global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="losses/loss", scalar_value=self.last_loss, global_step=self.n_total_iter)
self.tensorboard.add_scalar(tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter) self.tensorboard.add_scalar(
if self.alpha_mlm > 0.: tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter
self.tensorboard.add_scalar(tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter) )
if self.alpha_clm > 0.: if self.alpha_mlm > 0.0:
self.tensorboard.add_scalar(tag="losses/loss_clm", scalar_value=self.last_loss_clm, global_step=self.n_total_iter) self.tensorboard.add_scalar(
if self.alpha_mse > 0.: tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter
self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter) )
if self.alpha_cos > 0.: if self.alpha_clm > 0.0:
self.tensorboard.add_scalar(tag="losses/loss_cos", scalar_value=self.last_loss_cos, global_step=self.n_total_iter) self.tensorboard.add_scalar(
self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter) tag="losses/loss_clm", scalar_value=self.last_loss_clm, global_step=self.n_total_iter
)
self.tensorboard.add_scalar(tag="global/memory_usage", scalar_value=psutil.virtual_memory()._asdict()['used']/1_000_000, global_step=self.n_total_iter) if self.alpha_mse > 0.0:
self.tensorboard.add_scalar(tag="global/speed", scalar_value=time.time()-self.last_log, global_step=self.n_total_iter) self.tensorboard.add_scalar(
tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter
)
if self.alpha_cos > 0.0:
self.tensorboard.add_scalar(
tag="losses/loss_cos", scalar_value=self.last_loss_cos, global_step=self.n_total_iter
)
self.tensorboard.add_scalar(
tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter
)
self.tensorboard.add_scalar(
tag="global/memory_usage",
scalar_value=psutil.virtual_memory()._asdict()["used"] / 1_000_000,
global_step=self.n_total_iter,
)
self.tensorboard.add_scalar(
tag="global/speed", scalar_value=time.time() - self.last_log, global_step=self.n_total_iter
)
def end_epoch(self): def end_epoch(self):
""" """
Finally arrived at the end of epoch (full pass on dataset). Finally arrived at the end of epoch (full pass on dataset).
Do some tensorboard logging and checkpoint saving. Do some tensorboard logging and checkpoint saving.
""" """
logger.info(f'{self.n_sequences_epoch} sequences have been trained during this epoch.') logger.info(f"{self.n_sequences_epoch} sequences have been trained during this epoch.")
if self.is_master: if self.is_master:
self.save_checkpoint(checkpoint_name=f'model_epoch_{self.epoch}.pth') self.save_checkpoint(checkpoint_name=f"model_epoch_{self.epoch}.pth")
self.tensorboard.add_scalar(tag='epoch/loss', scalar_value=self.total_loss_epoch/self.n_iter, global_step=self.epoch) self.tensorboard.add_scalar(
tag="epoch/loss", scalar_value=self.total_loss_epoch / self.n_iter, global_step=self.epoch
)
self.epoch += 1 self.epoch += 1
self.n_sequences_epoch = 0 self.n_sequences_epoch = 0
self.n_iter = 0 self.n_iter = 0
self.total_loss_epoch = 0 self.total_loss_epoch = 0
def save_checkpoint(self, def save_checkpoint(self, checkpoint_name: str = "checkpoint.pth"):
checkpoint_name: str = 'checkpoint.pth'):
""" """
Save the current state. Only by the master process. Save the current state. Only by the master process.
""" """
if not self.is_master: if not self.is_master:
return return
mdl_to_save = self.student.module if hasattr(self.student, 'module') else self.student mdl_to_save = self.student.module if hasattr(self.student, "module") else self.student
mdl_to_save.config.save_pretrained(self.dump_path) mdl_to_save.config.save_pretrained(self.dump_path)
state_dict = mdl_to_save.state_dict() state_dict = mdl_to_save.state_dict()
torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name)) torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))
...@@ -17,18 +17,20 @@ ...@@ -17,18 +17,20 @@
import bisect import bisect
import copy import copy
from collections import defaultdict from collections import defaultdict
import numpy as np
import numpy as np
from torch.utils.data.sampler import BatchSampler, Sampler from torch.utils.data.sampler import BatchSampler, Sampler
from utils import logger from utils import logger
def _quantize(x, bins): def _quantize(x, bins):
bins = copy.deepcopy(bins) bins = copy.deepcopy(bins)
bins = sorted(bins) bins = sorted(bins)
quantized = list(map(lambda y: bisect.bisect_right(bins, y), x)) quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
return quantized return quantized
def create_lengths_groups(lengths, k=0): def create_lengths_groups(lengths, k=0):
bins = np.arange(start=3, stop=k, step=4).tolist() if k > 0 else [10] bins = np.arange(start=3, stop=k, step=4).tolist() if k > 0 else [10]
groups = _quantize(lengths, bins) groups = _quantize(lengths, bins)
...@@ -39,6 +41,7 @@ def create_lengths_groups(lengths, k=0): ...@@ -39,6 +41,7 @@ def create_lengths_groups(lengths, k=0):
logger.info("Count of instances per bin: {}".format(counts)) logger.info("Count of instances per bin: {}".format(counts))
return groups return groups
class GroupedBatchSampler(BatchSampler): class GroupedBatchSampler(BatchSampler):
""" """
Wraps another sampler to yield a mini-batch of indices. Wraps another sampler to yield a mini-batch of indices.
...@@ -53,11 +56,11 @@ class GroupedBatchSampler(BatchSampler): ...@@ -53,11 +56,11 @@ class GroupedBatchSampler(BatchSampler):
0, i.e. they must be in the range [0, num_groups). 0, i.e. they must be in the range [0, num_groups).
batch_size (int): Size of mini-batch. batch_size (int): Size of mini-batch.
""" """
def __init__(self, sampler, group_ids, batch_size): def __init__(self, sampler, group_ids, batch_size):
if not isinstance(sampler, Sampler): if not isinstance(sampler, Sampler):
raise ValueError( raise ValueError(
"sampler should be an instance of " "sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}".format(sampler)
"torch.utils.data.Sampler, but got sampler={}".format(sampler)
) )
self.sampler = sampler self.sampler = sampler
self.group_ids = group_ids self.group_ids = group_ids
...@@ -73,7 +76,7 @@ class GroupedBatchSampler(BatchSampler): ...@@ -73,7 +76,7 @@ class GroupedBatchSampler(BatchSampler):
buffer_per_group[group_id].append(idx) buffer_per_group[group_id].append(idx)
samples_per_group[group_id].append(idx) samples_per_group[group_id].append(idx)
if len(buffer_per_group[group_id]) == self.batch_size: if len(buffer_per_group[group_id]) == self.batch_size:
yield buffer_per_group[group_id] #TODO yield buffer_per_group[group_id] # TODO
num_batches += 1 num_batches += 1
del buffer_per_group[group_id] del buffer_per_group[group_id]
assert len(buffer_per_group[group_id]) < self.batch_size assert len(buffer_per_group[group_id]) < self.batch_size
...@@ -90,8 +93,8 @@ class GroupedBatchSampler(BatchSampler): ...@@ -90,8 +93,8 @@ class GroupedBatchSampler(BatchSampler):
for group_id, idxs in sorted(buffer_per_group.items(), key=lambda x: x[0]): for group_id, idxs in sorted(buffer_per_group.items(), key=lambda x: x[0]):
batch_idx.extend(idxs) batch_idx.extend(idxs)
if len(batch_idx) >= self.batch_size: if len(batch_idx) >= self.batch_size:
yield batch_idx[:self.batch_size] yield batch_idx[: self.batch_size]
batch_idx = batch_idx[self.batch_size:] batch_idx = batch_idx[self.batch_size :]
num_remaining -= 1 num_remaining -= 1
if len(batch_idx) > 0: if len(batch_idx) > 0:
yield batch_idx yield batch_idx
......
...@@ -15,12 +15,13 @@ ...@@ -15,12 +15,13 @@
""" Dataset to distilled models """ Dataset to distilled models
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
""" """
import numpy as np
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
import numpy as np
from utils import logger from utils import logger
class LmSeqsDataset(Dataset): class LmSeqsDataset(Dataset):
"""Custom Dataset wrapping language modeling sequences. """Custom Dataset wrapping language modeling sequences.
...@@ -32,9 +33,7 @@ class LmSeqsDataset(Dataset): ...@@ -32,9 +33,7 @@ class LmSeqsDataset(Dataset):
data: `List[np.array[int]] data: `List[np.array[int]]
""" """
def __init__(self, def __init__(self, params, data):
params,
data):
self.params = params self.params = params
self.token_ids = np.array(data) self.token_ids = np.array(data)
...@@ -65,17 +64,17 @@ class LmSeqsDataset(Dataset): ...@@ -65,17 +64,17 @@ class LmSeqsDataset(Dataset):
""" """
max_len = self.params.max_model_input_size max_len = self.params.max_model_input_size
indices = self.lengths > max_len indices = self.lengths > max_len
logger.info(f'Splitting {sum(indices)} too long sequences.') logger.info(f"Splitting {sum(indices)} too long sequences.")
def divide_chunks(l, n): def divide_chunks(l, n):
return [l[i:i + n] for i in range(0, len(l), n)] return [l[i : i + n] for i in range(0, len(l), n)]
new_tok_ids = [] new_tok_ids = []
new_lengths = [] new_lengths = []
if self.params.mlm: if self.params.mlm:
cls_id, sep_id = self.params.special_tok_ids['cls_token'], self.params.special_tok_ids['sep_token'] cls_id, sep_id = self.params.special_tok_ids["cls_token"], self.params.special_tok_ids["sep_token"]
else: else:
cls_id, sep_id = self.params.special_tok_ids['bos_token'], self.params.special_tok_ids['eos_token'] cls_id, sep_id = self.params.special_tok_ids["bos_token"], self.params.special_tok_ids["eos_token"]
for seq_, len_ in zip(self.token_ids, self.lengths): for seq_, len_ in zip(self.token_ids, self.lengths):
assert (seq_[0] == cls_id) and (seq_[-1] == sep_id), seq_ assert (seq_[0] == cls_id) and (seq_[-1] == sep_id), seq_
...@@ -84,7 +83,7 @@ class LmSeqsDataset(Dataset): ...@@ -84,7 +83,7 @@ class LmSeqsDataset(Dataset):
new_lengths.append(len_) new_lengths.append(len_)
else: else:
sub_seqs = [] sub_seqs = []
for sub_s in divide_chunks(seq_, max_len-2): for sub_s in divide_chunks(seq_, max_len - 2):
if sub_s[0] != cls_id: if sub_s[0] != cls_id:
sub_s = np.insert(sub_s, 0, cls_id) sub_s = np.insert(sub_s, 0, cls_id)
if sub_s[-1] != sep_id: if sub_s[-1] != sep_id:
...@@ -108,7 +107,7 @@ class LmSeqsDataset(Dataset): ...@@ -108,7 +107,7 @@ class LmSeqsDataset(Dataset):
self.token_ids = self.token_ids[indices] self.token_ids = self.token_ids[indices]
self.lengths = self.lengths[indices] self.lengths = self.lengths[indices]
new_size = len(self) new_size = len(self)
logger.info(f'Remove {init_size - new_size} too short (<=11 tokens) sequences.') logger.info(f"Remove {init_size - new_size} too short (<=11 tokens) sequences.")
def print_statistics(self): def print_statistics(self):
""" """
...@@ -116,7 +115,7 @@ class LmSeqsDataset(Dataset): ...@@ -116,7 +115,7 @@ class LmSeqsDataset(Dataset):
""" """
if not self.params.is_master: if not self.params.is_master:
return return
logger.info(f'{len(self)} sequences') logger.info(f"{len(self)} sequences")
# data_len = sum(self.lengths) # data_len = sum(self.lengths)
# nb_unique_tokens = len(Counter(list(chain(*self.token_ids)))) # nb_unique_tokens = len(Counter(list(chain(*self.token_ids))))
# logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)') # logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)')
...@@ -125,8 +124,7 @@ class LmSeqsDataset(Dataset): ...@@ -125,8 +124,7 @@ class LmSeqsDataset(Dataset):
# nb_unkown = sum([(t==unk_idx).sum() for t in self.token_ids]) # nb_unkown = sum([(t==unk_idx).sum() for t in self.token_ids])
# logger.info(f'{nb_unkown} unknown tokens (covering {100*nb_unkown/data_len:.2f}% of the data)') # logger.info(f'{nb_unkown} unknown tokens (covering {100*nb_unkown/data_len:.2f}% of the data)')
def batch_sequences(self, def batch_sequences(self, batch):
batch):
""" """
Do the padding and transform into torch.tensor. Do the padding and transform into torch.tensor.
""" """
...@@ -139,10 +137,10 @@ class LmSeqsDataset(Dataset): ...@@ -139,10 +137,10 @@ class LmSeqsDataset(Dataset):
# Pad token ids # Pad token ids
if self.params.mlm: if self.params.mlm:
pad_idx = self.params.special_tok_ids['pad_token'] pad_idx = self.params.special_tok_ids["pad_token"]
else: else:
pad_idx = self.params.special_tok_ids['unk_token'] pad_idx = self.params.special_tok_ids["unk_token"]
tk_ = [list(t.astype(int)) + [pad_idx]*(max_seq_len_-len(t)) for t in token_ids] tk_ = [list(t.astype(int)) + [pad_idx] * (max_seq_len_ - len(t)) for t in token_ids]
assert len(tk_) == len(token_ids) assert len(tk_) == len(token_ids)
assert all(len(t) == max_seq_len_ for t in tk_) assert all(len(t) == max_seq_len_ for t in tk_)
......
...@@ -18,57 +18,73 @@ ...@@ -18,57 +18,73 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import argparse import argparse
import glob
import logging import logging
import os import os
import random import random
import glob
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset)
from torch.utils.data.distributed import DistributedSampler
import torch.nn.functional as F
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
try: from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.tensorboard import SummaryWriter from torch.utils.data.distributed import DistributedSampler
except:
from tensorboardX import SummaryWriter
from tqdm import tqdm, trange from tqdm import tqdm, trange
from transformers import (WEIGHTS_NAME, BertConfig, from transformers import (
BertForQuestionAnswering, BertTokenizer, WEIGHTS_NAME,
XLMConfig, XLMForQuestionAnswering, AdamW,
XLMTokenizer, XLNetConfig, BertConfig,
BertForQuestionAnswering,
BertTokenizer,
DistilBertConfig,
DistilBertForQuestionAnswering,
DistilBertTokenizer,
XLMConfig,
XLMForQuestionAnswering,
XLMTokenizer,
XLNetConfig,
XLNetForQuestionAnswering, XLNetForQuestionAnswering,
XLNetTokenizer, XLNetTokenizer,
DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer) get_linear_schedule_with_warmup,
)
from transformers import AdamW, get_linear_schedule_with_warmup
from ..utils_squad import (
from ..utils_squad import (read_squad_examples, convert_examples_to_features, RawResult,
RawResult, write_predictions, RawResultExtended,
RawResultExtended, write_predictions_extended) convert_examples_to_features,
read_squad_examples,
write_predictions,
write_predictions_extended,
)
# The follwing import is the official SQuAD evaluation script (2.0). # The follwing import is the official SQuAD evaluation script (2.0).
# You can remove it from the dependencies if you are using this script outside of the library # You can remove it from the dependencies if you are using this script outside of the library
# We've added it here for automated tests (see examples/test_examples.py file) # We've added it here for automated tests (see examples/test_examples.py file)
from ..utils_squad_evaluate import EVAL_OPTS, main as evaluate_on_squad from ..utils_squad_evaluate import EVAL_OPTS
from ..utils_squad_evaluate import main as evaluate_on_squad
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
from tensorboardX import SummaryWriter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \ ALL_MODELS = sum(
for conf in (BertConfig, XLNetConfig, XLMConfig)), ()) (tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig)), ()
)
MODEL_CLASSES = { MODEL_CLASSES = {
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer), "bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer), "xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer), "xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer) "distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
} }
def set_seed(args): def set_seed(args):
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
...@@ -76,9 +92,11 @@ def set_seed(args): ...@@ -76,9 +92,11 @@ def set_seed(args):
if args.n_gpu > 0: if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed) torch.cuda.manual_seed_all(args.seed)
def to_list(tensor): def to_list(tensor):
return tensor.detach().cpu().tolist() return tensor.detach().cpu().tolist()
def train(args, train_dataset, model, tokenizer, teacher=None): def train(args, train_dataset, model, tokenizer, teacher=None):
""" Train the model """ """ Train the model """
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
...@@ -95,13 +113,18 @@ def train(args, train_dataset, model, tokenizer, teacher=None): ...@@ -95,13 +113,18 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
# Prepare optimizer and schedule (linear warmup and decay) # Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight'] no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, {
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
] ]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)
if args.fp16: if args.fp16:
try: try:
from apex import amp from apex import amp
...@@ -115,17 +138,21 @@ def train(args, train_dataset, model, tokenizer, teacher=None): ...@@ -115,17 +138,21 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
# Distributed training (should be after apex fp16 initialization) # Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], model = torch.nn.parallel.DistributedDataParallel(
output_device=args.local_rank, model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
find_unused_parameters=True) )
# Train! # Train!
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", logger.info(
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) " Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size
* args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
)
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total) logger.info(" Total optimization steps = %d", t_total)
...@@ -141,37 +168,44 @@ def train(args, train_dataset, model, tokenizer, teacher=None): ...@@ -141,37 +168,44 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
if teacher is not None: if teacher is not None:
teacher.eval() teacher.eval()
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0], inputs = {
'attention_mask': batch[1], "input_ids": batch[0],
'start_positions': batch[3], "attention_mask": batch[1],
'end_positions': batch[4]} "start_positions": batch[3],
if args.model_type != 'distilbert': "end_positions": batch[4],
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] }
if args.model_type in ['xlnet', 'xlm']: if args.model_type != "distilbert":
inputs.update({'cls_index': batch[5], inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2]
'p_mask': batch[6]}) if args.model_type in ["xlnet", "xlm"]:
inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
outputs = model(**inputs) outputs = model(**inputs)
loss, start_logits_stu, end_logits_stu = outputs loss, start_logits_stu, end_logits_stu = outputs
# Distillation loss # Distillation loss
if teacher is not None: if teacher is not None:
if 'token_type_ids' not in inputs: if "token_type_ids" not in inputs:
inputs['token_type_ids'] = None if args.teacher_type == 'xlm' else batch[2] inputs["token_type_ids"] = None if args.teacher_type == "xlm" else batch[2]
with torch.no_grad(): with torch.no_grad():
start_logits_tea, end_logits_tea = teacher(input_ids=inputs['input_ids'], start_logits_tea, end_logits_tea = teacher(
token_type_ids=inputs['token_type_ids'], input_ids=inputs["input_ids"],
attention_mask=inputs['attention_mask']) token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"],
)
assert start_logits_tea.size() == start_logits_stu.size() assert start_logits_tea.size() == start_logits_stu.size()
assert end_logits_tea.size() == end_logits_stu.size() assert end_logits_tea.size() == end_logits_stu.size()
loss_fct = nn.KLDivLoss(reduction='batchmean') loss_fct = nn.KLDivLoss(reduction="batchmean")
loss_start = loss_fct(F.log_softmax(start_logits_stu/args.temperature, dim=-1), loss_start = loss_fct(
F.softmax(start_logits_tea/args.temperature, dim=-1)) * (args.temperature**2) F.log_softmax(start_logits_stu / args.temperature, dim=-1),
loss_end = loss_fct(F.log_softmax(end_logits_stu/args.temperature, dim=-1), F.softmax(start_logits_tea / args.temperature, dim=-1),
F.softmax(end_logits_tea/args.temperature, dim=-1)) * (args.temperature**2) ) * (args.temperature ** 2)
loss_ce = (loss_start + loss_end)/2. loss_end = loss_fct(
F.log_softmax(end_logits_stu / args.temperature, dim=-1),
F.softmax(end_logits_tea / args.temperature, dim=-1),
) * (args.temperature ** 2)
loss_ce = (loss_start + loss_end) / 2.0
loss = args.alpha_ce*loss_ce + args.alpha_squad*loss loss = args.alpha_ce * loss_ce + args.alpha_squad * loss
if args.n_gpu > 1: if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
...@@ -195,22 +229,26 @@ def train(args, train_dataset, model, tokenizer, teacher=None): ...@@ -195,22 +229,26 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
# Log metrics # Log metrics
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well if (
args.local_rank == -1 and args.evaluate_during_training
): # Only evaluate when single GPU otherwise metrics may not average well
results = evaluate(args, model, tokenizer) results = evaluate(args, model, tokenizer)
for key, value in results.items(): for key, value in results.items():
tb_writer.add_scalar('eval_{}'.format(key), value, global_step) tb_writer.add_scalar("eval_{}".format(key), value, global_step)
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step) tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
logging_loss = tr_loss logging_loss = tr_loss
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
# Save model checkpoint # Save model checkpoint
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir) model_to_save.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, 'training_args.bin')) torch.save(args, os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to %s", output_dir) logger.info("Saving model checkpoint to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps: if args.max_steps > 0 and global_step > args.max_steps:
...@@ -246,32 +284,31 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -246,32 +284,31 @@ def evaluate(args, model, tokenizer, prefix=""):
model.eval() model.eval()
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
with torch.no_grad(): with torch.no_grad():
inputs = {'input_ids': batch[0], inputs = {"input_ids": batch[0], "attention_mask": batch[1]}
'attention_mask': batch[1] if args.model_type != "distilbert":
} inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2] # XLM don't use segment_ids
if args.model_type != 'distilbert':
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
example_indices = batch[3] example_indices = batch[3]
if args.model_type in ['xlnet', 'xlm']: if args.model_type in ["xlnet", "xlm"]:
inputs.update({'cls_index': batch[4], inputs.update({"cls_index": batch[4], "p_mask": batch[5]})
'p_mask': batch[5]})
outputs = model(**inputs) outputs = model(**inputs)
for i, example_index in enumerate(example_indices): for i, example_index in enumerate(example_indices):
eval_feature = features[example_index.item()] eval_feature = features[example_index.item()]
unique_id = int(eval_feature.unique_id) unique_id = int(eval_feature.unique_id)
if args.model_type in ['xlnet', 'xlm']: if args.model_type in ["xlnet", "xlm"]:
# XLNet uses a more complex post-processing procedure # XLNet uses a more complex post-processing procedure
result = RawResultExtended(unique_id = unique_id, result = RawResultExtended(
start_top_log_probs = to_list(outputs[0][i]), unique_id=unique_id,
start_top_index = to_list(outputs[1][i]), start_top_log_probs=to_list(outputs[0][i]),
end_top_log_probs = to_list(outputs[2][i]), start_top_index=to_list(outputs[1][i]),
end_top_index = to_list(outputs[3][i]), end_top_log_probs=to_list(outputs[2][i]),
cls_logits = to_list(outputs[4][i])) end_top_index=to_list(outputs[3][i]),
cls_logits=to_list(outputs[4][i]),
)
else: else:
result = RawResult(unique_id = unique_id, result = RawResult(
start_logits = to_list(outputs[0][i]), unique_id=unique_id, start_logits=to_list(outputs[0][i]), end_logits=to_list(outputs[1][i])
end_logits = to_list(outputs[1][i])) )
all_results.append(result) all_results.append(result)
# Compute predictions # Compute predictions
...@@ -282,23 +319,44 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -282,23 +319,44 @@ def evaluate(args, model, tokenizer, prefix=""):
else: else:
output_null_log_odds_file = None output_null_log_odds_file = None
if args.model_type in ['xlnet', 'xlm']: if args.model_type in ["xlnet", "xlm"]:
# XLNet uses a more complex post-processing procedure # XLNet uses a more complex post-processing procedure
write_predictions_extended(examples, features, all_results, args.n_best_size, write_predictions_extended(
args.max_answer_length, output_prediction_file, examples,
output_nbest_file, output_null_log_odds_file, args.predict_file, features,
model.config.start_n_top, model.config.end_n_top, all_results,
args.version_2_with_negative, tokenizer, args.verbose_logging) args.n_best_size,
args.max_answer_length,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
args.predict_file,
model.config.start_n_top,
model.config.end_n_top,
args.version_2_with_negative,
tokenizer,
args.verbose_logging,
)
else: else:
write_predictions(examples, features, all_results, args.n_best_size, write_predictions(
args.max_answer_length, args.do_lower_case, output_prediction_file, examples,
output_nbest_file, output_null_log_odds_file, args.verbose_logging, features,
args.version_2_with_negative, args.null_score_diff_threshold) all_results,
args.n_best_size,
args.max_answer_length,
args.do_lower_case,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
args.verbose_logging,
args.version_2_with_negative,
args.null_score_diff_threshold,
)
# Evaluate with the official SQuAD script # Evaluate with the official SQuAD script
evaluate_options = EVAL_OPTS(data_file=args.predict_file, evaluate_options = EVAL_OPTS(
pred_file=output_prediction_file, data_file=args.predict_file, pred_file=output_prediction_file, na_prob_file=output_null_log_odds_file
na_prob_file=output_null_log_odds_file) )
results = evaluate_on_squad(evaluate_options) results = evaluate_on_squad(evaluate_options)
return results return results
...@@ -309,24 +367,30 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ...@@ -309,24 +367,30 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
# Load data features from cache or dataset file # Load data features from cache or dataset file
input_file = args.predict_file if evaluate else args.train_file input_file = args.predict_file if evaluate else args.train_file
cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format( cached_features_file = os.path.join(
'dev' if evaluate else 'train', os.path.dirname(input_file),
list(filter(None, args.model_name_or_path.split('/'))).pop(), "cached_{}_{}_{}".format(
str(args.max_seq_length))) "dev" if evaluate else "train",
list(filter(None, args.model_name_or_path.split("/"))).pop(),
str(args.max_seq_length),
),
)
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples: if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
logger.info("Loading features from cached file %s", cached_features_file) logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file) features = torch.load(cached_features_file)
else: else:
logger.info("Creating features from dataset file at %s", input_file) logger.info("Creating features from dataset file at %s", input_file)
examples = read_squad_examples(input_file=input_file, examples = read_squad_examples(
is_training=not evaluate, input_file=input_file, is_training=not evaluate, version_2_with_negative=args.version_2_with_negative
version_2_with_negative=args.version_2_with_negative) )
features = convert_examples_to_features(examples=examples, features = convert_examples_to_features(
examples=examples,
tokenizer=tokenizer, tokenizer=tokenizer,
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride, doc_stride=args.doc_stride,
max_query_length=args.max_query_length, max_query_length=args.max_query_length,
is_training=not evaluate) is_training=not evaluate,
)
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file) logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file) torch.save(features, cached_features_file)
...@@ -342,14 +406,21 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ...@@ -342,14 +406,21 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float) all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
if evaluate: if evaluate:
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, dataset = TensorDataset(
all_example_index, all_cls_index, all_p_mask) all_input_ids, all_input_mask, all_segment_ids, all_example_index, all_cls_index, all_p_mask
)
else: else:
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long) all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long) all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, dataset = TensorDataset(
all_start_positions, all_end_positions, all_input_ids,
all_cls_index, all_p_mask) all_input_mask,
all_segment_ids,
all_start_positions,
all_end_positions,
all_cls_index,
all_p_mask,
)
if output_examples: if output_examples:
return dataset, examples, features return dataset, examples, features
...@@ -359,122 +430,214 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ...@@ -359,122 +430,214 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument("--train_file", default=None, type=str, required=True, parser.add_argument(
help="SQuAD json for training. E.g., train-v1.1.json") "--train_file", default=None, type=str, required=True, help="SQuAD json for training. E.g., train-v1.1.json"
parser.add_argument("--predict_file", default=None, type=str, required=True, )
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json") parser.add_argument(
parser.add_argument("--model_type", default=None, type=str, required=True, "--predict_file",
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) default=None,
parser.add_argument("--model_name_or_path", default=None, type=str, required=True, type=str,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) required=True,
parser.add_argument("--output_dir", default=None, type=str, required=True, help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json",
help="The output directory where the model checkpoints and predictions will be written.") )
parser.add_argument(
"--model_type",
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
)
parser.add_argument(
"--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model checkpoints and predictions will be written.",
)
# Distillation parameters (optional) # Distillation parameters (optional)
parser.add_argument('--teacher_type', default=None, type=str, parser.add_argument(
help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.") "--teacher_type",
parser.add_argument('--teacher_name_or_path', default=None, type=str, default=None,
help="Path to the already SQuAD fine-tuned teacher model. Only for distillation.") type=str,
parser.add_argument('--alpha_ce', default=0.5, type=float, help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.",
help="Distillation loss linear weight. Only for distillation.") )
parser.add_argument('--alpha_squad', default=0.5, type=float, parser.add_argument(
help="True SQuAD loss linear weight. Only for distillation.") "--teacher_name_or_path",
parser.add_argument('--temperature', default=2.0, type=float, default=None,
help="Distillation temperature. Only for distillation.") type=str,
help="Path to the already SQuAD fine-tuned teacher model. Only for distillation.",
## Other parameters )
parser.add_argument("--config_name", default="", type=str, parser.add_argument(
help="Pretrained config name or path if not the same as model_name") "--alpha_ce", default=0.5, type=float, help="Distillation loss linear weight. Only for distillation."
parser.add_argument("--tokenizer_name", default="", type=str, )
help="Pretrained tokenizer name or path if not the same as model_name") parser.add_argument(
parser.add_argument("--cache_dir", default="", type=str, "--alpha_squad", default=0.5, type=float, help="True SQuAD loss linear weight. Only for distillation."
help="Where do you want to store the pre-trained models downloaded from s3") )
parser.add_argument(
parser.add_argument('--version_2_with_negative', action='store_true', "--temperature", default=2.0, type=float, help="Distillation temperature. Only for distillation."
help='If true, the SQuAD examples contain some that do not have an answer.') )
parser.add_argument('--null_score_diff_threshold', type=float, default=0.0,
help="If null_score - best_non_null is greater than the threshold predict null.") # Other parameters
parser.add_argument(
parser.add_argument("--max_seq_length", default=384, type=int, "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
)
parser.add_argument(
"--tokenizer_name",
default="",
type=str,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--cache_dir",
default="",
type=str,
help="Where do you want to store the pre-trained models downloaded from s3",
)
parser.add_argument(
"--version_2_with_negative",
action="store_true",
help="If true, the SQuAD examples contain some that do not have an answer.",
)
parser.add_argument(
"--null_score_diff_threshold",
type=float,
default=0.0,
help="If null_score - best_non_null is greater than the threshold predict null.",
)
parser.add_argument(
"--max_seq_length",
default=384,
type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences " help="The maximum total input sequence length after WordPiece tokenization. Sequences "
"longer than this will be truncated, and sequences shorter than this will be padded.") "longer than this will be truncated, and sequences shorter than this will be padded.",
parser.add_argument("--doc_stride", default=128, type=int, )
help="When splitting up a long document into chunks, how much stride to take between chunks.") parser.add_argument(
parser.add_argument("--max_query_length", default=64, type=int, "--doc_stride",
default=128,
type=int,
help="When splitting up a long document into chunks, how much stride to take between chunks.",
)
parser.add_argument(
"--max_query_length",
default=64,
type=int,
help="The maximum number of tokens for the question. Questions longer than this will " help="The maximum number of tokens for the question. Questions longer than this will "
"be truncated to this length.") "be truncated to this length.",
parser.add_argument("--do_train", action='store_true', )
help="Whether to run training.") parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true', parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
help="Whether to run eval on the dev set.") parser.add_argument(
parser.add_argument("--evaluate_during_training", action='store_true', "--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
help="Rul evaluation during training at each logging step.") )
parser.add_argument("--do_lower_case", action='store_true', parser.add_argument(
help="Set this flag if you are using an uncased model.") "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
)
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
help="Batch size per GPU/CPU for training.") parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, parser.add_argument(
help="Batch size per GPU/CPU for evaluation.") "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
parser.add_argument("--learning_rate", default=5e-5, type=float, )
help="The initial learning rate for Adam.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1, parser.add_argument(
help="Number of updates steps to accumulate before performing a backward/update pass.") "--gradient_accumulation_steps",
parser.add_argument("--weight_decay", default=0.0, type=float, type=int,
help="Weight deay if we apply some.") default=1,
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Number of updates steps to accumulate before performing a backward/update pass.",
help="Epsilon for Adam optimizer.") )
parser.add_argument("--max_grad_norm", default=1.0, type=float, parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
help="Max gradient norm.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--num_train_epochs", default=3.0, type=float, parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
help="Total number of training epochs to perform.") parser.add_argument(
parser.add_argument("--max_steps", default=-1, type=int, "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
help="If > 0: set total number of training steps to perform. Override num_train_epochs.") )
parser.add_argument("--warmup_steps", default=0, type=int, parser.add_argument(
help="Linear warmup over warmup_steps.") "--max_steps",
parser.add_argument("--n_best_size", default=20, type=int, default=-1,
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.") type=int,
parser.add_argument("--max_answer_length", default=30, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
)
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument(
"--n_best_size",
default=20,
type=int,
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.",
)
parser.add_argument(
"--max_answer_length",
default=30,
type=int,
help="The maximum length of an answer that can be generated. This is needed because the start " help="The maximum length of an answer that can be generated. This is needed because the start "
"and end predictions are not conditioned on one another.") "and end predictions are not conditioned on one another.",
parser.add_argument("--verbose_logging", action='store_true', )
parser.add_argument(
"--verbose_logging",
action="store_true",
help="If true, all of the warnings related to data processing will be printed. " help="If true, all of the warnings related to data processing will be printed. "
"A number of warnings are expected for a normal SQuAD evaluation.") "A number of warnings are expected for a normal SQuAD evaluation.",
)
parser.add_argument('--logging_steps', type=int, default=50,
help="Log every X updates steps.") parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
parser.add_argument('--save_steps', type=int, default=50, parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
help="Save checkpoint every X updates steps.") parser.add_argument(
parser.add_argument("--eval_all_checkpoints", action='store_true', "--eval_all_checkpoints",
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") action="store_true",
parser.add_argument("--no_cuda", action='store_true', help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
help="Whether not to use CUDA when available") )
parser.add_argument('--overwrite_output_dir', action='store_true', parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
help="Overwrite the content of the output directory") parser.add_argument(
parser.add_argument('--overwrite_cache', action='store_true', "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
help="Overwrite the cached training and evaluation sets") )
parser.add_argument('--seed', type=int, default=42, parser.add_argument(
help="random seed for initialization") "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
)
parser.add_argument("--local_rank", type=int, default=-1, parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
help="local_rank for distributed training on gpus")
parser.add_argument('--fp16', action='store_true', parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") parser.add_argument(
parser.add_argument('--fp16_opt_level', type=str, default='O1', "--fp16",
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
parser.add_argument(
"--fp16_opt_level",
type=str,
default="O1",
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html") "See details at https://nvidia.github.io/apex/amp.html",
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") )
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
args = parser.parse_args() args = parser.parse_args()
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: if (
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) os.path.exists(args.output_dir)
and os.listdir(args.output_dir)
and args.do_train
and not args.overwrite_output_dir
):
raise ValueError(
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
args.output_dir
)
)
# Setup distant debugging if needed # Setup distant debugging if needed
if args.server_ip and args.server_port: if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd import ptvsd
print("Waiting for debugger attach") print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach() ptvsd.wait_for_attach()
...@@ -486,16 +649,24 @@ def main(): ...@@ -486,16 +649,24 @@ def main():
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.cuda.set_device(args.local_rank) torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank) device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend='nccl') torch.distributed.init_process_group(backend="nccl")
args.n_gpu = 1 args.n_gpu = 1
args.device = device args.device = device
# Setup logging # Setup logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(
datefmt = '%m/%d/%Y %H:%M:%S', format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) datefmt="%m/%d/%Y %H:%M:%S",
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) )
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
args.local_rank,
device,
args.n_gpu,
bool(args.local_rank != -1),
args.fp16,
)
# Set seed # Set seed
set_seed(args) set_seed(args)
...@@ -506,27 +677,34 @@ def main(): ...@@ -506,27 +677,34 @@ def main():
args.model_type = args.model_type.lower() args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, config = config_class.from_pretrained(
cache_dir=args.cache_dir if args.cache_dir else None) args.config_name if args.config_name else args.model_name_or_path,
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, cache_dir=args.cache_dir if args.cache_dir else None,
)
tokenizer = tokenizer_class.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
do_lower_case=args.do_lower_case, do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None) cache_dir=args.cache_dir if args.cache_dir else None,
model = model_class.from_pretrained(args.model_name_or_path, )
from_tf=bool('.ckpt' in args.model_name_or_path), model = model_class.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config, config=config,
cache_dir=args.cache_dir if args.cache_dir else None) cache_dir=args.cache_dir if args.cache_dir else None,
)
if args.teacher_type is not None: if args.teacher_type is not None:
assert args.teacher_name_or_path is not None assert args.teacher_name_or_path is not None
assert args.alpha_ce > 0. assert args.alpha_ce > 0.0
assert args.alpha_ce + args.alpha_squad > 0. assert args.alpha_ce + args.alpha_squad > 0.0
assert args.teacher_type != 'distilbert', "We constraint teachers not to be of type DistilBERT." assert args.teacher_type != "distilbert", "We constraint teachers not to be of type DistilBERT."
teacher_config_class, teacher_model_class, _ = MODEL_CLASSES[args.teacher_type] teacher_config_class, teacher_model_class, _ = MODEL_CLASSES[args.teacher_type]
teacher_config = teacher_config_class.from_pretrained(args.teacher_name_or_path, teacher_config = teacher_config_class.from_pretrained(
cache_dir=args.cache_dir if args.cache_dir else None) args.teacher_name_or_path, cache_dir=args.cache_dir if args.cache_dir else None
teacher = teacher_model_class.from_pretrained(args.teacher_name_or_path, )
config=teacher_config, teacher = teacher_model_class.from_pretrained(
cache_dir=args.cache_dir if args.cache_dir else None) args.teacher_name_or_path, config=teacher_config, cache_dir=args.cache_dir if args.cache_dir else None
)
teacher.to(args.device) teacher.to(args.device)
else: else:
teacher = None teacher = None
...@@ -544,7 +722,6 @@ def main(): ...@@ -544,7 +722,6 @@ def main():
global_step, tr_loss = train(args, train_dataset, model, tokenizer, teacher=teacher) global_step, tr_loss = train(args, train_dataset, model, tokenizer, teacher=teacher)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
# Save the trained model and the tokenizer # Save the trained model and the tokenizer
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
# Create output directory if needed # Create output directory if needed
...@@ -554,41 +731,44 @@ def main(): ...@@ -554,41 +731,44 @@ def main():
logger.info("Saving model checkpoint to %s", args.output_dir) logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`. # Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()` # They can then be reloaded using `from_pretrained()`
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(args.output_dir) model_to_save.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
# Good practice: save your training arguments together with the trained model # Good practice: save your training arguments together with the trained model
torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
# Load a trained model and vocabulary that you have fine-tuned # Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir, cache_dir=args.cache_dir if args.cache_dir else None) model = model_class.from_pretrained(args.output_dir, cache_dir=args.cache_dir if args.cache_dir else None)
tokenizer = tokenizer_class.from_pretrained(args.output_dir, tokenizer = tokenizer_class.from_pretrained(
do_lower_case=args.do_lower_case, args.output_dir, do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None
cache_dir=args.cache_dir if args.cache_dir else None) )
model.to(args.device) model.to(args.device)
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
results = {} results = {}
if args.do_eval and args.local_rank in [-1, 0]: if args.do_eval and args.local_rank in [-1, 0]:
checkpoints = [args.output_dir] checkpoints = [args.output_dir]
if args.eval_all_checkpoints: if args.eval_all_checkpoints:
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) checkpoints = list(
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
)
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
# Reload the model # Reload the model
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
model = model_class.from_pretrained(checkpoint, cache_dir=args.cache_dir if args.cache_dir else None) model = model_class.from_pretrained(checkpoint, cache_dir=args.cache_dir if args.cache_dir else None)
model.to(args.device) model.to(args.device)
# Evaluate # Evaluate
result = evaluate(args, model, tokenizer, prefix=global_step) result = evaluate(args, model, tokenizer, prefix=global_step)
result = dict((k + ('_{}'.format(global_step) if global_step else ''), v) for k, v in result.items()) result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
results.update(result) results.update(result)
logger.info("Results: {}".format(results)) logger.info("Results: {}".format(results))
......
...@@ -16,75 +16,75 @@ ...@@ -16,75 +16,75 @@
Preprocessing script before distillation. Preprocessing script before distillation.
""" """
import argparse import argparse
import logging
import pickle import pickle
import random import random
import time import time
import numpy as np import numpy as np
from transformers import BertTokenizer, RobertaTokenizer, GPT2Tokenizer
import logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', from transformers import BertTokenizer, GPT2Tokenizer, RobertaTokenizer
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def main(): def main():
parser = argparse.ArgumentParser(description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids).") parser = argparse.ArgumentParser(
parser.add_argument('--file_path', type=str, default='data/dump.txt', description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids)."
help='The path to the data.') )
parser.add_argument('--tokenizer_type', type=str, default='bert', choices=['bert', 'roberta', 'gpt2']) parser.add_argument("--file_path", type=str, default="data/dump.txt", help="The path to the data.")
parser.add_argument('--tokenizer_name', type=str, default='bert-base-uncased', parser.add_argument("--tokenizer_type", type=str, default="bert", choices=["bert", "roberta", "gpt2"])
help="The tokenizer to use.") parser.add_argument("--tokenizer_name", type=str, default="bert-base-uncased", help="The tokenizer to use.")
parser.add_argument('--dump_file', type=str, default='data/dump', parser.add_argument("--dump_file", type=str, default="data/dump", help="The dump file prefix.")
help='The dump file prefix.')
args = parser.parse_args() args = parser.parse_args()
logger.info(f"Loading Tokenizer ({args.tokenizer_name})")
logger.info(f'Loading Tokenizer ({args.tokenizer_name})') if args.tokenizer_type == "bert":
if args.tokenizer_type == 'bert':
tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name) tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name)
bos = tokenizer.special_tokens_map['cls_token'] # `[CLS]` bos = tokenizer.special_tokens_map["cls_token"] # `[CLS]`
sep = tokenizer.special_tokens_map['sep_token'] # `[SEP]` sep = tokenizer.special_tokens_map["sep_token"] # `[SEP]`
elif args.tokenizer_type == 'roberta': elif args.tokenizer_type == "roberta":
tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name) tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name)
bos = tokenizer.special_tokens_map['cls_token'] # `<s>` bos = tokenizer.special_tokens_map["cls_token"] # `<s>`
sep = tokenizer.special_tokens_map['sep_token'] # `</s>` sep = tokenizer.special_tokens_map["sep_token"] # `</s>`
elif args.tokenizer_type == 'gpt2': elif args.tokenizer_type == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_name) tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_name)
bos = tokenizer.special_tokens_map['bos_token'] # `<|endoftext|>` bos = tokenizer.special_tokens_map["bos_token"] # `<|endoftext|>`
sep = tokenizer.special_tokens_map['eos_token'] # `<|endoftext|>` sep = tokenizer.special_tokens_map["eos_token"] # `<|endoftext|>`
logger.info(f'Loading text from {args.file_path}') logger.info(f"Loading text from {args.file_path}")
with open(args.file_path, 'r', encoding='utf8') as fp: with open(args.file_path, "r", encoding="utf8") as fp:
data = fp.readlines() data = fp.readlines()
logger.info(f"Start encoding")
logger.info(f'Start encoding') logger.info(f"{len(data)} examples to process.")
logger.info(f'{len(data)} examples to process.')
rslt = [] rslt = []
iter = 0 iter = 0
interval = 10000 interval = 10000
start = time.time() start = time.time()
for text in data: for text in data:
text = f'{bos} {text.strip()} {sep}' text = f"{bos} {text.strip()} {sep}"
token_ids = tokenizer.encode(text, add_special_tokens=False) token_ids = tokenizer.encode(text, add_special_tokens=False)
rslt.append(token_ids) rslt.append(token_ids)
iter += 1 iter += 1
if iter % interval == 0: if iter % interval == 0:
end = time.time() end = time.time()
logger.info(f'{iter} examples processed. - {(end-start)/interval:.2f}s/expl') logger.info(f"{iter} examples processed. - {(end-start)/interval:.2f}s/expl")
start = time.time() start = time.time()
logger.info('Finished binarization') logger.info("Finished binarization")
logger.info(f'{len(data)} examples processed.') logger.info(f"{len(data)} examples processed.")
dp_file = f'{args.dump_file}.{args.tokenizer_name}.pickle' dp_file = f"{args.dump_file}.{args.tokenizer_name}.pickle"
rslt_ = [np.uint16(d) for d in rslt] rslt_ = [np.uint16(d) for d in rslt]
random.shuffle(rslt_) random.shuffle(rslt_)
logger.info(f'Dump to {dp_file}') logger.info(f"Dump to {dp_file}")
with open(dp_file, 'wb') as handle: with open(dp_file, "wb") as handle:
pickle.dump(rslt_, handle, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(rslt_, handle, protocol=pickle.HIGHEST_PROTOCOL)
......
...@@ -16,74 +16,87 @@ ...@@ -16,74 +16,87 @@
Preprocessing script before training the distilled model. Preprocessing script before training the distilled model.
Specific to RoBERTa -> DistilRoBERTa and GPT2 -> DistilGPT2. Specific to RoBERTa -> DistilRoBERTa and GPT2 -> DistilGPT2.
""" """
from transformers import BertForMaskedLM, RobertaForMaskedLM, GPT2LMHeadModel
import torch
import argparse import argparse
if __name__ == '__main__': import torch
parser = argparse.ArgumentParser(description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation")
from transformers import GPT2LMHeadModel, RobertaForMaskedLM
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation"
)
parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"]) parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"])
parser.add_argument("--model_name", default='roberta-large', type=str) parser.add_argument("--model_name", default="roberta-large", type=str)
parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_roberta_048131723.pth', type=str) parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_roberta_048131723.pth", type=str)
parser.add_argument("--vocab_transform", action='store_true') parser.add_argument("--vocab_transform", action="store_true")
args = parser.parse_args() args = parser.parse_args()
if args.model_type == "roberta":
if args.model_type == 'roberta':
model = RobertaForMaskedLM.from_pretrained(args.model_name) model = RobertaForMaskedLM.from_pretrained(args.model_name)
prefix = 'roberta' prefix = "roberta"
elif args.model_type == 'gpt2': elif args.model_type == "gpt2":
model = GPT2LMHeadModel.from_pretrained(args.model_name) model = GPT2LMHeadModel.from_pretrained(args.model_name)
prefix = 'transformer' prefix = "transformer"
state_dict = model.state_dict() state_dict = model.state_dict()
compressed_sd = {} compressed_sd = {}
### Embeddings ### # Embeddings #
if args.model_type == 'gpt2': if args.model_type == "gpt2":
for param_name in ['wte.weight', 'wpe.weight']: for param_name in ["wte.weight", "wpe.weight"]:
compressed_sd[f'{prefix}.{param_name}'] = state_dict[f'{prefix}.{param_name}'] compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"]
else: else:
for w in ['word_embeddings', 'position_embeddings', 'token_type_embeddings']: for w in ["word_embeddings", "position_embeddings", "token_type_embeddings"]:
param_name = f'{prefix}.embeddings.{w}.weight' param_name = f"{prefix}.embeddings.{w}.weight"
compressed_sd[param_name] = state_dict[param_name] compressed_sd[param_name] = state_dict[param_name]
for w in ['weight', 'bias']: for w in ["weight", "bias"]:
param_name = f'{prefix}.embeddings.LayerNorm.{w}' param_name = f"{prefix}.embeddings.LayerNorm.{w}"
compressed_sd[param_name] = state_dict[param_name] compressed_sd[param_name] = state_dict[param_name]
### Transformer Blocks ### # Transformer Blocks #
std_idx = 0 std_idx = 0
for teacher_idx in [0, 2, 4, 7, 9, 11]: for teacher_idx in [0, 2, 4, 7, 9, 11]:
if args.model_type == 'gpt2': if args.model_type == "gpt2":
for layer in ['ln_1', 'attn.c_attn', 'attn.c_proj', 'ln_2', 'mlp.c_fc', 'mlp.c_proj']: for layer in ["ln_1", "attn.c_attn", "attn.c_proj", "ln_2", "mlp.c_fc", "mlp.c_proj"]:
for w in ['weight', 'bias']: for w in ["weight", "bias"]:
compressed_sd[f'{prefix}.h.{std_idx}.{layer}.{w}'] = \ compressed_sd[f"{prefix}.h.{std_idx}.{layer}.{w}"] = state_dict[
state_dict[f'{prefix}.h.{teacher_idx}.{layer}.{w}'] f"{prefix}.h.{teacher_idx}.{layer}.{w}"
compressed_sd[f'{prefix}.h.{std_idx}.attn.bias'] = state_dict[f'{prefix}.h.{teacher_idx}.attn.bias'] ]
compressed_sd[f"{prefix}.h.{std_idx}.attn.bias"] = state_dict[f"{prefix}.h.{teacher_idx}.attn.bias"]
else: else:
for layer in ['attention.self.query', 'attention.self.key', 'attention.self.value', for layer in [
'attention.output.dense', 'attention.output.LayerNorm', "attention.self.query",
'intermediate.dense', 'output.dense', 'output.LayerNorm']: "attention.self.key",
for w in ['weight', 'bias']: "attention.self.value",
compressed_sd[f'{prefix}.encoder.layer.{std_idx}.{layer}.{w}'] = \ "attention.output.dense",
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.{layer}.{w}'] "attention.output.LayerNorm",
"intermediate.dense",
"output.dense",
"output.LayerNorm",
]:
for w in ["weight", "bias"]:
compressed_sd[f"{prefix}.encoder.layer.{std_idx}.{layer}.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.{layer}.{w}"
]
std_idx += 1 std_idx += 1
### Language Modeling Head ###s # Language Modeling Head ###s
if args.model_type == 'roberta': if args.model_type == "roberta":
for layer in ['lm_head.decoder.weight', 'lm_head.bias']: for layer in ["lm_head.decoder.weight", "lm_head.bias"]:
compressed_sd[f'{layer}'] = state_dict[f'{layer}'] compressed_sd[f"{layer}"] = state_dict[f"{layer}"]
if args.vocab_transform: if args.vocab_transform:
for w in ['weight', 'bias']: for w in ["weight", "bias"]:
compressed_sd[f'lm_head.dense.{w}'] = state_dict[f'lm_head.dense.{w}'] compressed_sd[f"lm_head.dense.{w}"] = state_dict[f"lm_head.dense.{w}"]
compressed_sd[f'lm_head.layer_norm.{w}'] = state_dict[f'lm_head.layer_norm.{w}'] compressed_sd[f"lm_head.layer_norm.{w}"] = state_dict[f"lm_head.layer_norm.{w}"]
elif args.model_type == 'gpt2': elif args.model_type == "gpt2":
for w in ['weight', 'bias']: for w in ["weight", "bias"]:
compressed_sd[f'{prefix}.ln_f.{w}'] = state_dict[f'{prefix}.ln_f.{w}'] compressed_sd[f"{prefix}.ln_f.{w}"] = state_dict[f"{prefix}.ln_f.{w}"]
compressed_sd[f'lm_head.weight'] = state_dict[f'lm_head.weight'] compressed_sd[f"lm_head.weight"] = state_dict[f"lm_head.weight"]
print(f'N layers selected for distillation: {std_idx}') print(f"N layers selected for distillation: {std_idx}")
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}') print(f"Number of params transfered for distillation: {len(compressed_sd.keys())}")
print(f'Save transfered checkpoint to {args.dump_checkpoint}.') print(f"Save transfered checkpoint to {args.dump_checkpoint}.")
torch.save(compressed_sd, args.dump_checkpoint) torch.save(compressed_sd, args.dump_checkpoint)
...@@ -16,67 +16,77 @@ ...@@ -16,67 +16,77 @@
Preprocessing script before training DistilBERT. Preprocessing script before training DistilBERT.
Specific to BERT -> DistilBERT. Specific to BERT -> DistilBERT.
""" """
from transformers import BertForMaskedLM, RobertaForMaskedLM
import torch
import argparse import argparse
if __name__ == '__main__': import torch
parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation")
from transformers import BertForMaskedLM
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation"
)
parser.add_argument("--model_type", default="bert", choices=["bert"]) parser.add_argument("--model_type", default="bert", choices=["bert"])
parser.add_argument("--model_name", default='bert-base-uncased', type=str) parser.add_argument("--model_name", default="bert-base-uncased", type=str)
parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_bert-base-uncased_0247911.pth', type=str) parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_bert-base-uncased_0247911.pth", type=str)
parser.add_argument("--vocab_transform", action='store_true') parser.add_argument("--vocab_transform", action="store_true")
args = parser.parse_args() args = parser.parse_args()
if args.model_type == "bert":
if args.model_type == 'bert':
model = BertForMaskedLM.from_pretrained(args.model_name) model = BertForMaskedLM.from_pretrained(args.model_name)
prefix = 'bert' prefix = "bert"
else: else:
raise ValueError(f'args.model_type should be "bert".') raise ValueError(f'args.model_type should be "bert".')
state_dict = model.state_dict() state_dict = model.state_dict()
compressed_sd = {} compressed_sd = {}
for w in ['word_embeddings', 'position_embeddings']: for w in ["word_embeddings", "position_embeddings"]:
compressed_sd[f'distilbert.embeddings.{w}.weight'] = \ compressed_sd[f"distilbert.embeddings.{w}.weight"] = state_dict[f"{prefix}.embeddings.{w}.weight"]
state_dict[f'{prefix}.embeddings.{w}.weight'] for w in ["weight", "bias"]:
for w in ['weight', 'bias']: compressed_sd[f"distilbert.embeddings.LayerNorm.{w}"] = state_dict[f"{prefix}.embeddings.LayerNorm.{w}"]
compressed_sd[f'distilbert.embeddings.LayerNorm.{w}'] = \
state_dict[f'{prefix}.embeddings.LayerNorm.{w}']
std_idx = 0 std_idx = 0
for teacher_idx in [0, 2, 4, 7, 9, 11]: for teacher_idx in [0, 2, 4, 7, 9, 11]:
for w in ['weight', 'bias']: for w in ["weight", "bias"]:
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}'] = \ compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}"] = state_dict[
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.query.{w}'] f"{prefix}.encoder.layer.{teacher_idx}.attention.self.query.{w}"
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}'] = \ ]
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.key.{w}'] compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}"] = state_dict[
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}'] = \ f"{prefix}.encoder.layer.{teacher_idx}.attention.self.key.{w}"
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.value.{w}'] ]
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.attention.self.value.{w}"
]
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}'] = \ compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}"] = state_dict[
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.output.dense.{w}'] f"{prefix}.encoder.layer.{teacher_idx}.attention.output.dense.{w}"
compressed_sd[f'distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}'] = \ ]
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}'] compressed_sd[f"distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}"
]
compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}'] = \ compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}"] = state_dict[
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.intermediate.dense.{w}'] f"{prefix}.encoder.layer.{teacher_idx}.intermediate.dense.{w}"
compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}'] = \ ]
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.dense.{w}'] compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}"] = state_dict[
compressed_sd[f'distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}'] = \ f"{prefix}.encoder.layer.{teacher_idx}.output.dense.{w}"
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}'] ]
compressed_sd[f"distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}"
]
std_idx += 1 std_idx += 1
compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight'] compressed_sd[f"vocab_projector.weight"] = state_dict[f"cls.predictions.decoder.weight"]
compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias'] compressed_sd[f"vocab_projector.bias"] = state_dict[f"cls.predictions.bias"]
if args.vocab_transform: if args.vocab_transform:
for w in ['weight', 'bias']: for w in ["weight", "bias"]:
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}'] compressed_sd[f"vocab_transform.{w}"] = state_dict[f"cls.predictions.transform.dense.{w}"]
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}'] compressed_sd[f"vocab_layer_norm.{w}"] = state_dict[f"cls.predictions.transform.LayerNorm.{w}"]
print(f'N layers selected for distillation: {std_idx}') print(f"N layers selected for distillation: {std_idx}")
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}') print(f"Number of params transfered for distillation: {len(compressed_sd.keys())}")
print(f'Save transfered checkpoint to {args.dump_checkpoint}.') print(f"Save transfered checkpoint to {args.dump_checkpoint}.")
torch.save(compressed_sd, args.dump_checkpoint) torch.save(compressed_sd, args.dump_checkpoint)
...@@ -15,37 +15,42 @@ ...@@ -15,37 +15,42 @@
""" """
Preprocessing script before training the distilled model. Preprocessing script before training the distilled model.
""" """
from collections import Counter
import argparse import argparse
import pickle
import logging import logging
import pickle
from collections import Counter
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(
datefmt = '%m/%d/%Y %H:%M:%S', format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
level = logging.INFO) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Token Counts for smoothing the masking probabilities in MLM (cf XLM/word2vec)") parser = argparse.ArgumentParser(
parser.add_argument("--data_file", type=str, default="data/dump.bert-base-uncased.pickle", description="Token Counts for smoothing the masking probabilities in MLM (cf XLM/word2vec)"
help="The binarized dataset.") )
parser.add_argument("--token_counts_dump", type=str, default="data/token_counts.bert-base-uncased.pickle", parser.add_argument(
help="The dump file.") "--data_file", type=str, default="data/dump.bert-base-uncased.pickle", help="The binarized dataset."
)
parser.add_argument(
"--token_counts_dump", type=str, default="data/token_counts.bert-base-uncased.pickle", help="The dump file."
)
parser.add_argument("--vocab_size", default=30522, type=int) parser.add_argument("--vocab_size", default=30522, type=int)
args = parser.parse_args() args = parser.parse_args()
logger.info(f'Loading data from {args.data_file}') logger.info(f"Loading data from {args.data_file}")
with open(args.data_file, 'rb') as fp: with open(args.data_file, "rb") as fp:
data = pickle.load(fp) data = pickle.load(fp)
logger.info('Counting occurences for MLM.') logger.info("Counting occurences for MLM.")
counter = Counter() counter = Counter()
for tk_ids in data: for tk_ids in data:
counter.update(tk_ids) counter.update(tk_ids)
counts = [0]*args.vocab_size counts = [0] * args.vocab_size
for k, v in counter.items(): for k, v in counter.items():
counts[k] = v counts[k] = v
logger.info(f'Dump to {args.token_counts_dump}') logger.info(f"Dump to {args.token_counts_dump}")
with open(args.token_counts_dump, 'wb') as handle: with open(args.token_counts_dump, "wb") as handle:
pickle.dump(counts, handle, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(counts, handle, protocol=pickle.HIGHEST_PROTOCOL)
...@@ -16,272 +16,304 @@ ...@@ -16,272 +16,304 @@
Training the distilled model. Training the distilled model.
Supported architectures include: BERT -> DistilBERT, RoBERTa -> DistilRoBERTa, GPT2 -> DistilGPT2. Supported architectures include: BERT -> DistilBERT, RoBERTa -> DistilRoBERTa, GPT2 -> DistilGPT2.
""" """
import os
import argparse import argparse
import pickle
import json import json
import os
import pickle
import shutil import shutil
import numpy as np import numpy as np
import torch import torch
from transformers import BertConfig, BertForMaskedLM, BertTokenizer
from transformers import RobertaConfig, RobertaForMaskedLM, RobertaTokenizer
from transformers import DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
from distiller import Distiller from distiller import Distiller
from utils import git_log, logger, init_gpu_params, set_seed
from lm_seqs_dataset import LmSeqsDataset from lm_seqs_dataset import LmSeqsDataset
from transformers import (
BertConfig,
BertForMaskedLM,
BertTokenizer,
DistilBertConfig,
DistilBertForMaskedLM,
DistilBertTokenizer,
GPT2Config,
GPT2LMHeadModel,
GPT2Tokenizer,
RobertaConfig,
RobertaForMaskedLM,
RobertaTokenizer,
)
from utils import git_log, init_gpu_params, logger, set_seed
MODEL_CLASSES = { MODEL_CLASSES = {
'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer), "distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer), "roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
'bert': (BertConfig, BertForMaskedLM, BertTokenizer), "bert": (BertConfig, BertForMaskedLM, BertTokenizer),
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer) "gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
} }
def sanity_checks(args): def sanity_checks(args):
""" """
A bunch of args sanity checks to perform even starting... A bunch of args sanity checks to perform even starting...
""" """
assert (args.mlm and args.alpha_mlm > 0.) or (not args.mlm and args.alpha_mlm == 0.) assert (args.mlm and args.alpha_mlm > 0.0) or (not args.mlm and args.alpha_mlm == 0.0)
assert (args.alpha_mlm > 0. and args.alpha_clm == 0.) or (args.alpha_mlm == 0. and args.alpha_clm > 0.) assert (args.alpha_mlm > 0.0 and args.alpha_clm == 0.0) or (args.alpha_mlm == 0.0 and args.alpha_clm > 0.0)
if args.mlm: if args.mlm:
assert os.path.isfile(args.token_counts) assert os.path.isfile(args.token_counts)
assert (args.student_type in ['roberta', 'distilbert']) and (args.teacher_type in ['roberta', 'bert']) assert (args.student_type in ["roberta", "distilbert"]) and (args.teacher_type in ["roberta", "bert"])
else: else:
assert (args.student_type in ['gpt2']) and (args.teacher_type in ['gpt2']) assert (args.student_type in ["gpt2"]) and (args.teacher_type in ["gpt2"])
assert args.teacher_type == args.student_type or (args.student_type=='distilbert' and args.teacher_type=='bert') assert args.teacher_type == args.student_type or (
args.student_type == "distilbert" and args.teacher_type == "bert"
)
assert os.path.isfile(args.student_config) assert os.path.isfile(args.student_config)
if args.student_pretrained_weights is not None: if args.student_pretrained_weights is not None:
assert os.path.isfile(args.student_pretrained_weights) assert os.path.isfile(args.student_pretrained_weights)
if args.freeze_token_type_embds: assert args.student_type in ['roberta'] if args.freeze_token_type_embds:
assert args.student_type in ["roberta"]
assert args.alpha_ce >= 0.0
assert args.alpha_mlm >= 0.0
assert args.alpha_clm >= 0.0
assert args.alpha_mse >= 0.0
assert args.alpha_cos >= 0.0
assert args.alpha_ce + args.alpha_mlm + args.alpha_clm + args.alpha_mse + args.alpha_cos > 0.0
assert args.alpha_ce >= 0.
assert args.alpha_mlm >= 0.
assert args.alpha_clm >= 0.
assert args.alpha_mse >= 0.
assert args.alpha_cos >= 0.
assert args.alpha_ce + args.alpha_mlm + args.alpha_clm + args.alpha_mse + args.alpha_cos > 0.
def freeze_pos_embeddings(student, args): def freeze_pos_embeddings(student, args):
if args.student_type == 'roberta': if args.student_type == "roberta":
student.roberta.embeddings.position_embeddings.weight.requires_grad = False student.roberta.embeddings.position_embeddings.weight.requires_grad = False
elif args.student_type == 'gpt2': elif args.student_type == "gpt2":
student.transformer.wpe.weight.requires_grad = False student.transformer.wpe.weight.requires_grad = False
def freeze_token_type_embeddings(student, args): def freeze_token_type_embeddings(student, args):
if args.student_type == 'roberta': if args.student_type == "roberta":
student.roberta.embeddings.token_type_embeddings.weight.requires_grad = False student.roberta.embeddings.token_type_embeddings.weight.requires_grad = False
def main(): def main():
parser = argparse.ArgumentParser(description="Training") parser = argparse.ArgumentParser(description="Training")
parser.add_argument("--force", action='store_true', parser.add_argument("--force", action="store_true", help="Overwrite dump_path if it already exists.")
help="Overwrite dump_path if it already exists.")
parser.add_argument(
parser.add_argument("--dump_path", type=str, required=True, "--dump_path", type=str, required=True, help="The output directory (log, checkpoints, parameters, etc.)"
help="The output directory (log, checkpoints, parameters, etc.)") )
parser.add_argument("--data_file", type=str, required=True, parser.add_argument(
help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.") "--data_file",
type=str,
parser.add_argument("--student_type", type=str, choices=["distilbert", "roberta", "gpt2"], required=True, required=True,
help="The student type (DistilBERT, RoBERTa).") help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.",
parser.add_argument("--student_config", type=str, required=True, )
help="Path to the student configuration.")
parser.add_argument("--student_pretrained_weights", default=None, type=str, parser.add_argument(
help="Load student initialization checkpoint.") "--student_type",
type=str,
parser.add_argument("--teacher_type", choices=["bert", "roberta", "gpt2"], required=True, choices=["distilbert", "roberta", "gpt2"],
help="Teacher type (BERT, RoBERTa).") required=True,
parser.add_argument("--teacher_name", type=str, required=True, help="The student type (DistilBERT, RoBERTa).",
help="The teacher model.") )
parser.add_argument("--student_config", type=str, required=True, help="Path to the student configuration.")
parser.add_argument("--temperature", default=2., type=float, parser.add_argument(
help="Temperature for the softmax temperature.") "--student_pretrained_weights", default=None, type=str, help="Load student initialization checkpoint."
parser.add_argument("--alpha_ce", default=0.5, type=float, )
help="Linear weight for the distillation loss. Must be >=0.")
parser.add_argument("--alpha_mlm", default=0.0, type=float, parser.add_argument(
help="Linear weight for the MLM loss. Must be >=0. Should be used in coonjunction with `mlm` flag.") "--teacher_type", choices=["bert", "roberta", "gpt2"], required=True, help="Teacher type (BERT, RoBERTa)."
parser.add_argument("--alpha_clm", default=0.5, type=float, )
help="Linear weight for the CLM loss. Must be >=0.") parser.add_argument("--teacher_name", type=str, required=True, help="The teacher model.")
parser.add_argument("--alpha_mse", default=0.0, type=float,
help="Linear weight of the MSE loss. Must be >=0.") parser.add_argument("--temperature", default=2.0, type=float, help="Temperature for the softmax temperature.")
parser.add_argument("--alpha_cos", default=0.0, type=float, parser.add_argument(
help="Linear weight of the cosine embedding loss. Must be >=0.") "--alpha_ce", default=0.5, type=float, help="Linear weight for the distillation loss. Must be >=0."
)
parser.add_argument("--mlm", action="store_true", parser.add_argument(
help="The LM step: MLM or CLM. If `mlm` is True, the MLM is used over CLM.") "--alpha_mlm",
parser.add_argument("--mlm_mask_prop", default=0.15, type=float, default=0.0,
help="Proportion of tokens for which we need to make a prediction.") type=float,
parser.add_argument("--word_mask", default=0.8, type=float, help="Linear weight for the MLM loss. Must be >=0. Should be used in coonjunction with `mlm` flag.",
help="Proportion of tokens to mask out.") )
parser.add_argument("--word_keep", default=0.1, type=float, parser.add_argument("--alpha_clm", default=0.5, type=float, help="Linear weight for the CLM loss. Must be >=0.")
help="Proportion of tokens to keep.") parser.add_argument("--alpha_mse", default=0.0, type=float, help="Linear weight of the MSE loss. Must be >=0.")
parser.add_argument("--word_rand", default=0.1, type=float, parser.add_argument(
help="Proportion of tokens to randomly replace.") "--alpha_cos", default=0.0, type=float, help="Linear weight of the cosine embedding loss. Must be >=0."
parser.add_argument("--mlm_smoothing", default=0.7, type=float, )
help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).")
parser.add_argument("--token_counts", type=str, parser.add_argument(
help="The token counts in the data_file for MLM.") "--mlm", action="store_true", help="The LM step: MLM or CLM. If `mlm` is True, the MLM is used over CLM."
)
parser.add_argument("--restrict_ce_to_mask", action='store_true', parser.add_argument(
help="If true, compute the distilation loss only the [MLM] prediction distribution.") "--mlm_mask_prop",
parser.add_argument("--freeze_pos_embs", action="store_true", default=0.15,
help="Freeze positional embeddings during distillation. For student_type in ['roberta', 'gpt2'] only.") type=float,
parser.add_argument("--freeze_token_type_embds", action="store_true", help="Proportion of tokens for which we need to make a prediction.",
help="Freeze token type embeddings during distillation if existent. For student_type in ['roberta'] only.") )
parser.add_argument("--word_mask", default=0.8, type=float, help="Proportion of tokens to mask out.")
parser.add_argument("--n_epoch", type=int, default=3, parser.add_argument("--word_keep", default=0.1, type=float, help="Proportion of tokens to keep.")
help="Number of pass on the whole dataset.") parser.add_argument("--word_rand", default=0.1, type=float, help="Proportion of tokens to randomly replace.")
parser.add_argument("--batch_size", type=int, default=5, parser.add_argument(
help="Batch size (for each process).") "--mlm_smoothing",
parser.add_argument("--group_by_size", action='store_false', default=0.7,
help="If true, group sequences that have similar length into the same batch. Default is true.") type=float,
help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).",
parser.add_argument("--gradient_accumulation_steps", type=int, default=50, )
help="Gradient accumulation for larger training batches.") parser.add_argument("--token_counts", type=str, help="The token counts in the data_file for MLM.")
parser.add_argument("--warmup_prop", default=0.05, type=float,
help="Linear warmup proportion.") parser.add_argument(
parser.add_argument("--weight_decay", default=0.0, type=float, "--restrict_ce_to_mask",
help="Weight deay if we apply some.") action="store_true",
parser.add_argument("--learning_rate", default=5e-4, type=float, help="If true, compute the distilation loss only the [MLM] prediction distribution.",
help="The initial learning rate for Adam.") )
parser.add_argument("--adam_epsilon", default=1e-6, type=float, parser.add_argument(
help="Epsilon for Adam optimizer.") "--freeze_pos_embs",
parser.add_argument("--max_grad_norm", default=5.0, type=float, action="store_true",
help="Max gradient norm.") help="Freeze positional embeddings during distillation. For student_type in ['roberta', 'gpt2'] only.",
parser.add_argument("--initializer_range", default=0.02, type=float, )
help="Random initialization range.") parser.add_argument(
"--freeze_token_type_embds",
parser.add_argument('--fp16', action='store_true', action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") help="Freeze token type embeddings during distillation if existent. For student_type in ['roberta'] only.",
parser.add_argument('--fp16_opt_level', type=str, default='O1', )
parser.add_argument("--n_epoch", type=int, default=3, help="Number of pass on the whole dataset.")
parser.add_argument("--batch_size", type=int, default=5, help="Batch size (for each process).")
parser.add_argument(
"--group_by_size",
action="store_false",
help="If true, group sequences that have similar length into the same batch. Default is true.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=50,
help="Gradient accumulation for larger training batches.",
)
parser.add_argument("--warmup_prop", default=0.05, type=float, help="Linear warmup proportion.")
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
parser.add_argument("--learning_rate", default=5e-4, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--adam_epsilon", default=1e-6, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=5.0, type=float, help="Max gradient norm.")
parser.add_argument("--initializer_range", default=0.02, type=float, help="Random initialization range.")
parser.add_argument(
"--fp16",
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
parser.add_argument(
"--fp16_opt_level",
type=str,
default="O1",
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html") "See details at https://nvidia.github.io/apex/amp.html",
parser.add_argument("--n_gpu", type=int, default=1, )
help="Number of GPUs in the node.") parser.add_argument("--n_gpu", type=int, default=1, help="Number of GPUs in the node.")
parser.add_argument("--local_rank", type=int, default=-1, parser.add_argument("--local_rank", type=int, default=-1, help="Distributed training - Local rank")
help="Distributed training - Local rank") parser.add_argument("--seed", type=int, default=56, help="Random seed")
parser.add_argument("--seed", type=int, default=56,
help="Random seed") parser.add_argument("--log_interval", type=int, default=500, help="Tensorboard logging interval.")
parser.add_argument("--checkpoint_interval", type=int, default=4000, help="Checkpoint interval.")
parser.add_argument("--log_interval", type=int, default=500,
help="Tensorboard logging interval.")
parser.add_argument("--checkpoint_interval", type=int, default=4000,
help="Checkpoint interval.")
args = parser.parse_args() args = parser.parse_args()
sanity_checks(args) sanity_checks(args)
# ARGS #
## ARGS ##
init_gpu_params(args) init_gpu_params(args)
set_seed(args) set_seed(args)
if args.is_master: if args.is_master:
if os.path.exists(args.dump_path): if os.path.exists(args.dump_path):
if not args.force: if not args.force:
raise ValueError(f'Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite it' raise ValueError(
'Use `--force` if you want to overwrite it') f"Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite it"
"Use `--force` if you want to overwrite it"
)
else: else:
shutil.rmtree(args.dump_path) shutil.rmtree(args.dump_path)
if not os.path.exists(args.dump_path): if not os.path.exists(args.dump_path):
os.makedirs(args.dump_path) os.makedirs(args.dump_path)
logger.info(f'Experiment will be dumped and logged in {args.dump_path}') logger.info(f"Experiment will be dumped and logged in {args.dump_path}")
### SAVE PARAMS ### # SAVE PARAMS #
logger.info(f'Param: {args}') logger.info(f"Param: {args}")
with open(os.path.join(args.dump_path, 'parameters.json'), 'w') as f: with open(os.path.join(args.dump_path, "parameters.json"), "w") as f:
json.dump(vars(args), f, indent=4) json.dump(vars(args), f, indent=4)
git_log(args.dump_path) git_log(args.dump_path)
student_config_class, student_model_class, _ = MODEL_CLASSES[args.student_type] student_config_class, student_model_class, _ = MODEL_CLASSES[args.student_type]
teacher_config_class, teacher_model_class, teacher_tokenizer_class = MODEL_CLASSES[args.teacher_type] teacher_config_class, teacher_model_class, teacher_tokenizer_class = MODEL_CLASSES[args.teacher_type]
### TOKENIZER ### # TOKENIZER #
tokenizer = teacher_tokenizer_class.from_pretrained(args.teacher_name) tokenizer = teacher_tokenizer_class.from_pretrained(args.teacher_name)
special_tok_ids = {} special_tok_ids = {}
for tok_name, tok_symbol in tokenizer.special_tokens_map.items(): for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
idx = tokenizer.all_special_tokens.index(tok_symbol) idx = tokenizer.all_special_tokens.index(tok_symbol)
special_tok_ids[tok_name] = tokenizer.all_special_ids[idx] special_tok_ids[tok_name] = tokenizer.all_special_ids[idx]
logger.info(f'Special tokens {special_tok_ids}') logger.info(f"Special tokens {special_tok_ids}")
args.special_tok_ids = special_tok_ids args.special_tok_ids = special_tok_ids
args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name] args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name]
# DATA LOADER #
## DATA LOADER ## logger.info(f"Loading data from {args.data_file}")
logger.info(f'Loading data from {args.data_file}') with open(args.data_file, "rb") as fp:
with open(args.data_file, 'rb') as fp:
data = pickle.load(fp) data = pickle.load(fp)
if args.mlm: if args.mlm:
logger.info(f'Loading token counts from {args.token_counts} (already pre-computed)') logger.info(f"Loading token counts from {args.token_counts} (already pre-computed)")
with open(args.token_counts, 'rb') as fp: with open(args.token_counts, "rb") as fp:
counts = pickle.load(fp) counts = pickle.load(fp)
token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing
for idx in special_tok_ids.values(): for idx in special_tok_ids.values():
token_probs[idx] = 0. # do not predict special tokens token_probs[idx] = 0.0 # do not predict special tokens
token_probs = torch.from_numpy(token_probs) token_probs = torch.from_numpy(token_probs)
else: else:
token_probs = None token_probs = None
train_lm_seq_dataset = LmSeqsDataset(params=args, data=data) train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
logger.info(f'Data loader created.') logger.info(f"Data loader created.")
## STUDENT ## # STUDENT #
logger.info(f'Loading student config from {args.student_config}') logger.info(f"Loading student config from {args.student_config}")
stu_architecture_config = student_config_class.from_pretrained(args.student_config) stu_architecture_config = student_config_class.from_pretrained(args.student_config)
stu_architecture_config.output_hidden_states = True stu_architecture_config.output_hidden_states = True
if args.student_pretrained_weights is not None: if args.student_pretrained_weights is not None:
logger.info(f'Loading pretrained weights from {args.student_pretrained_weights}') logger.info(f"Loading pretrained weights from {args.student_pretrained_weights}")
student = student_model_class.from_pretrained(args.student_pretrained_weights, student = student_model_class.from_pretrained(args.student_pretrained_weights, config=stu_architecture_config)
config=stu_architecture_config)
else: else:
student = student_model_class(stu_architecture_config) student = student_model_class(stu_architecture_config)
if args.n_gpu > 0: if args.n_gpu > 0:
student.to(f'cuda:{args.local_rank}') student.to(f"cuda:{args.local_rank}")
logger.info(f'Student loaded.') logger.info(f"Student loaded.")
## TEACHER ## # TEACHER #
teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True) teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True)
if args.n_gpu > 0: if args.n_gpu > 0:
teacher.to(f'cuda:{args.local_rank}') teacher.to(f"cuda:{args.local_rank}")
logger.info(f'Teacher loaded from {args.teacher_name}.') logger.info(f"Teacher loaded from {args.teacher_name}.")
## FREEZING ## # FREEZING #
if args.freeze_pos_embs: if args.freeze_pos_embs:
freeze_pos_embeddings(student, args) freeze_pos_embeddings(student, args)
if args.freeze_token_type_embds: if args.freeze_token_type_embds:
freeze_token_type_embeddings(student, args) freeze_token_type_embeddings(student, args)
# SANITY CHECKS #
## SANITY CHECKS ##
assert student.config.vocab_size == teacher.config.vocab_size assert student.config.vocab_size == teacher.config.vocab_size
assert student.config.hidden_size == teacher.config.hidden_size assert student.config.hidden_size == teacher.config.hidden_size
assert student.config.max_position_embeddings == teacher.config.max_position_embeddings assert student.config.max_position_embeddings == teacher.config.max_position_embeddings
if args.mlm: if args.mlm:
assert token_probs.size(0) == stu_architecture_config.vocab_size assert token_probs.size(0) == stu_architecture_config.vocab_size
# DISTILLER #
## DISTILLER ##
torch.cuda.empty_cache() torch.cuda.empty_cache()
distiller = Distiller(params=args, distiller = Distiller(
dataset=train_lm_seq_dataset, params=args, dataset=train_lm_seq_dataset, token_probs=token_probs, student=student, teacher=teacher
token_probs=token_probs, )
student=student,
teacher=teacher)
distiller.train() distiller.train()
logger.info("Let's go get some drinks.") logger.info("Let's go get some drinks.")
......
...@@ -15,17 +15,21 @@ ...@@ -15,17 +15,21 @@
""" Utils to train DistilBERT """ Utils to train DistilBERT
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
""" """
import git
import json import json
import logging
import os import os
import socket import socket
import torch
import git
import numpy as np import numpy as np
import torch
import logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s', logging.basicConfig(
datefmt = '%m/%d/%Y %H:%M:%S', format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
level = logging.INFO) datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -35,12 +39,12 @@ def git_log(folder_path: str): ...@@ -35,12 +39,12 @@ def git_log(folder_path: str):
""" """
repo = git.Repo(search_parent_directories=True) repo = git.Repo(search_parent_directories=True)
repo_infos = { repo_infos = {
'repo_id': str(repo), "repo_id": str(repo),
'repo_sha': str(repo.head.object.hexsha), "repo_sha": str(repo.head.object.hexsha),
'repo_branch': str(repo.active_branch) "repo_branch": str(repo.active_branch),
} }
with open(os.path.join(folder_path, 'git_log.json'), 'w') as f: with open(os.path.join(folder_path, "git_log.json"), "w") as f:
json.dump(repo_infos, f, indent=4) json.dump(repo_infos, f, indent=4)
...@@ -57,21 +61,21 @@ def init_gpu_params(params): ...@@ -57,21 +61,21 @@ def init_gpu_params(params):
assert torch.cuda.is_available() assert torch.cuda.is_available()
logger.info('Initializing GPUs') logger.info("Initializing GPUs")
if params.n_gpu > 1: if params.n_gpu > 1:
assert params.local_rank != -1 assert params.local_rank != -1
params.world_size = int(os.environ['WORLD_SIZE']) params.world_size = int(os.environ["WORLD_SIZE"])
params.n_gpu_per_node = int(os.environ['N_GPU_NODE']) params.n_gpu_per_node = int(os.environ["N_GPU_NODE"])
params.global_rank = int(os.environ['RANK']) params.global_rank = int(os.environ["RANK"])
# number of nodes / node ID # number of nodes / node ID
params.n_nodes = params.world_size // params.n_gpu_per_node params.n_nodes = params.world_size // params.n_gpu_per_node
params.node_id = params.global_rank // params.n_gpu_per_node params.node_id = params.global_rank // params.n_gpu_per_node
params.multi_gpu = True params.multi_gpu = True
assert params.n_nodes == int(os.environ['N_NODES']) assert params.n_nodes == int(os.environ["N_NODES"])
assert params.node_id == int(os.environ['NODE_RANK']) assert params.node_id == int(os.environ["NODE_RANK"])
# local job (single GPU) # local job (single GPU)
else: else:
...@@ -114,8 +118,7 @@ def init_gpu_params(params): ...@@ -114,8 +118,7 @@ def init_gpu_params(params):
if params.multi_gpu: if params.multi_gpu:
logger.info("Initializing PyTorch distributed") logger.info("Initializing PyTorch distributed")
torch.distributed.init_process_group( torch.distributed.init_process_group(
init_method='env://', init_method="env://", backend="nccl",
backend='nccl',
) )
......
...@@ -19,50 +19,70 @@ from __future__ import absolute_import, division, print_function ...@@ -19,50 +19,70 @@ from __future__ import absolute_import, division, print_function
import argparse import argparse
import glob import glob
import json
import logging import logging
import os import os
import random import random
import json
from sklearn.metrics import f1_score
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from sklearn.metrics import f1_score
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from transformers import (
WEIGHTS_NAME,
AdamW,
AlbertConfig,
AlbertModel,
AlbertTokenizer,
BertConfig,
BertModel,
BertTokenizer,
DistilBertConfig,
DistilBertModel,
DistilBertTokenizer,
MMBTConfig,
MMBTForClassification,
RobertaConfig,
RobertaModel,
RobertaTokenizer,
XLMConfig,
XLMModel,
XLMTokenizer,
XLNetConfig,
XLNetModel,
XLNetTokenizer,
get_linear_schedule_with_warmup,
)
from utils_mmimdb import ImageEncoder, JsonlDataset, collate_fn, get_image_transforms, get_mmimdb_labels
try: try:
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
except: except ImportError:
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from tqdm import tqdm, trange
from utils_mmimdb import ImageEncoder, JsonlDataset, collate_fn, get_mmimdb_labels, get_image_transforms
from transformers import (WEIGHTS_NAME,
BertConfig, BertModel, BertTokenizer,
RobertaConfig, RobertaModel, RobertaTokenizer,
XLMConfig, XLMModel, XLMTokenizer,
XLNetConfig, XLNetModel, XLNetTokenizer,
DistilBertConfig, DistilBertModel, DistilBertTokenizer,
AlbertConfig, AlbertModel, AlbertTokenizer,
MMBTForClassification, MMBTConfig)
from transformers import AdamW, get_linear_schedule_with_warmup
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig, ALL_MODELS = sum(
RobertaConfig, DistilBertConfig)), ()) (
tuple(conf.pretrained_config_archive_map.keys())
for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig, DistilBertConfig)
),
(),
)
MODEL_CLASSES = { MODEL_CLASSES = {
'bert': (BertConfig, BertModel, BertTokenizer), "bert": (BertConfig, BertModel, BertTokenizer),
'xlnet': (XLNetConfig, XLNetModel, XLNetTokenizer), "xlnet": (XLNetConfig, XLNetModel, XLNetTokenizer),
'xlm': (XLMConfig, XLMModel, XLMTokenizer), "xlm": (XLMConfig, XLMModel, XLMTokenizer),
'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer), "roberta": (RobertaConfig, RobertaModel, RobertaTokenizer),
'distilbert': (DistilBertConfig, DistilBertModel, DistilBertTokenizer), "distilbert": (DistilBertConfig, DistilBertModel, DistilBertTokenizer),
'albert': (AlbertConfig, AlbertModel, AlbertTokenizer) "albert": (AlbertConfig, AlbertModel, AlbertTokenizer),
} }
...@@ -81,10 +101,13 @@ def train(args, train_dataset, model, tokenizer, criterion): ...@@ -81,10 +101,13 @@ def train(args, train_dataset, model, tokenizer, criterion):
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, train_dataloader = DataLoader(
train_dataset,
sampler=train_sampler,
batch_size=args.train_batch_size, batch_size=args.train_batch_size,
collate_fn=collate_fn, collate_fn=collate_fn,
num_workers=args.num_workers) num_workers=args.num_workers,
)
if args.max_steps > 0: if args.max_steps > 0:
t_total = args.max_steps t_total = args.max_steps
...@@ -93,14 +116,19 @@ def train(args, train_dataset, model, tokenizer, criterion): ...@@ -93,14 +116,19 @@ def train(args, train_dataset, model, tokenizer, criterion):
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
# Prepare optimizer and schedule (linear warmup and decay) # Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight'] no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, {
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
] ]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)
if args.fp16: if args.fp16:
try: try:
from apex import amp from apex import amp
...@@ -114,17 +142,21 @@ def train(args, train_dataset, model, tokenizer, criterion): ...@@ -114,17 +142,21 @@ def train(args, train_dataset, model, tokenizer, criterion):
# Distributed training (should be after apex fp16 initialization) # Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], model = torch.nn.parallel.DistributedDataParallel(
output_device=args.local_rank, model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
find_unused_parameters=True) )
# Train! # Train!
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", logger.info(
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) " Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size
* args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
)
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total) logger.info(" Total optimization steps = %d", t_total)
...@@ -140,11 +172,13 @@ def train(args, train_dataset, model, tokenizer, criterion): ...@@ -140,11 +172,13 @@ def train(args, train_dataset, model, tokenizer, criterion):
model.train() model.train()
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
labels = batch[5] labels = batch[5]
inputs = {'input_ids': batch[0], inputs = {
'input_modal': batch[2], "input_ids": batch[0],
'attention_mask': batch[1], "input_modal": batch[2],
'modal_start_tokens': batch[3], "attention_mask": batch[1],
'modal_end_tokens': batch[4]} "modal_start_tokens": batch[3],
"modal_end_tokens": batch[4],
}
outputs = model(**inputs) outputs = model(**inputs)
logits = outputs[0] # model outputs are always tuple in transformers (see doc) logits = outputs[0] # model outputs are always tuple in transformers (see doc)
loss = criterion(logits, labels) loss = criterion(logits, labels)
...@@ -174,30 +208,34 @@ def train(args, train_dataset, model, tokenizer, criterion): ...@@ -174,30 +208,34 @@ def train(args, train_dataset, model, tokenizer, criterion):
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
logs = {} logs = {}
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well if (
args.local_rank == -1 and args.evaluate_during_training
): # Only evaluate when single GPU otherwise metrics may not average well
results = evaluate(args, model, tokenizer, criterion) results = evaluate(args, model, tokenizer, criterion)
for key, value in results.items(): for key, value in results.items():
eval_key = 'eval_{}'.format(key) eval_key = "eval_{}".format(key)
logs[eval_key] = value logs[eval_key] = value
loss_scalar = (tr_loss - logging_loss) / args.logging_steps loss_scalar = (tr_loss - logging_loss) / args.logging_steps
learning_rate_scalar = scheduler.get_lr()[0] learning_rate_scalar = scheduler.get_lr()[0]
logs['learning_rate'] = learning_rate_scalar logs["learning_rate"] = learning_rate_scalar
logs['loss'] = loss_scalar logs["loss"] = loss_scalar
logging_loss = tr_loss logging_loss = tr_loss
for key, value in logs.items(): for key, value in logs.items():
tb_writer.add_scalar(key, value, global_step) tb_writer.add_scalar(key, value, global_step)
print(json.dumps({**logs, **{'step': global_step}})) print(json.dumps({**logs, **{"step": global_step}}))
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
# Save model checkpoint # Save model checkpoint
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
torch.save(model_to_save.state_dict(), os.path.join(output_dir, WEIGHTS_NAME)) torch.save(model_to_save.state_dict(), os.path.join(output_dir, WEIGHTS_NAME))
torch.save(args, os.path.join(output_dir, 'training_args.bin')) torch.save(args, os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to %s", output_dir) logger.info("Saving model checkpoint to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps: if args.max_steps > 0 and global_step > args.max_steps:
...@@ -209,8 +247,8 @@ def train(args, train_dataset, model, tokenizer, criterion): ...@@ -209,8 +247,8 @@ def train(args, train_dataset, model, tokenizer, criterion):
if args.local_rank == -1: if args.local_rank == -1:
results = evaluate(args, model, tokenizer, criterion) results = evaluate(args, model, tokenizer, criterion)
if results['micro_f1'] > best_f1: if results["micro_f1"] > best_f1:
best_f1 = results['micro_f1'] best_f1 = results["micro_f1"]
n_no_improve = 0 n_no_improve = 0
else: else:
n_no_improve += 1 n_no_improve += 1
...@@ -236,7 +274,9 @@ def evaluate(args, model, tokenizer, criterion, prefix=""): ...@@ -236,7 +274,9 @@ def evaluate(args, model, tokenizer, criterion, prefix=""):
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
# Note that DistributedSampler samples randomly # Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(eval_dataset) eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate_fn) eval_dataloader = DataLoader(
eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate_fn
)
# multi-gpu eval # multi-gpu eval
if args.n_gpu > 1: if args.n_gpu > 1:
...@@ -257,11 +297,13 @@ def evaluate(args, model, tokenizer, criterion, prefix=""): ...@@ -257,11 +297,13 @@ def evaluate(args, model, tokenizer, criterion, prefix=""):
with torch.no_grad(): with torch.no_grad():
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
labels = batch[5] labels = batch[5]
inputs = {'input_ids': batch[0], inputs = {
'input_modal': batch[2], "input_ids": batch[0],
'attention_mask': batch[1], "input_modal": batch[2],
'modal_start_tokens': batch[3], "attention_mask": batch[1],
'modal_end_tokens': batch[4]} "modal_start_tokens": batch[3],
"modal_end_tokens": batch[4],
}
outputs = model(**inputs) outputs = model(**inputs)
logits = outputs[0] # model outputs are always tuple in transformers (see doc) logits = outputs[0] # model outputs are always tuple in transformers (see doc)
tmp_eval_loss = criterion(logits, labels) tmp_eval_loss = criterion(logits, labels)
...@@ -278,7 +320,7 @@ def evaluate(args, model, tokenizer, criterion, prefix=""): ...@@ -278,7 +320,7 @@ def evaluate(args, model, tokenizer, criterion, prefix=""):
result = { result = {
"loss": eval_loss, "loss": eval_loss,
"macro_f1": f1_score(out_label_ids, preds, average="macro"), "macro_f1": f1_score(out_label_ids, preds, average="macro"),
"micro_f1": f1_score(out_label_ids, preds, average="micro") "micro_f1": f1_score(out_label_ids, preds, average="micro"),
} }
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt") output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
...@@ -302,95 +344,148 @@ def load_examples(args, tokenizer, evaluate=False): ...@@ -302,95 +344,148 @@ def load_examples(args, tokenizer, evaluate=False):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument("--data_dir", default=None, type=str, required=True, parser.add_argument(
help="The input data dir. Should contain the .jsonl files for MMIMDB.") "--data_dir",
parser.add_argument("--model_type", default=None, type=str, required=True, default=None,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) type=str,
parser.add_argument("--model_name_or_path", default=None, type=str, required=True, required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) help="The input data dir. Should contain the .jsonl files for MMIMDB.",
parser.add_argument("--output_dir", default=None, type=str, required=True, )
help="The output directory where the model predictions and checkpoints will be written.") parser.add_argument(
"--model_type",
## Other parameters default=None,
parser.add_argument("--config_name", default="", type=str, type=str,
help="Pretrained config name or path if not the same as model_name") required=True,
parser.add_argument("--tokenizer_name", default="", type=str, help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
help="Pretrained tokenizer name or path if not the same as model_name") )
parser.add_argument("--cache_dir", default="", type=str, parser.add_argument(
help="Where do you want to store the pre-trained models downloaded from s3") "--model_name_or_path",
parser.add_argument("--max_seq_length", default=128, type=int, default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
)
parser.add_argument(
"--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model predictions and checkpoints will be written.",
)
# Other parameters
parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
)
parser.add_argument(
"--tokenizer_name",
default="",
type=str,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--cache_dir",
default="",
type=str,
help="Where do you want to store the pre-trained models downloaded from s3",
)
parser.add_argument(
"--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer " help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.") "than this will be truncated, sequences shorter will be padded.",
parser.add_argument("--num_image_embeds", default=1, type=int, )
help="Number of Image Embeddings from the Image Encoder") parser.add_argument(
parser.add_argument("--do_train", action='store_true', "--num_image_embeds", default=1, type=int, help="Number of Image Embeddings from the Image Encoder"
help="Whether to run training.") )
parser.add_argument("--do_eval", action='store_true', parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
help="Whether to run eval on the dev set.") parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
parser.add_argument("--evaluate_during_training", action='store_true', parser.add_argument(
help="Rul evaluation during training at each logging step.") "--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
parser.add_argument("--do_lower_case", action='store_true', )
help="Set this flag if you are using an uncased model.") parser.add_argument(
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, )
help="Batch size per GPU/CPU for training.")
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
help="Batch size per GPU/CPU for evaluation.") parser.add_argument(
parser.add_argument('--gradient_accumulation_steps', type=int, default=1, "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
help="Number of updates steps to accumulate before performing a backward/update pass.") )
parser.add_argument("--learning_rate", default=5e-5, type=float, parser.add_argument(
help="The initial learning rate for Adam.") "--gradient_accumulation_steps",
parser.add_argument("--weight_decay", default=0.0, type=float, type=int,
help="Weight deay if we apply some.") default=1,
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Number of updates steps to accumulate before performing a backward/update pass.",
help="Epsilon for Adam optimizer.") )
parser.add_argument("--max_grad_norm", default=1.0, type=float, parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
help="Max gradient norm.") parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
parser.add_argument("--num_train_epochs", default=3.0, type=float, parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
help="Total number of training epochs to perform.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--patience", default=5, type=int, parser.add_argument(
help="Patience for Early Stopping.") "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
parser.add_argument("--max_steps", default=-1, type=int, )
help="If > 0: set total number of training steps to perform. Override num_train_epochs.") parser.add_argument("--patience", default=5, type=int, help="Patience for Early Stopping.")
parser.add_argument("--warmup_steps", default=0, type=int, parser.add_argument(
help="Linear warmup over warmup_steps.") "--max_steps",
default=-1,
parser.add_argument('--logging_steps', type=int, default=50, type=int,
help="Log every X updates steps.") help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
parser.add_argument('--save_steps', type=int, default=50, )
help="Save checkpoint every X updates steps.") parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--eval_all_checkpoints", action='store_true',
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
parser.add_argument("--no_cuda", action='store_true', parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
help="Avoid using CUDA when available") parser.add_argument(
parser.add_argument('--num_workers', type=int, default=8, "--eval_all_checkpoints",
help="number of worker threads for dataloading") action="store_true",
parser.add_argument('--overwrite_output_dir', action='store_true', help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
help="Overwrite the content of the output directory") )
parser.add_argument('--overwrite_cache', action='store_true', parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
help="Overwrite the cached training and evaluation sets") parser.add_argument("--num_workers", type=int, default=8, help="number of worker threads for dataloading")
parser.add_argument('--seed', type=int, default=42, parser.add_argument(
help="random seed for initialization") "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
)
parser.add_argument('--fp16', action='store_true', parser.add_argument(
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
parser.add_argument('--fp16_opt_level', type=str, default='O1', )
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument(
"--fp16",
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
parser.add_argument(
"--fp16_opt_level",
type=str,
default="O1",
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html") "See details at https://nvidia.github.io/apex/amp.html",
parser.add_argument("--local_rank", type=int, default=-1, )
help="For distributed training: local_rank") parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.") parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.") parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
args = parser.parse_args() args = parser.parse_args()
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: if (
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) os.path.exists(args.output_dir)
and os.listdir(args.output_dir)
and args.do_train
and not args.overwrite_output_dir
):
raise ValueError(
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
args.output_dir
)
)
# Setup distant debugging if needed # Setup distant debugging if needed
if args.server_ip and args.server_port: if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd import ptvsd
print("Waiting for debugger attach") print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach() ptvsd.wait_for_attach()
...@@ -402,17 +497,25 @@ def main(): ...@@ -402,17 +497,25 @@ def main():
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.cuda.set_device(args.local_rank) torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank) device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend='nccl') torch.distributed.init_process_group(backend="nccl")
args.n_gpu = 1 args.n_gpu = 1
args.device = device args.device = device
# Setup logging # Setup logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(
datefmt = '%m/%d/%Y %H:%M:%S', format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) datefmt="%m/%d/%Y %H:%M:%S",
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) )
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
args.local_rank,
device,
args.n_gpu,
bool(args.local_rank != -1),
args.fp16,
)
# Set seed # Set seed
set_seed(args) set_seed(args)
...@@ -426,13 +529,17 @@ def main(): ...@@ -426,13 +529,17 @@ def main():
num_labels = len(labels) num_labels = len(labels)
args.model_type = args.model_type.lower() args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
transformer_config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) transformer_config = config_class.from_pretrained(
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, args.config_name if args.config_name else args.model_name_or_path
)
tokenizer = tokenizer_class.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
do_lower_case=args.do_lower_case, do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None) cache_dir=args.cache_dir if args.cache_dir else None,
transformer = model_class.from_pretrained(args.model_name_or_path, )
config=transformer_config, transformer = model_class.from_pretrained(
cache_dir=args.cache_dir if args.cache_dir else None) args.model_name_or_path, config=transformer_config, cache_dir=args.cache_dir if args.cache_dir else None
)
img_encoder = ImageEncoder(args) img_encoder = ImageEncoder(args)
config = MMBTConfig(transformer_config, num_labels=num_labels) config = MMBTConfig(transformer_config, num_labels=num_labels)
model = MMBTForClassification(config, transformer, img_encoder) model = MMBTForClassification(config, transformer, img_encoder)
...@@ -449,12 +556,13 @@ def main(): ...@@ -449,12 +556,13 @@ def main():
train_dataset = load_examples(args, tokenizer, evaluate=False) train_dataset = load_examples(args, tokenizer, evaluate=False)
label_frequences = train_dataset.get_label_frequencies() label_frequences = train_dataset.get_label_frequencies()
label_frequences = [label_frequences[l] for l in labels] label_frequences = [label_frequences[l] for l in labels]
label_weights = (torch.tensor(label_frequences, device=args.device, dtype=torch.float) / len(train_dataset)) ** -1 label_weights = (
torch.tensor(label_frequences, device=args.device, dtype=torch.float) / len(train_dataset)
) ** -1
criterion = nn.BCEWithLogitsLoss(pos_weight=label_weights) criterion = nn.BCEWithLogitsLoss(pos_weight=label_weights)
global_step, tr_loss = train(args, train_dataset, model, tokenizer, criterion) global_step, tr_loss = train(args, train_dataset, model, tokenizer, criterion)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
# Create output directory if needed # Create output directory if needed
...@@ -464,12 +572,14 @@ def main(): ...@@ -464,12 +572,14 @@ def main():
logger.info("Saving model checkpoint to %s", args.output_dir) logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`. # Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()` # They can then be reloaded using `from_pretrained()`
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
torch.save(model_to_save.state_dict(), os.path.join(args.output_dir, WEIGHTS_NAME)) torch.save(model_to_save.state_dict(), os.path.join(args.output_dir, WEIGHTS_NAME))
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
# Good practice: save your training arguments together with the trained model # Good practice: save your training arguments together with the trained model
torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
# Load a trained model and vocabulary that you have fine-tuned # Load a trained model and vocabulary that you have fine-tuned
model = MMBTForClassification(config, transformer, img_encoder) model = MMBTForClassification(config, transformer, img_encoder)
...@@ -477,24 +587,25 @@ def main(): ...@@ -477,24 +587,25 @@ def main():
tokenizer = tokenizer_class.from_pretrained(args.output_dir) tokenizer = tokenizer_class.from_pretrained(args.output_dir)
model.to(args.device) model.to(args.device)
# Evaluation # Evaluation
results = {} results = {}
if args.do_eval and args.local_rank in [-1, 0]: if args.do_eval and args.local_rank in [-1, 0]:
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
checkpoints = [args.output_dir] checkpoints = [args.output_dir]
if args.eval_all_checkpoints: if args.eval_all_checkpoints:
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) checkpoints = list(
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
)
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else "" prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
model = MMBTForClassification(config, transformer, img_encoder) model = MMBTForClassification(config, transformer, img_encoder)
model.load_state_dict(torch.load(checkpoint)) model.load_state_dict(torch.load(checkpoint))
model.to(args.device) model.to(args.device)
result = evaluate(args, model, tokenizer, criterion, prefix=prefix) result = evaluate(args, model, tokenizer, criterion, prefix=prefix)
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
results.update(result) results.update(result)
return results return results
......
...@@ -17,25 +17,16 @@ ...@@ -17,25 +17,16 @@
import json import json
import os import os
from collections import Counter from collections import Counter
from PIL import Image
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision import torchvision
import torchvision.transforms as transforms import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset from torch.utils.data import Dataset
POOLING_BREAKDOWN = {
1: (1, 1), POOLING_BREAKDOWN = {1: (1, 1), 2: (2, 1), 3: (3, 1), 4: (2, 2), 5: (5, 1), 6: (3, 2), 7: (7, 1), 8: (4, 2), 9: (3, 3)}
2: (2, 1),
3: (3, 1),
4: (2, 2),
5: (5, 1),
6: (3, 2),
7: (7, 1),
8: (4, 2),
9: (3, 3)
}
class ImageEncoder(nn.Module): class ImageEncoder(nn.Module):
...@@ -54,7 +45,6 @@ class ImageEncoder(nn.Module): ...@@ -54,7 +45,6 @@ class ImageEncoder(nn.Module):
return out # BxNx2048 return out # BxNx2048
class JsonlDataset(Dataset): class JsonlDataset(Dataset):
def __init__(self, data_path, tokenizer, transforms, labels, max_seq_length): def __init__(self, data_path, tokenizer, transforms, labels, max_seq_length):
self.data = [json.loads(l) for l in open(data_path)] self.data = [json.loads(l) for l in open(data_path)]
...@@ -72,7 +62,7 @@ class JsonlDataset(Dataset): ...@@ -72,7 +62,7 @@ class JsonlDataset(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
sentence = torch.LongTensor(self.tokenizer.encode(self.data[index]["text"], add_special_tokens=True)) sentence = torch.LongTensor(self.tokenizer.encode(self.data[index]["text"], add_special_tokens=True))
start_token, sentence, end_token = sentence[0], sentence[1:-1], sentence[-1] start_token, sentence, end_token = sentence[0], sentence[1:-1], sentence[-1]
sentence = sentence[:self.max_seq_length] sentence = sentence[: self.max_seq_length]
label = torch.zeros(self.n_classes) label = torch.zeros(self.n_classes)
label[[self.labels.index(tgt) for tgt in self.data[index]["label"]]] = 1 label[[self.labels.index(tgt) for tgt in self.data[index]["label"]]] = 1
...@@ -80,8 +70,13 @@ class JsonlDataset(Dataset): ...@@ -80,8 +70,13 @@ class JsonlDataset(Dataset):
image = Image.open(os.path.join(self.data_dir, self.data[index]["img"])).convert("RGB") image = Image.open(os.path.join(self.data_dir, self.data[index]["img"])).convert("RGB")
image = self.transforms(image) image = self.transforms(image)
return {"image_start_token": start_token, "image_end_token": end_token, return {
"sentence": sentence, "image": image, "label": label} "image_start_token": start_token,
"image_end_token": end_token,
"sentence": sentence,
"image": image,
"label": label,
}
def get_label_frequencies(self): def get_label_frequencies(self):
label_freqs = Counter() label_freqs = Counter()
...@@ -110,10 +105,31 @@ def collate_fn(batch): ...@@ -110,10 +105,31 @@ def collate_fn(batch):
def get_mmimdb_labels(): def get_mmimdb_labels():
return ['Crime', 'Drama', 'Thriller', 'Action', 'Comedy', 'Romance', return [
'Documentary', 'Short', 'Mystery', 'History', 'Family', 'Adventure', "Crime",
'Fantasy', 'Sci-Fi', 'Western', 'Horror', 'Sport', 'War', 'Music', "Drama",
'Musical', 'Animation', 'Biography', 'Film-Noir'] "Thriller",
"Action",
"Comedy",
"Romance",
"Documentary",
"Short",
"Mystery",
"History",
"Family",
"Adventure",
"Fantasy",
"Sci-Fi",
"Western",
"Horror",
"Sport",
"War",
"Music",
"Musical",
"Animation",
"Biography",
"Film-Noir",
]
def get_image_transforms(): def get_image_transforms():
...@@ -122,9 +138,6 @@ def get_image_transforms(): ...@@ -122,9 +138,6 @@ def get_image_transforms():
transforms.Resize(256), transforms.Resize(256),
transforms.CenterCrop(224), transforms.CenterCrop(224),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize( transforms.Normalize(mean=[0.46777044, 0.44531429, 0.40661017], std=[0.12221994, 0.12145835, 0.14380469],),
mean=[0.46777044, 0.44531429, 0.40661017],
std=[0.12221994, 0.12145835, 0.14380469],
),
] ]
) )
import torch import torch
class ClassificationHead(torch.nn.Module): class ClassificationHead(torch.nn.Module):
"""Classification Head for transformer encoders""" """Classification Head for transformer encoders"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment