"...src/routers/git@developer.sourcefind.cn:change/sglang.git" did not exist on "36efd5be8a0f0c5e0f07dcd3a0b6b4df5d210c89"
Commit fa84ae26 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Reformat source code with black.

This is the result of:

    $ black --line-length 119 examples templates transformers utils hubconf.py setup.py

There's a lot of fairly long lines in the project. As a consequence, I'm
picking the longest widely accepted line length, 119 characters.

This is also Thomas' preference, because it allows for explicit variable
names, to make the code easier to understand.
parent 63e3827c
...@@ -247,7 +247,8 @@ the wall, slowly on into the Social Predestination Room. ...@@ -247,7 +247,8 @@ the wall, slowly on into the Social Predestination Room.
as they entered.""" as they entered."""
def create_setup_and_compute(model_names: List[str], def create_setup_and_compute(
model_names: List[str],
gpu: bool = True, gpu: bool = True,
tensorflow: bool = False, tensorflow: bool = False,
average_over: int = 3, average_over: int = 3,
...@@ -256,7 +257,8 @@ def create_setup_and_compute(model_names: List[str], ...@@ -256,7 +257,8 @@ def create_setup_and_compute(model_names: List[str],
amp: bool = False, amp: bool = False,
fp16: bool = False, fp16: bool = False,
save_to_csv: bool = False, save_to_csv: bool = False,
csv_filename: str = f"results_{round(time())}.csv"): csv_filename: str = f"results_{round(time())}.csv",
):
if xla: if xla:
tf.config.optimizer.set_jit(True) tf.config.optimizer.set_jit(True)
if amp: if amp:
...@@ -266,7 +268,7 @@ def create_setup_and_compute(model_names: List[str], ...@@ -266,7 +268,7 @@ def create_setup_and_compute(model_names: List[str],
dictionary = {model_name: {} for model_name in model_names} dictionary = {model_name: {} for model_name in model_names}
results = _compute_tensorflow(model_names, dictionary, average_over, amp) results = _compute_tensorflow(model_names, dictionary, average_over, amp)
else: else:
device = 'cuda' if (gpu and torch.cuda.is_available()) else 'cpu' device = "cuda" if (gpu and torch.cuda.is_available()) else "cpu"
dictionary = {model_name: {} for model_name in model_names} dictionary = {model_name: {} for model_name in model_names}
results = _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16) results = _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16)
...@@ -276,22 +278,40 @@ def create_setup_and_compute(model_names: List[str], ...@@ -276,22 +278,40 @@ def create_setup_and_compute(model_names: List[str],
for batch_size in results[model_name]["bs"]: for batch_size in results[model_name]["bs"]:
print("\t\t" + f"===== BATCH SIZE: {batch_size} =====") print("\t\t" + f"===== BATCH SIZE: {batch_size} =====")
for slice_size in results[model_name]["ss"]: for slice_size in results[model_name]["ss"]:
result = results[model_name]['results'][batch_size][slice_size] result = results[model_name]["results"][batch_size][slice_size]
if isinstance(result, str): if isinstance(result, str):
print(f"\t\t{model_name}/{batch_size}/{slice_size}: " print(f"\t\t{model_name}/{batch_size}/{slice_size}: " f"{result}")
f"{result}")
else: else:
print(f"\t\t{model_name}/{batch_size}/{slice_size}: " print(f"\t\t{model_name}/{batch_size}/{slice_size}: " f"{(round(1000 * result) / 1000)}" f"s")
f"{(round(1000 * result) / 1000)}"
f"s")
if save_to_csv: if save_to_csv:
with open(csv_filename, mode='w') as csv_file: with open(csv_filename, mode="w") as csv_file:
fieldnames = ['model', fieldnames = [
'1x8', '1x64', '1x128', '1x256', '1x512', '1x1024', "model",
'2x8', '2x64', '2x128', '2x256', '2x512', '2x1024', "1x8",
'4x8', '4x64', '4x128', '4x256', '4x512', '4x1024', "1x64",
'8x8', '8x64', '8x128', '8x256', '8x512', '8x1024', "1x128",
"1x256",
"1x512",
"1x1024",
"2x8",
"2x64",
"2x128",
"2x256",
"2x512",
"2x1024",
"4x8",
"4x64",
"4x128",
"4x256",
"4x512",
"4x1024",
"8x8",
"8x64",
"8x128",
"8x256",
"8x512",
"8x1024",
] ]
writer = csv.DictWriter(csv_file, fieldnames=fieldnames) writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
...@@ -299,11 +319,11 @@ def create_setup_and_compute(model_names: List[str], ...@@ -299,11 +319,11 @@ def create_setup_and_compute(model_names: List[str],
for model_name in model_names: for model_name in model_names:
model_results = { model_results = {
f'{bs}x{ss}': results[model_name]['results'][bs][ss] f"{bs}x{ss}": results[model_name]["results"][bs][ss]
for bs in results[model_name]["results"] for bs in results[model_name]["results"]
for ss in results[model_name]['results'][bs] for ss in results[model_name]["results"][bs]
} }
writer.writerow({'model': model_name, **model_results}) writer.writerow({"model": model_name, **model_results})
def _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16): def _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16):
...@@ -343,7 +363,7 @@ def _compute_pytorch(model_names, dictionary, average_over, device, torchscript, ...@@ -343,7 +363,7 @@ def _compute_pytorch(model_names, dictionary, average_over, device, torchscript,
print("Going through model with sequence of shape", sequence.shape) print("Going through model with sequence of shape", sequence.shape)
runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3) runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3)
average_time = sum(runtimes)/float(len(runtimes)) / 3.0 average_time = sum(runtimes) / float(len(runtimes)) / 3.0
dictionary[model_name]["results"][batch_size][slice_size] = average_time dictionary[model_name]["results"][batch_size][slice_size] = average_time
except RuntimeError as e: except RuntimeError as e:
print("Doesn't fit on GPU.", e) print("Doesn't fit on GPU.", e)
...@@ -379,7 +399,9 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp): ...@@ -379,7 +399,9 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp):
if max_input_size is not None and slice_size > max_input_size: if max_input_size is not None and slice_size > max_input_size:
dictionary[model_name]["results"][batch_size][slice_size] = "N/A" dictionary[model_name]["results"][batch_size][slice_size] = "N/A"
else: else:
sequence = tf.stack([tf.squeeze(tf.constant(tokenized_sequence[:slice_size])[None, :])] * batch_size) sequence = tf.stack(
[tf.squeeze(tf.constant(tokenized_sequence[:slice_size])[None, :])] * batch_size
)
try: try:
print("Going through model with sequence of shape", sequence.shape) print("Going through model with sequence of shape", sequence.shape)
...@@ -387,7 +409,7 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp): ...@@ -387,7 +409,7 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp):
inference(sequence) inference(sequence)
runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3) runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3)
average_time = sum(runtimes)/float(len(runtimes)) / 3.0 average_time = sum(runtimes) / float(len(runtimes)) / 3.0
dictionary[model_name]["results"][batch_size][slice_size] = average_time dictionary[model_name]["results"][batch_size][slice_size] = average_time
except tf.errors.ResourceExhaustedError as e: except tf.errors.ResourceExhaustedError as e:
print("Doesn't fit on GPU.", e) print("Doesn't fit on GPU.", e)
...@@ -399,33 +421,64 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp): ...@@ -399,33 +421,64 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--models", required=False, type=str, default='all', help="Model checkpoints to be provided " parser.add_argument(
"--models",
required=False,
type=str,
default="all",
help="Model checkpoints to be provided "
"to the AutoModel classes. Leave " "to the AutoModel classes. Leave "
"blank to benchmark the base version " "blank to benchmark the base version "
"of all available model " "of all available model "
"architectures.") "architectures.",
parser.add_argument("--torch", required=False, action="store_true", help="Benchmark the Pytorch version of the " )
"models") parser.add_argument(
parser.add_argument("--torch_cuda", required=False, action="store_true", help="Pytorch only: run on available " "--torch", required=False, action="store_true", help="Benchmark the Pytorch version of the " "models"
"cuda devices") )
parser.add_argument("--torchscript", required=False, action="store_true", help="Pytorch only: trace the models " parser.add_argument(
"using torchscript") "--torch_cuda", required=False, action="store_true", help="Pytorch only: run on available " "cuda devices"
parser.add_argument("--tensorflow", required=False, action="store_true", help="Benchmark the TensorFlow version " )
parser.add_argument(
"--torchscript",
required=False,
action="store_true",
help="Pytorch only: trace the models " "using torchscript",
)
parser.add_argument(
"--tensorflow",
required=False,
action="store_true",
help="Benchmark the TensorFlow version "
"of the models. Will run on GPU if " "of the models. Will run on GPU if "
"the correct dependencies are " "the correct dependencies are "
"installed") "installed",
)
parser.add_argument("--xla", required=False, action="store_true", help="TensorFlow only: use XLA acceleration.") parser.add_argument("--xla", required=False, action="store_true", help="TensorFlow only: use XLA acceleration.")
parser.add_argument("--amp", required=False, action="store_true", help="TensorFlow only: use automatic mixed precision acceleration.") parser.add_argument(
parser.add_argument("--fp16", required=False, action="store_true", help="PyTorch only: use FP16 to accelerate inference.") "--amp",
parser.add_argument("--keras_predict", required=False, action="store_true", help="Whether to use model.predict " required=False,
"instead of model() to do a " action="store_true",
"forward pass.") help="TensorFlow only: use automatic mixed precision acceleration.",
)
parser.add_argument(
"--fp16", required=False, action="store_true", help="PyTorch only: use FP16 to accelerate inference."
)
parser.add_argument(
"--keras_predict",
required=False,
action="store_true",
help="Whether to use model.predict " "instead of model() to do a " "forward pass.",
)
parser.add_argument("--save_to_csv", required=False, action="store_true", help="Save to a CSV file.") parser.add_argument("--save_to_csv", required=False, action="store_true", help="Save to a CSV file.")
parser.add_argument("--csv_filename", required=False, default=None, help="CSV filename used if saving results to csv.") parser.add_argument(
parser.add_argument("--average_over", required=False, default=30, type=int, help="Times an experiment will be run.") "--csv_filename", required=False, default=None, help="CSV filename used if saving results to csv."
)
parser.add_argument(
"--average_over", required=False, default=30, type=int, help="Times an experiment will be run."
)
args = parser.parse_args() args = parser.parse_args()
if args.models == 'all': if args.models == "all":
args.models = [ args.models = [
"gpt2", "gpt2",
"bert-base-cased", "bert-base-cased",
...@@ -436,7 +489,7 @@ def main(): ...@@ -436,7 +489,7 @@ def main():
"distilbert-base-uncased", "distilbert-base-uncased",
"distilgpt2", "distilgpt2",
"roberta-base", "roberta-base",
"ctrl" "ctrl",
] ]
else: else:
args.models = args.models.split() args.models = args.models.split()
...@@ -453,7 +506,7 @@ def main(): ...@@ -453,7 +506,7 @@ def main():
fp16=args.fp16, fp16=args.fp16,
save_to_csv=args.save_to_csv, save_to_csv=args.save_to_csv,
csv_filename=args.csv_filename, csv_filename=args.csv_filename,
average_over=args.average_over average_over=args.average_over,
) )
else: else:
raise ImportError("Trying to run a PyTorch benchmark but PyTorch was not found in the environment.") raise ImportError("Trying to run a PyTorch benchmark but PyTorch was not found in the environment.")
...@@ -467,11 +520,11 @@ def main(): ...@@ -467,11 +520,11 @@ def main():
amp=args.amp, amp=args.amp,
save_to_csv=args.save_to_csv, save_to_csv=args.save_to_csv,
csv_filename=args.csv_filename, csv_filename=args.csv_filename,
average_over=args.average_over average_over=args.average_over,
) )
else: else:
raise ImportError("Trying to run a TensorFlow benchmark but TensorFlow was not found in the environment.") raise ImportError("Trying to run a TensorFlow benchmark but TensorFlow was not found in the environment.")
if __name__ == '__main__':
main()
if __name__ == "__main__":
main()
...@@ -10,38 +10,37 @@ from transformers.modeling_camembert import CamembertForMaskedLM ...@@ -10,38 +10,37 @@ from transformers.modeling_camembert import CamembertForMaskedLM
def fill_mask(masked_input, model, tokenizer, topk=5): def fill_mask(masked_input, model, tokenizer, topk=5):
# Adapted from https://github.com/pytorch/fairseq/blob/master/fairseq/models/roberta/hub_interface.py # Adapted from https://github.com/pytorch/fairseq/blob/master/fairseq/models/roberta/hub_interface.py
assert masked_input.count('<mask>') == 1 assert masked_input.count("<mask>") == 1
input_ids = torch.tensor(tokenizer.encode(masked_input, add_special_tokens=True)).unsqueeze(0) # Batch size 1 input_ids = torch.tensor(tokenizer.encode(masked_input, add_special_tokens=True)).unsqueeze(0) # Batch size 1
logits = model(input_ids)[0] # The last hidden-state is the first element of the output tuple logits = model(input_ids)[0] # The last hidden-state is the first element of the output tuple
masked_index = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero().item() masked_index = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero().item()
logits = logits[0, masked_index, :] logits = logits[0, masked_index, :]
prob = logits.softmax(dim=0) prob = logits.softmax(dim=0)
values, indices = prob.topk(k=topk, dim=0) values, indices = prob.topk(k=topk, dim=0)
topk_predicted_token_bpe = ' '.join([tokenizer.convert_ids_to_tokens(indices[i].item()) topk_predicted_token_bpe = " ".join(
for i in range(len(indices))]) [tokenizer.convert_ids_to_tokens(indices[i].item()) for i in range(len(indices))]
)
masked_token = tokenizer.mask_token masked_token = tokenizer.mask_token
topk_filled_outputs = [] topk_filled_outputs = []
for index, predicted_token_bpe in enumerate(topk_predicted_token_bpe.split(' ')): for index, predicted_token_bpe in enumerate(topk_predicted_token_bpe.split(" ")):
predicted_token = predicted_token_bpe.replace('\u2581', ' ') predicted_token = predicted_token_bpe.replace("\u2581", " ")
if " {0}".format(masked_token) in masked_input: if " {0}".format(masked_token) in masked_input:
topk_filled_outputs.append(( topk_filled_outputs.append(
masked_input.replace( (
' {0}'.format(masked_token), predicted_token masked_input.replace(" {0}".format(masked_token), predicted_token),
),
values[index].item(), values[index].item(),
predicted_token, predicted_token,
)) )
)
else: else:
topk_filled_outputs.append(( topk_filled_outputs.append(
masked_input.replace(masked_token, predicted_token), (masked_input.replace(masked_token, predicted_token), values[index].item(), predicted_token,)
values[index].item(), )
predicted_token,
))
return topk_filled_outputs return topk_filled_outputs
tokenizer = CamembertTokenizer.from_pretrained('camembert-base') tokenizer = CamembertTokenizer.from_pretrained("camembert-base")
model = CamembertForMaskedLM.from_pretrained('camembert-base') model = CamembertForMaskedLM.from_pretrained("camembert-base")
model.eval() model.eval()
masked_input = "Le camembert est <mask> :)" masked_input = "Le camembert est <mask> :)"
......
...@@ -36,34 +36,42 @@ from tqdm import tqdm, trange ...@@ -36,34 +36,42 @@ from tqdm import tqdm, trange
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
TensorDataset)
from transformers import (
from transformers import (OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer, OpenAIGPTDoubleHeadsModel,
AdamW, cached_path, WEIGHTS_NAME, CONFIG_NAME, OpenAIGPTTokenizer,
get_linear_schedule_with_warmup) AdamW,
cached_path,
WEIGHTS_NAME,
CONFIG_NAME,
get_linear_schedule_with_warmup,
)
ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz" ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz"
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(
datefmt = '%m/%d/%Y %H:%M:%S', format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
level = logging.INFO) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def accuracy(out, labels): def accuracy(out, labels):
outputs = np.argmax(out, axis=1) outputs = np.argmax(out, axis=1)
return np.sum(outputs == labels) return np.sum(outputs == labels)
def load_rocstories_dataset(dataset_path): def load_rocstories_dataset(dataset_path):
""" Output a list of tuples(story, 1st continuation, 2nd continuation, label) """ """ Output a list of tuples(story, 1st continuation, 2nd continuation, label) """
with open(dataset_path, encoding='utf_8') as f: with open(dataset_path, encoding="utf_8") as f:
f = csv.reader(f) f = csv.reader(f)
output = [] output = []
next(f) # skip the first line next(f) # skip the first line
for line in tqdm(f): for line in tqdm(f):
output.append((' '.join(line[1:5]), line[5], line[6], int(line[-1])-1)) output.append((" ".join(line[1:5]), line[5], line[6], int(line[-1]) - 1))
return output return output
def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, delimiter_token, clf_token): def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, delimiter_token, clf_token):
""" Pre-process datasets containing lists of tuples(story, 1st continuation, 2nd continuation, label) """ Pre-process datasets containing lists of tuples(story, 1st continuation, 2nd continuation, label)
...@@ -80,56 +88,68 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d ...@@ -80,56 +88,68 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
for i, (story, cont1, cont2, mc_label), in enumerate(dataset): for i, (story, cont1, cont2, mc_label), in enumerate(dataset):
with_cont1 = [start_token] + story[:cap_length] + [delimiter_token] + cont1[:cap_length] + [clf_token] with_cont1 = [start_token] + story[:cap_length] + [delimiter_token] + cont1[:cap_length] + [clf_token]
with_cont2 = [start_token] + story[:cap_length] + [delimiter_token] + cont2[:cap_length] + [clf_token] with_cont2 = [start_token] + story[:cap_length] + [delimiter_token] + cont2[:cap_length] + [clf_token]
input_ids[i, 0, :len(with_cont1)] = with_cont1 input_ids[i, 0, : len(with_cont1)] = with_cont1
input_ids[i, 1, :len(with_cont2)] = with_cont2 input_ids[i, 1, : len(with_cont2)] = with_cont2
mc_token_ids[i, 0] = len(with_cont1) - 1 mc_token_ids[i, 0] = len(with_cont1) - 1
mc_token_ids[i, 1] = len(with_cont2) - 1 mc_token_ids[i, 1] = len(with_cont2) - 1
lm_labels[i, 0, :len(with_cont1)] = with_cont1 lm_labels[i, 0, : len(with_cont1)] = with_cont1
lm_labels[i, 1, :len(with_cont2)] = with_cont2 lm_labels[i, 1, : len(with_cont2)] = with_cont2
mc_labels[i] = mc_label mc_labels[i] = mc_label
all_inputs = (input_ids, mc_token_ids, lm_labels, mc_labels) all_inputs = (input_ids, mc_token_ids, lm_labels, mc_labels)
tensor_datasets.append(tuple(torch.tensor(t) for t in all_inputs)) tensor_datasets.append(tuple(torch.tensor(t) for t in all_inputs))
return tensor_datasets return tensor_datasets
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default='openai-gpt', parser.add_argument("--model_name", type=str, default="openai-gpt", help="pretrained model name")
help='pretrained model name') parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument(
parser.add_argument("--output_dir", default=None, type=str, required=True, "--output_dir",
help="The output directory where the model predictions and checkpoints will be written.") default=None,
parser.add_argument('--train_dataset', type=str, default='') type=str,
parser.add_argument('--eval_dataset', type=str, default='') required=True,
parser.add_argument('--seed', type=int, default=42) help="The output directory where the model predictions and checkpoints will be written.",
parser.add_argument('--num_train_epochs', type=int, default=3) )
parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument("--train_dataset", type=str, default="")
parser.add_argument('--eval_batch_size', type=int, default=16) parser.add_argument("--eval_dataset", type=str, default="")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, parser.add_argument("--seed", type=int, default=42)
help="Epsilon for Adam optimizer.") parser.add_argument("--num_train_epochs", type=int, default=3)
parser.add_argument('--max_grad_norm', type=int, default=1) parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument("--max_steps", default=-1, type=int, parser.add_argument("--eval_batch_size", type=int, default=16)
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", type=int, default=1)
parser.add_argument(
"--max_steps",
default=-1,
type=int,
help="If > 0: set total number of training \ help="If > 0: set total number of training \
steps to perform. Override num_train_epochs.") steps to perform. Override num_train_epochs.",
parser.add_argument('--gradient_accumulation_steps', type=int, default=1, )
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before\ help="Number of updates steps to accumulate before\
performing a backward/update pass.") performing a backward/update pass.",
parser.add_argument('--learning_rate', type=float, default=6.25e-5) )
parser.add_argument("--warmup_steps", default=0, type=int, parser.add_argument("--learning_rate", type=float, default=6.25e-5)
help="Linear warmup over warmup_steps.") parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument('--lr_schedule', type=str, default='warmup_linear') parser.add_argument("--lr_schedule", type=str, default="warmup_linear")
parser.add_argument('--weight_decay', type=float, default=0.01) parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument('--lm_coef', type=float, default=0.9) parser.add_argument("--lm_coef", type=float, default=0.9)
parser.add_argument('--n_valid', type=int, default=374) parser.add_argument("--n_valid", type=int, default=374)
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
if args.server_ip and args.server_port: if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd import ptvsd
print("Waiting for debugger attach") print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach() ptvsd.wait_for_attach()
...@@ -152,7 +172,7 @@ def main(): ...@@ -152,7 +172,7 @@ def main():
# Load tokenizer and model # Load tokenizer and model
# This loading functions also add new tokens and embeddings called `special tokens` # This loading functions also add new tokens and embeddings called `special tokens`
# These new embeddings will be fine-tuned on the RocStories dataset # These new embeddings will be fine-tuned on the RocStories dataset
special_tokens = ['_start_', '_delimiter_', '_classify_'] special_tokens = ["_start_", "_delimiter_", "_classify_"]
tokenizer = OpenAIGPTTokenizer.from_pretrained(args.model_name) tokenizer = OpenAIGPTTokenizer.from_pretrained(args.model_name)
tokenizer.add_tokens(special_tokens) tokenizer.add_tokens(special_tokens)
special_tokens_ids = tokenizer.convert_tokens_to_ids(special_tokens) special_tokens_ids = tokenizer.convert_tokens_to_ids(special_tokens)
...@@ -163,6 +183,7 @@ def main(): ...@@ -163,6 +183,7 @@ def main():
# Load and encode the datasets # Load and encode the datasets
if not args.train_dataset and not args.eval_dataset: if not args.train_dataset and not args.eval_dataset:
roc_stories = cached_path(ROCSTORIES_URL) roc_stories = cached_path(ROCSTORIES_URL)
def tokenize_and_encode(obj): def tokenize_and_encode(obj):
""" Tokenize and encode a nested object """ """ Tokenize and encode a nested object """
if isinstance(obj, str): if isinstance(obj, str):
...@@ -170,6 +191,7 @@ def main(): ...@@ -170,6 +191,7 @@ def main():
elif isinstance(obj, int): elif isinstance(obj, int):
return obj return obj
return list(tokenize_and_encode(o) for o in obj) return list(tokenize_and_encode(o) for o in obj)
logger.info("Encoding dataset...") logger.info("Encoding dataset...")
train_dataset = load_rocstories_dataset(args.train_dataset) train_dataset = load_rocstories_dataset(args.train_dataset)
eval_dataset = load_rocstories_dataset(args.eval_dataset) eval_dataset = load_rocstories_dataset(args.eval_dataset)
...@@ -178,8 +200,11 @@ def main(): ...@@ -178,8 +200,11 @@ def main():
# Compute the max input length for the Transformer # Compute the max input length for the Transformer
max_length = model.config.n_positions // 2 - 2 max_length = model.config.n_positions // 2 - 2
input_length = max(len(story[:max_length]) + max(len(cont1[:max_length]), len(cont2[:max_length])) + 3 \ input_length = max(
for dataset in encoded_datasets for story, cont1, cont2, _ in dataset) len(story[:max_length]) + max(len(cont1[:max_length]), len(cont2[:max_length])) + 3
for dataset in encoded_datasets
for story, cont1, cont2, _ in dataset
)
input_length = min(input_length, model.config.n_positions) # Max size of input for the pre-trained model input_length = min(input_length, model.config.n_positions) # Max size of input for the pre-trained model
# Prepare inputs tensors and dataloaders # Prepare inputs tensors and dataloaders
...@@ -198,20 +223,23 @@ def main(): ...@@ -198,20 +223,23 @@ def main():
if args.do_train: if args.do_train:
if args.max_steps > 0: if args.max_steps > 0:
t_total = args.max_steps t_total = args.max_steps
args.num_train_epochs = args.max_steps //\ args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
(len(train_dataloader) // args.gradient_accumulation_steps) + 1
else: else:
t_total = len(train_dataloader)\ t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
// args.gradient_accumulation_steps * args.num_train_epochs
param_optimizer = list(model.named_parameters()) param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, {
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
] ]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)
if args.do_train: if args.do_train:
nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None
...@@ -230,14 +258,16 @@ def main(): ...@@ -230,14 +258,16 @@ def main():
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
tr_loss += loss.item() tr_loss += loss.item()
exp_average_loss = loss.item() if exp_average_loss is None else 0.7*exp_average_loss+0.3*loss.item() exp_average_loss = (
loss.item() if exp_average_loss is None else 0.7 * exp_average_loss + 0.3 * loss.item()
)
nb_tr_steps += 1 nb_tr_steps += 1
tqdm_bar.desc = "Training loss: {:.2e} lr: {:.2e}".format(exp_average_loss, scheduler.get_lr()[0]) tqdm_bar.desc = "Training loss: {:.2e} lr: {:.2e}".format(exp_average_loss, scheduler.get_lr()[0])
# Save a trained model # Save a trained model
if args.do_train: if args.do_train:
# Save a trained model, configuration and tokenizer # Save a trained model, configuration and tokenizer
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model itself model_to_save = model.module if hasattr(model, "module") else model # Only save the model itself
# If we save using the predefined names, we can load using `from_pretrained` # If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
...@@ -260,10 +290,12 @@ def main(): ...@@ -260,10 +290,12 @@ def main():
batch = tuple(t.to(device) for t in batch) batch = tuple(t.to(device) for t in batch)
input_ids, mc_token_ids, lm_labels, mc_labels = batch input_ids, mc_token_ids, lm_labels, mc_labels = batch
with torch.no_grad(): with torch.no_grad():
_, mc_loss, _, mc_logits = model(input_ids, mc_token_ids=mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels) _, mc_loss, _, mc_logits = model(
input_ids, mc_token_ids=mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels
)
mc_logits = mc_logits.detach().cpu().numpy() mc_logits = mc_logits.detach().cpu().numpy()
mc_labels = mc_labels.to('cpu').numpy() mc_labels = mc_labels.to("cpu").numpy()
tmp_eval_accuracy = accuracy(mc_logits, mc_labels) tmp_eval_accuracy = accuracy(mc_logits, mc_labels)
eval_loss += mc_loss.mean().item() eval_loss += mc_loss.mean().item()
...@@ -274,10 +306,8 @@ def main(): ...@@ -274,10 +306,8 @@ def main():
eval_loss = eval_loss / nb_eval_steps eval_loss = eval_loss / nb_eval_steps
eval_accuracy = eval_accuracy / nb_eval_examples eval_accuracy = eval_accuracy / nb_eval_examples
train_loss = tr_loss/nb_tr_steps if args.do_train else None train_loss = tr_loss / nb_tr_steps if args.do_train else None
result = {'eval_loss': eval_loss, result = {"eval_loss": eval_loss, "eval_accuracy": eval_accuracy, "train_loss": train_loss}
'eval_accuracy': eval_accuracy,
'train_loss': train_loss}
output_eval_file = os.path.join(args.output_dir, "eval_results.txt") output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer: with open(output_eval_file, "w") as writer:
...@@ -286,5 +316,6 @@ def main(): ...@@ -286,5 +316,6 @@ def main():
logger.info(" %s = %s", key, str(result[key])) logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key]))) writer.write("%s = %s\n" % (key, str(result[key])))
if __name__ == '__main__':
if __name__ == "__main__":
main() main()
...@@ -28,8 +28,7 @@ import glob ...@@ -28,8 +28,7 @@ import glob
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
TensorDataset)
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
try: try:
...@@ -39,31 +38,23 @@ except: ...@@ -39,31 +38,23 @@ except:
from tqdm import tqdm, trange from tqdm import tqdm, trange
from transformers import (WEIGHTS_NAME, BertConfig, from transformers import WEIGHTS_NAME, BertConfig, BertForMultipleChoice, BertTokenizer
BertForMultipleChoice, BertTokenizer)
from transformers import AdamW, get_linear_schedule_with_warmup from transformers import AdamW, get_linear_schedule_with_warmup
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \ ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in [BertConfig]), ())
for conf in [BertConfig]), ())
MODEL_CLASSES = { MODEL_CLASSES = {
'bert': (BertConfig, BertForMultipleChoice, BertTokenizer), "bert": (BertConfig, BertForMultipleChoice, BertTokenizer),
} }
class SwagExample(object): class SwagExample(object):
"""A single training/test example for the SWAG dataset.""" """A single training/test example for the SWAG dataset."""
def __init__(self,
swag_id, def __init__(self, swag_id, context_sentence, start_ending, ending_0, ending_1, ending_2, ending_3, label=None):
context_sentence,
start_ending,
ending_0,
ending_1,
ending_2,
ending_3,
label = None):
self.swag_id = swag_id self.swag_id = swag_id
self.context_sentence = context_sentence self.context_sentence = context_sentence
self.start_ending = start_ending self.start_ending = start_ending
...@@ -94,57 +85,49 @@ class SwagExample(object): ...@@ -94,57 +85,49 @@ class SwagExample(object):
return ", ".join(l) return ", ".join(l)
class InputFeatures(object):
def __init__(self,
example_id,
choices_features,
label
): class InputFeatures(object):
def __init__(self, example_id, choices_features, label):
self.example_id = example_id self.example_id = example_id
self.choices_features = [ self.choices_features = [
{ {"input_ids": input_ids, "input_mask": input_mask, "segment_ids": segment_ids}
'input_ids': input_ids,
'input_mask': input_mask,
'segment_ids': segment_ids
}
for _, input_ids, input_mask, segment_ids in choices_features for _, input_ids, input_mask, segment_ids in choices_features
] ]
self.label = label self.label = label
def read_swag_examples(input_file, is_training=True): def read_swag_examples(input_file, is_training=True):
with open(input_file, 'r', encoding='utf-8') as f: with open(input_file, "r", encoding="utf-8") as f:
reader = csv.reader(f) reader = csv.reader(f)
lines = [] lines = []
for line in reader: for line in reader:
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
line = list(unicode(cell, 'utf-8') for cell in line) line = list(unicode(cell, "utf-8") for cell in line)
lines.append(line) lines.append(line)
if is_training and lines[0][-1] != 'label': if is_training and lines[0][-1] != "label":
raise ValueError( raise ValueError("For training, the input file must contain a label column.")
"For training, the input file must contain a label column."
)
examples = [ examples = [
SwagExample( SwagExample(
swag_id = line[2], swag_id=line[2],
context_sentence = line[4], context_sentence=line[4],
start_ending = line[5], # in the swag dataset, the start_ending=line[5], # in the swag dataset, the
# common beginning of each # common beginning of each
# choice is stored in "sent2". # choice is stored in "sent2".
ending_0 = line[7], ending_0=line[7],
ending_1 = line[8], ending_1=line[8],
ending_2 = line[9], ending_2=line[9],
ending_3 = line[10], ending_3=line[10],
label = int(line[11]) if is_training else None label=int(line[11]) if is_training else None,
) for line in lines[1:] # we skip the line with the column names )
for line in lines[1:] # we skip the line with the column names
] ]
return examples return examples
def convert_examples_to_features(examples, tokenizer, max_seq_length,
is_training): def convert_examples_to_features(examples, tokenizer, max_seq_length, is_training):
"""Loads a data file into a list of `InputBatch`s.""" """Loads a data file into a list of `InputBatch`s."""
# Swag is a multiple choice task. To perform this task using Bert, # Swag is a multiple choice task. To perform this task using Bert,
...@@ -204,23 +187,18 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -204,23 +187,18 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
logger.info("swag_id: {}".format(example.swag_id)) logger.info("swag_id: {}".format(example.swag_id))
for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features): for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features):
logger.info("choice: {}".format(choice_idx)) logger.info("choice: {}".format(choice_idx))
logger.info("tokens: {}".format(' '.join(tokens))) logger.info("tokens: {}".format(" ".join(tokens)))
logger.info("input_ids: {}".format(' '.join(map(str, input_ids)))) logger.info("input_ids: {}".format(" ".join(map(str, input_ids))))
logger.info("input_mask: {}".format(' '.join(map(str, input_mask)))) logger.info("input_mask: {}".format(" ".join(map(str, input_mask))))
logger.info("segment_ids: {}".format(' '.join(map(str, segment_ids)))) logger.info("segment_ids: {}".format(" ".join(map(str, segment_ids))))
if is_training: if is_training:
logger.info("label: {}".format(label)) logger.info("label: {}".format(label))
features.append( features.append(InputFeatures(example_id=example.swag_id, choices_features=choices_features, label=label))
InputFeatures(
example_id = example.swag_id,
choices_features = choices_features,
label = label
)
)
return features return features
def _truncate_seq_pair(tokens_a, tokens_b, max_length): def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length.""" """Truncates a sequence pair in place to the maximum length."""
...@@ -237,18 +215,14 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length): ...@@ -237,18 +215,14 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
else: else:
tokens_b.pop() tokens_b.pop()
def accuracy(out, labels): def accuracy(out, labels):
outputs = np.argmax(out, axis=1) outputs = np.argmax(out, axis=1)
return np.sum(outputs == labels) return np.sum(outputs == labels)
def select_field(features, field): def select_field(features, field):
return [ return [[choice[field] for choice in feature.choices_features] for feature in features]
[
choice[field]
for choice in feature.choices_features
]
for feature in features
]
def set_seed(args): def set_seed(args):
...@@ -258,24 +232,28 @@ def set_seed(args): ...@@ -258,24 +232,28 @@ def set_seed(args):
if args.n_gpu > 0: if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed) torch.cuda.manual_seed_all(args.seed)
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False): def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
if args.local_rank not in [-1, 0]: if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Load data features from cache or dataset file # Load data features from cache or dataset file
input_file = args.predict_file if evaluate else args.train_file input_file = args.predict_file if evaluate else args.train_file
cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format( cached_features_file = os.path.join(
'dev' if evaluate else 'train', os.path.dirname(input_file),
list(filter(None, args.model_name_or_path.split('/'))).pop(), "cached_{}_{}_{}".format(
str(args.max_seq_length))) "dev" if evaluate else "train",
list(filter(None, args.model_name_or_path.split("/"))).pop(),
str(args.max_seq_length),
),
)
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples: if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
logger.info("Loading features from cached file %s", cached_features_file) logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file) features = torch.load(cached_features_file)
else: else:
logger.info("Creating features from dataset file at %s", input_file) logger.info("Creating features from dataset file at %s", input_file)
examples = read_swag_examples(input_file) examples = read_swag_examples(input_file)
features = convert_examples_to_features( features = convert_examples_to_features(examples, tokenizer, args.max_seq_length, not evaluate)
examples, tokenizer, args.max_seq_length, not evaluate)
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file) logger.info("Saving features into cached file %s", cached_features_file)
...@@ -285,21 +263,21 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ...@@ -285,21 +263,21 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Convert to Tensors and build dataset # Convert to Tensors and build dataset
all_input_ids = torch.tensor(select_field(features, 'input_ids'), dtype=torch.long) all_input_ids = torch.tensor(select_field(features, "input_ids"), dtype=torch.long)
all_input_mask = torch.tensor(select_field(features, 'input_mask'), dtype=torch.long) all_input_mask = torch.tensor(select_field(features, "input_mask"), dtype=torch.long)
all_segment_ids = torch.tensor(select_field(features, 'segment_ids'), dtype=torch.long) all_segment_ids = torch.tensor(select_field(features, "segment_ids"), dtype=torch.long)
all_label = torch.tensor([f.label for f in features], dtype=torch.long) all_label = torch.tensor([f.label for f in features], dtype=torch.long)
if evaluate: if evaluate:
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
all_label)
else: else:
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
all_label)
if output_examples: if output_examples:
return dataset, examples, features return dataset, examples, features
return dataset return dataset
def train(args, train_dataset, model, tokenizer): def train(args, train_dataset, model, tokenizer):
""" Train the model """ """ Train the model """
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
...@@ -316,13 +294,18 @@ def train(args, train_dataset, model, tokenizer): ...@@ -316,13 +294,18 @@ def train(args, train_dataset, model, tokenizer):
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
# Prepare optimizer and schedule (linear warmup and decay) # Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight'] no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, {
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
] ]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)
if args.fp16: if args.fp16:
try: try:
from apex import amp from apex import amp
...@@ -336,17 +319,21 @@ def train(args, train_dataset, model, tokenizer): ...@@ -336,17 +319,21 @@ def train(args, train_dataset, model, tokenizer):
# Distributed training (should be after apex fp16 initialization) # Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], model = torch.nn.parallel.DistributedDataParallel(
output_device=args.local_rank, model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
find_unused_parameters=True) )
# Train! # Train!
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", logger.info(
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) " Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size
* args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
)
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total) logger.info(" Total optimization steps = %d", t_total)
...@@ -360,11 +347,13 @@ def train(args, train_dataset, model, tokenizer): ...@@ -360,11 +347,13 @@ def train(args, train_dataset, model, tokenizer):
for step, batch in enumerate(epoch_iterator): for step, batch in enumerate(epoch_iterator):
model.train() model.train()
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0], inputs = {
'attention_mask': batch[1], "input_ids": batch[0],
"attention_mask": batch[1],
#'token_type_ids': None if args.model_type == 'xlm' else batch[2], #'token_type_ids': None if args.model_type == 'xlm' else batch[2],
'token_type_ids': batch[2], "token_type_ids": batch[2],
'labels': batch[3]} "labels": batch[3],
}
# if args.model_type in ['xlnet', 'xlm']: # if args.model_type in ['xlnet', 'xlm']:
# inputs.update({'cls_index': batch[5], # inputs.update({'cls_index': batch[5],
# 'p_mask': batch[6]}) # 'p_mask': batch[6]})
...@@ -393,23 +382,27 @@ def train(args, train_dataset, model, tokenizer): ...@@ -393,23 +382,27 @@ def train(args, train_dataset, model, tokenizer):
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
# Log metrics # Log metrics
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well if (
args.local_rank == -1 and args.evaluate_during_training
): # Only evaluate when single GPU otherwise metrics may not average well
results = evaluate(args, model, tokenizer) results = evaluate(args, model, tokenizer)
for key, value in results.items(): for key, value in results.items():
tb_writer.add_scalar('eval_{}'.format(key), value, global_step) tb_writer.add_scalar("eval_{}".format(key), value, global_step)
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step) tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
logging_loss = tr_loss logging_loss = tr_loss
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
# Save model checkpoint # Save model checkpoint
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir) model_to_save.save_pretrained(output_dir)
tokenizer.save_vocabulary(output_dir) tokenizer.save_vocabulary(output_dir)
torch.save(args, os.path.join(output_dir, 'training_args.bin')) torch.save(args, os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to %s", output_dir) logger.info("Saving model checkpoint to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps: if args.max_steps > 0 and global_step > args.max_steps:
...@@ -424,6 +417,7 @@ def train(args, train_dataset, model, tokenizer): ...@@ -424,6 +417,7 @@ def train(args, train_dataset, model, tokenizer):
return global_step, tr_loss / global_step return global_step, tr_loss / global_step
def evaluate(args, model, tokenizer, prefix=""): def evaluate(args, model, tokenizer, prefix=""):
dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True) dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
...@@ -440,7 +434,6 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -440,7 +434,6 @@ def evaluate(args, model, tokenizer, prefix=""):
logger.info(" Num examples = %d", len(dataset)) logger.info(" Num examples = %d", len(dataset))
logger.info(" Batch size = %d", args.eval_batch_size) logger.info(" Batch size = %d", args.eval_batch_size)
eval_loss, eval_accuracy = 0, 0 eval_loss, eval_accuracy = 0, 0
nb_eval_steps, nb_eval_examples = 0, 0 nb_eval_steps, nb_eval_examples = 0, 0
...@@ -448,11 +441,13 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -448,11 +441,13 @@ def evaluate(args, model, tokenizer, prefix=""):
model.eval() model.eval()
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
with torch.no_grad(): with torch.no_grad():
inputs = {'input_ids': batch[0], inputs = {
'attention_mask': batch[1], "input_ids": batch[0],
"attention_mask": batch[1],
# 'token_type_ids': None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids # 'token_type_ids': None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
'token_type_ids': batch[2], "token_type_ids": batch[2],
'labels': batch[3]} "labels": batch[3],
}
# if args.model_type in ['xlnet', 'xlm']: # if args.model_type in ['xlnet', 'xlm']:
# inputs.update({'cls_index': batch[4], # inputs.update({'cls_index': batch[4],
...@@ -462,17 +457,16 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -462,17 +457,16 @@ def evaluate(args, model, tokenizer, prefix=""):
eval_loss += tmp_eval_loss.mean().item() eval_loss += tmp_eval_loss.mean().item()
logits = logits.detach().cpu().numpy() logits = logits.detach().cpu().numpy()
label_ids = inputs['labels'].to('cpu').numpy() label_ids = inputs["labels"].to("cpu").numpy()
tmp_eval_accuracy = accuracy(logits, label_ids) tmp_eval_accuracy = accuracy(logits, label_ids)
eval_accuracy += tmp_eval_accuracy eval_accuracy += tmp_eval_accuracy
nb_eval_steps += 1 nb_eval_steps += 1
nb_eval_examples += inputs['input_ids'].size(0) nb_eval_examples += inputs["input_ids"].size(0)
eval_loss = eval_loss / nb_eval_steps eval_loss = eval_loss / nb_eval_steps
eval_accuracy = eval_accuracy / nb_eval_examples eval_accuracy = eval_accuracy / nb_eval_examples
result = {'eval_loss': eval_loss, result = {"eval_loss": eval_loss, "eval_accuracy": eval_accuracy}
'eval_accuracy': eval_accuracy}
output_eval_file = os.path.join(args.output_dir, "eval_results.txt") output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer: with open(output_eval_file, "w") as writer:
...@@ -483,92 +477,144 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -483,92 +477,144 @@ def evaluate(args, model, tokenizer, prefix=""):
return result return result
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters ## Required parameters
parser.add_argument("--train_file", default=None, type=str, required=True, parser.add_argument(
help="SWAG csv for training. E.g., train.csv") "--train_file", default=None, type=str, required=True, help="SWAG csv for training. E.g., train.csv"
parser.add_argument("--predict_file", default=None, type=str, required=True, )
help="SWAG csv for predictions. E.g., val.csv or test.csv") parser.add_argument(
parser.add_argument("--model_type", default=None, type=str, required=True, "--predict_file",
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) default=None,
parser.add_argument("--model_name_or_path", default=None, type=str, required=True, type=str,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) required=True,
parser.add_argument("--output_dir", default=None, type=str, required=True, help="SWAG csv for predictions. E.g., val.csv or test.csv",
help="The output directory where the model checkpoints and predictions will be written.") )
parser.add_argument(
"--model_type",
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
)
parser.add_argument(
"--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model checkpoints and predictions will be written.",
)
## Other parameters ## Other parameters
parser.add_argument("--config_name", default="", type=str, parser.add_argument(
help="Pretrained config name or path if not the same as model_name") "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
parser.add_argument("--tokenizer_name", default="", type=str, )
help="Pretrained tokenizer name or path if not the same as model_name") parser.add_argument(
parser.add_argument("--max_seq_length", default=384, type=int, "--tokenizer_name",
default="",
type=str,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--max_seq_length",
default=384,
type=int,
help="The maximum total input sequence length after tokenization. Sequences " help="The maximum total input sequence length after tokenization. Sequences "
"longer than this will be truncated, and sequences shorter than this will be padded.") "longer than this will be truncated, and sequences shorter than this will be padded.",
parser.add_argument("--do_train", action='store_true', )
help="Whether to run training.") parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true', parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
help="Whether to run eval on the dev set.") parser.add_argument(
parser.add_argument("--evaluate_during_training", action='store_true', "--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
help="Rul evaluation during training at each logging step.") )
parser.add_argument("--do_lower_case", action='store_true', parser.add_argument(
help="Set this flag if you are using an uncased model.") "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
)
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
help="Batch size per GPU/CPU for training.") parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, parser.add_argument(
help="Batch size per GPU/CPU for evaluation.") "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
parser.add_argument("--learning_rate", default=5e-5, type=float, )
help="The initial learning rate for Adam.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1, parser.add_argument(
help="Number of updates steps to accumulate before performing a backward/update pass.") "--gradient_accumulation_steps",
parser.add_argument("--weight_decay", default=0.0, type=float, type=int,
help="Weight deay if we apply some.") default=1,
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Number of updates steps to accumulate before performing a backward/update pass.",
help="Epsilon for Adam optimizer.") )
parser.add_argument("--max_grad_norm", default=1.0, type=float, parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
help="Max gradient norm.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--num_train_epochs", default=3.0, type=float, parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
help="Total number of training epochs to perform.") parser.add_argument(
parser.add_argument("--max_steps", default=-1, type=int, "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
help="If > 0: set total number of training steps to perform. Override num_train_epochs.") )
parser.add_argument("--warmup_steps", default=0, type=int, parser.add_argument(
help="Linear warmup over warmup_steps.") "--max_steps",
default=-1,
parser.add_argument('--logging_steps', type=int, default=50, type=int,
help="Log every X updates steps.") help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
parser.add_argument('--save_steps', type=int, default=50, )
help="Save checkpoint every X updates steps.") parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--eval_all_checkpoints", action='store_true',
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
parser.add_argument("--no_cuda", action='store_true', parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
help="Whether not to use CUDA when available") parser.add_argument(
parser.add_argument('--overwrite_output_dir', action='store_true', "--eval_all_checkpoints",
help="Overwrite the content of the output directory") action="store_true",
parser.add_argument('--overwrite_cache', action='store_true', help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
help="Overwrite the cached training and evaluation sets") )
parser.add_argument('--seed', type=int, default=42, parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
help="random seed for initialization") parser.add_argument(
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
parser.add_argument("--local_rank", type=int, default=-1, )
help="local_rank for distributed training on gpus") parser.add_argument(
parser.add_argument('--fp16', action='store_true', "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") )
parser.add_argument('--fp16_opt_level', type=str, default='O1', parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
parser.add_argument(
"--fp16",
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
parser.add_argument(
"--fp16_opt_level",
type=str,
default="O1",
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html") "See details at https://nvidia.github.io/apex/amp.html",
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") )
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
args = parser.parse_args() args = parser.parse_args()
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: if (
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) os.path.exists(args.output_dir)
and os.listdir(args.output_dir)
and args.do_train
and not args.overwrite_output_dir
):
raise ValueError(
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
args.output_dir
)
)
# Setup distant debugging if needed # Setup distant debugging if needed
if args.server_ip and args.server_port: if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd import ptvsd
print("Waiting for debugger attach") print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach() ptvsd.wait_for_attach()
...@@ -580,16 +626,24 @@ def main(): ...@@ -580,16 +626,24 @@ def main():
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.cuda.set_device(args.local_rank) torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank) device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend='nccl') torch.distributed.init_process_group(backend="nccl")
args.n_gpu = 1 args.n_gpu = 1
args.device = device args.device = device
# Setup logging # Setup logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(
datefmt = '%m/%d/%Y %H:%M:%S', format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) datefmt="%m/%d/%Y %H:%M:%S",
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) )
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
args.local_rank,
device,
args.n_gpu,
bool(args.local_rank != -1),
args.fp16,
)
# Set seed # Set seed
set_seed(args) set_seed(args)
...@@ -601,8 +655,12 @@ def main(): ...@@ -601,8 +655,12 @@ def main():
args.model_type = args.model_type.lower() args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case) tokenizer = tokenizer_class.from_pretrained(
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config) args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case
)
model = model_class.from_pretrained(
args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config
)
if args.local_rank == 0: if args.local_rank == 0:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
...@@ -617,7 +675,6 @@ def main(): ...@@ -617,7 +675,6 @@ def main():
global_step, tr_loss = train(args, train_dataset, model, tokenizer) global_step, tr_loss = train(args, train_dataset, model, tokenizer)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
# Save the trained model and the tokenizer # Save the trained model and the tokenizer
if args.local_rank == -1 or torch.distributed.get_rank() == 0: if args.local_rank == -1 or torch.distributed.get_rank() == 0:
# Create output directory if needed # Create output directory if needed
...@@ -627,19 +684,20 @@ def main(): ...@@ -627,19 +684,20 @@ def main():
logger.info("Saving model checkpoint to %s", args.output_dir) logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`. # Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()` # They can then be reloaded using `from_pretrained()`
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(args.output_dir) model_to_save.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
# Good practice: save your training arguments together with the trained model # Good practice: save your training arguments together with the trained model
torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
# Load a trained model and vocabulary that you have fine-tuned # Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir) model = model_class.from_pretrained(args.output_dir)
tokenizer = tokenizer_class.from_pretrained(args.output_dir) tokenizer = tokenizer_class.from_pretrained(args.output_dir)
model.to(args.device) model.to(args.device)
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
results = {} results = {}
if args.do_eval and args.local_rank in [-1, 0]: if args.do_eval and args.local_rank in [-1, 0]:
...@@ -650,14 +708,16 @@ def main(): ...@@ -650,14 +708,16 @@ def main():
checkpoints = [args.model_name_or_path] checkpoints = [args.model_name_or_path]
if args.eval_all_checkpoints: if args.eval_all_checkpoints:
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) checkpoints = list(
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
)
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
# Reload the model # Reload the model
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
model = model_class.from_pretrained(checkpoint) model = model_class.from_pretrained(checkpoint)
tokenizer = tokenizer_class.from_pretrained(checkpoint) tokenizer = tokenizer_class.from_pretrained(checkpoint)
model.to(args.device) model.to(args.device)
...@@ -665,7 +725,7 @@ def main(): ...@@ -665,7 +725,7 @@ def main():
# Evaluate # Evaluate
result = evaluate(args, model, tokenizer, prefix=global_step) result = evaluate(args, model, tokenizer, prefix=global_step)
result = dict((k + ('_{}'.format(global_step) if global_step else ''), v) for k, v in result.items()) result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
results.update(result) results.update(result)
logger.info("Results: {}".format(results)) logger.info("Results: {}".format(results))
......
...@@ -30,44 +30,36 @@ import torch ...@@ -30,44 +30,36 @@ import torch
from transformers import TransfoXLLMHeadModel, TransfoXLCorpus, TransfoXLTokenizer from transformers import TransfoXLLMHeadModel, TransfoXLCorpus, TransfoXLTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(
datefmt = '%m/%d/%Y %H:%M:%S', format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
level = logging.INFO) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def main(): def main():
parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model') parser = argparse.ArgumentParser(description="PyTorch Transformer Language Model")
parser.add_argument('--model_name', type=str, default='transfo-xl-wt103', parser.add_argument("--model_name", type=str, default="transfo-xl-wt103", help="pretrained model name")
help='pretrained model name') parser.add_argument(
parser.add_argument('--split', type=str, default='test', "--split", type=str, default="test", choices=["all", "valid", "test"], help="which split to evaluate"
choices=['all', 'valid', 'test'], )
help='which split to evaluate') parser.add_argument("--batch_size", type=int, default=10, help="batch size")
parser.add_argument('--batch_size', type=int, default=10, parser.add_argument("--tgt_len", type=int, default=128, help="number of tokens to predict")
help='batch size') parser.add_argument("--ext_len", type=int, default=0, help="length of the extended context")
parser.add_argument('--tgt_len', type=int, default=128, parser.add_argument("--mem_len", type=int, default=1600, help="length of the retained previous heads")
help='number of tokens to predict') parser.add_argument("--clamp_len", type=int, default=1000, help="max positional embedding index")
parser.add_argument('--ext_len', type=int, default=0, parser.add_argument("--no_cuda", action="store_true", help="Do not use CUDA even though CUA is available")
help='length of the extended context') parser.add_argument("--work_dir", type=str, required=True, help="path to the work_dir")
parser.add_argument('--mem_len', type=int, default=1600, parser.add_argument("--no_log", action="store_true", help="do not log the eval result")
help='length of the retained previous heads') parser.add_argument("--same_length", action="store_true", help="set same length attention with masking")
parser.add_argument('--clamp_len', type=int, default=1000, parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
help='max positional embedding index') parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument('--no_cuda', action='store_true',
help='Do not use CUDA even though CUA is available')
parser.add_argument('--work_dir', type=str, required=True,
help='path to the work_dir')
parser.add_argument('--no_log', action='store_true',
help='do not log the eval result')
parser.add_argument('--same_length', action='store_true',
help='set same length attention with masking')
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
args = parser.parse_args() args = parser.parse_args()
assert args.ext_len >= 0, 'extended context length must be non-negative' assert args.ext_len >= 0, "extended context length must be non-negative"
if args.server_ip and args.server_port: if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd import ptvsd
print("Waiting for debugger attach") print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach() ptvsd.wait_for_attach()
...@@ -84,17 +76,18 @@ def main(): ...@@ -84,17 +76,18 @@ def main():
corpus = TransfoXLCorpus.from_pretrained(args.model_name) corpus = TransfoXLCorpus.from_pretrained(args.model_name)
ntokens = len(corpus.vocab) ntokens = len(corpus.vocab)
va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len, va_iter = corpus.get_iterator("valid", args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len)
device=device, ext_len=args.ext_len) te_iter = corpus.get_iterator("test", args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len)
te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
device=device, ext_len=args.ext_len)
# Load a pre-trained model # Load a pre-trained model
model = TransfoXLLMHeadModel.from_pretrained(args.model_name) model = TransfoXLLMHeadModel.from_pretrained(args.model_name)
model = model.to(device) model = model.to(device)
logger.info('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format( logger.info(
args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len)) "Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}".format(
args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len
)
)
model.reset_length(args.tgt_len, args.ext_len, args.mem_len) model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
if args.clamp_len > 0: if args.clamp_len > 0:
...@@ -108,7 +101,7 @@ def main(): ...@@ -108,7 +101,7 @@ def main():
def evaluate(eval_iter): def evaluate(eval_iter):
# Turn on evaluation mode which disables dropout. # Turn on evaluation mode which disables dropout.
model.eval() model.eval()
total_len, total_loss = 0, 0. total_len, total_loss = 0, 0.0
start_time = time.time() start_time = time.time()
with torch.no_grad(): with torch.no_grad():
mems = None mems = None
...@@ -119,35 +112,34 @@ def main(): ...@@ -119,35 +112,34 @@ def main():
total_loss += seq_len * loss.item() total_loss += seq_len * loss.item()
total_len += seq_len total_len += seq_len
total_time = time.time() - start_time total_time = time.time() - start_time
logger.info('Time : {:.2f}s, {:.2f}ms/segment'.format( logger.info("Time : {:.2f}s, {:.2f}ms/segment".format(total_time, 1000 * total_time / (idx + 1)))
total_time, 1000 * total_time / (idx+1)))
return total_loss / total_len return total_loss / total_len
# Run on test data. # Run on test data.
if args.split == 'all': if args.split == "all":
test_loss = evaluate(te_iter) test_loss = evaluate(te_iter)
valid_loss = evaluate(va_iter) valid_loss = evaluate(va_iter)
elif args.split == 'valid': elif args.split == "valid":
valid_loss = evaluate(va_iter) valid_loss = evaluate(va_iter)
test_loss = None test_loss = None
elif args.split == 'test': elif args.split == "test":
test_loss = evaluate(te_iter) test_loss = evaluate(te_iter)
valid_loss = None valid_loss = None
def format_log(loss, split): def format_log(loss, split):
log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format( log_str = "| {0} loss {1:5.2f} | {0} ppl {2:9.3f} ".format(split, loss, math.exp(loss))
split, loss, math.exp(loss))
return log_str return log_str
log_str = '' log_str = ""
if valid_loss is not None: if valid_loss is not None:
log_str += format_log(valid_loss, 'valid') log_str += format_log(valid_loss, "valid")
if test_loss is not None: if test_loss is not None:
log_str += format_log(test_loss, 'test') log_str += format_log(test_loss, "test")
logger.info('=' * 100) logger.info("=" * 100)
logger.info(log_str) logger.info(log_str)
logger.info('=' * 100) logger.info("=" * 100)
if __name__ == '__main__': if __name__ == "__main__":
main() main()
...@@ -40,14 +40,12 @@ from utils import logger ...@@ -40,14 +40,12 @@ from utils import logger
from lm_seqs_dataset import LmSeqsDataset from lm_seqs_dataset import LmSeqsDataset
from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
class Distiller: class Distiller:
def __init__(self, def __init__(
params: dict, self, params: dict, dataset: LmSeqsDataset, token_probs: torch.tensor, student: nn.Module, teacher: nn.Module
dataset: LmSeqsDataset, ):
token_probs: torch.tensor, logger.info("Initializing Distiller")
student: nn.Module,
teacher: nn.Module):
logger.info('Initializing Distiller')
self.params = params self.params = params
self.dump_path = params.dump_path self.dump_path = params.dump_path
self.multi_gpu = params.multi_gpu self.multi_gpu = params.multi_gpu
...@@ -70,12 +68,10 @@ class Distiller: ...@@ -70,12 +68,10 @@ class Distiller:
else: else:
sampler = BatchSampler(sampler=sampler, batch_size=params.batch_size, drop_last=False) sampler = BatchSampler(sampler=sampler, batch_size=params.batch_size, drop_last=False)
self.dataloader = DataLoader(dataset=dataset, self.dataloader = DataLoader(dataset=dataset, batch_sampler=sampler, collate_fn=dataset.batch_sequences)
batch_sampler=sampler,
collate_fn=dataset.batch_sequences)
self.temperature = params.temperature self.temperature = params.temperature
assert self.temperature > 0. assert self.temperature > 0.0
self.alpha_ce = params.alpha_ce self.alpha_ce = params.alpha_ce
self.alpha_mlm = params.alpha_mlm self.alpha_mlm = params.alpha_mlm
...@@ -85,18 +81,18 @@ class Distiller: ...@@ -85,18 +81,18 @@ class Distiller:
self.mlm = params.mlm self.mlm = params.mlm
if self.mlm: if self.mlm:
logger.info(f'Using MLM loss for LM step.') logger.info(f"Using MLM loss for LM step.")
self.mlm_mask_prop = params.mlm_mask_prop self.mlm_mask_prop = params.mlm_mask_prop
assert 0.0 <= self.mlm_mask_prop <= 1.0 assert 0.0 <= self.mlm_mask_prop <= 1.0
assert params.word_mask + params.word_keep + params.word_rand == 1.0 assert params.word_mask + params.word_keep + params.word_rand == 1.0
self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand]) self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
self.pred_probs = self.pred_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else self.pred_probs self.pred_probs = self.pred_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else self.pred_probs
self.token_probs = token_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else token_probs self.token_probs = token_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else token_probs
if self.fp16: if self.fp16:
self.pred_probs = self.pred_probs.half() self.pred_probs = self.pred_probs.half()
self.token_probs = self.token_probs.half() self.token_probs = self.token_probs.half()
else: else:
logger.info(f'Using CLM loss for LM step.') logger.info(f"Using CLM loss for LM step.")
self.epoch = 0 self.epoch = 0
self.n_iter = 0 self.n_iter = 0
...@@ -107,38 +103,54 @@ class Distiller: ...@@ -107,38 +103,54 @@ class Distiller:
self.last_loss_ce = 0 self.last_loss_ce = 0
self.last_loss_mlm = 0 self.last_loss_mlm = 0
self.last_loss_clm = 0 self.last_loss_clm = 0
if self.alpha_mse > 0.: self.last_loss_mse = 0 if self.alpha_mse > 0.0:
if self.alpha_cos > 0.: self.last_loss_cos = 0 self.last_loss_mse = 0
if self.alpha_cos > 0.0:
self.last_loss_cos = 0
self.last_log = 0 self.last_log = 0
self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean') self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100) self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
if self.alpha_mse > 0.: if self.alpha_mse > 0.0:
self.mse_loss_fct = nn.MSELoss(reduction='sum') self.mse_loss_fct = nn.MSELoss(reduction="sum")
if self.alpha_cos > 0.: if self.alpha_cos > 0.0:
self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction='mean') self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")
logger.info('--- Initializing model optimizer') logger.info("--- Initializing model optimizer")
assert params.gradient_accumulation_steps >= 1 assert params.gradient_accumulation_steps >= 1
self.num_steps_epoch = len(self.dataloader) self.num_steps_epoch = len(self.dataloader)
num_train_optimization_steps = int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1 num_train_optimization_steps = (
int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1
)
no_decay = ['bias', 'LayerNorm.weight'] no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
{'params': [p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': params.weight_decay}, {
{'params': [p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': 0.0} "params": [
p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad
],
"weight_decay": params.weight_decay,
},
{
"params": [
p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad
],
"weight_decay": 0.0,
},
] ]
logger.info("------ Number of trainable parameters (student): %i" % sum([p.numel() for p in self.student.parameters() if p.requires_grad])) logger.info(
"------ Number of trainable parameters (student): %i"
% sum([p.numel() for p in self.student.parameters() if p.requires_grad])
)
logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()])) logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()]))
self.optimizer = AdamW(optimizer_grouped_parameters, self.optimizer = AdamW(
lr=params.learning_rate, optimizer_grouped_parameters, lr=params.learning_rate, eps=params.adam_epsilon, betas=(0.9, 0.98)
eps=params.adam_epsilon, )
betas=(0.9, 0.98))
warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop) warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
self.scheduler = get_linear_schedule_with_warmup(self.optimizer, self.scheduler = get_linear_schedule_with_warmup(
num_warmup_steps=warmup_steps, self.optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps
num_training_steps=num_train_optimization_steps) )
if self.fp16: if self.fp16:
try: try:
...@@ -146,33 +158,36 @@ class Distiller: ...@@ -146,33 +158,36 @@ class Distiller:
except ImportError: except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
logger.info(f"Using fp16 training: {self.params.fp16_opt_level} level") logger.info(f"Using fp16 training: {self.params.fp16_opt_level} level")
self.student, self.optimizer = amp.initialize(self.student, self.student, self.optimizer = amp.initialize(
self.optimizer, self.student, self.optimizer, opt_level=self.params.fp16_opt_level
opt_level=self.params.fp16_opt_level) )
self.teacher = self.teacher.half() self.teacher = self.teacher.half()
if self.multi_gpu: if self.multi_gpu:
if self.fp16: if self.fp16:
from apex.parallel import DistributedDataParallel from apex.parallel import DistributedDataParallel
logger.info("Using apex.parallel.DistributedDataParallel for distributed training.") logger.info("Using apex.parallel.DistributedDataParallel for distributed training.")
self.student = DistributedDataParallel(self.student) self.student = DistributedDataParallel(self.student)
else: else:
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
logger.info("Using nn.parallel.DistributedDataParallel for distributed training.") logger.info("Using nn.parallel.DistributedDataParallel for distributed training.")
self.student = DistributedDataParallel(self.student, self.student = DistributedDataParallel(
self.student,
device_ids=[params.local_rank], device_ids=[params.local_rank],
output_device=params.local_rank, output_device=params.local_rank,
find_unused_parameters=True) find_unused_parameters=True,
)
self.is_master = params.is_master self.is_master = params.is_master
if self.is_master: if self.is_master:
logger.info('--- Initializing Tensorboard') logger.info("--- Initializing Tensorboard")
self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, 'log', 'train')) self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, "log", "train"))
self.tensorboard.add_text(tag='config/training', text_string=str(self.params), global_step=0) self.tensorboard.add_text(tag="config/training", text_string=str(self.params), global_step=0)
self.tensorboard.add_text(tag='config/student', text_string=str(self.student_config), global_step=0) self.tensorboard.add_text(tag="config/student", text_string=str(self.student_config), global_step=0)
def prepare_batch_mlm(self, def prepare_batch_mlm(self, batch):
batch):
""" """
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM. Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.
...@@ -192,7 +207,7 @@ class Distiller: ...@@ -192,7 +207,7 @@ class Distiller:
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths) token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
assert token_ids.size(0) == lengths.size(0) assert token_ids.size(0) == lengths.size(0)
attn_mask = (torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]) attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
bs, max_seq_len = token_ids.size() bs, max_seq_len = token_ids.size()
mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids) mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
...@@ -200,11 +215,13 @@ class Distiller: ...@@ -200,11 +215,13 @@ class Distiller:
x_prob = self.token_probs[token_ids.flatten()] x_prob = self.token_probs[token_ids.flatten()]
n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item()) n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False) tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
pred_mask = torch.zeros(bs * max_seq_len, dtype=torch.bool, device=token_ids.device) # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility pred_mask = torch.zeros(
bs * max_seq_len, dtype=torch.bool, device=token_ids.device
) # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
pred_mask[tgt_ids] = 1 pred_mask[tgt_ids] = 1
pred_mask = pred_mask.view(bs, max_seq_len) pred_mask = pred_mask.view(bs, max_seq_len)
pred_mask[token_ids == self.params.special_tok_ids['pad_token']] = 0 pred_mask[token_ids == self.params.special_tok_ids["pad_token"]] = 0
# mask a number of words == 0 [8] (faster with fp16) # mask a number of words == 0 [8] (faster with fp16)
if self.fp16: if self.fp16:
...@@ -213,15 +230,19 @@ class Distiller: ...@@ -213,15 +230,19 @@ class Distiller:
pred_mask = pred_mask.view(-1) pred_mask = pred_mask.view(-1)
n2 = max(n1 % 8, 8 * (n1 // 8)) n2 = max(n1 % 8, 8 * (n1 // 8))
if n2 != n1: if n2 != n1:
pred_mask[torch.nonzero(pred_mask).view(-1)[:n1-n2]] = 0 pred_mask[torch.nonzero(pred_mask).view(-1)[: n1 - n2]] = 0
pred_mask = pred_mask.view(bs, max_seq_len) pred_mask = pred_mask.view(bs, max_seq_len)
assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item() assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()
_token_ids_real = token_ids[pred_mask] _token_ids_real = token_ids[pred_mask]
_token_ids_rand = _token_ids_real.clone().random_(self.vocab_size) _token_ids_rand = _token_ids_real.clone().random_(self.vocab_size)
_token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids['mask_token']) _token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids["mask_token"])
probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True) probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True)
_token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long() _token_ids = (
_token_ids_mask * (probs == 0).long()
+ _token_ids_real * (probs == 1).long()
+ _token_ids_rand * (probs == 2).long()
)
token_ids = token_ids.masked_scatter(pred_mask, _token_ids) token_ids = token_ids.masked_scatter(pred_mask, _token_ids)
mlm_labels[~pred_mask] = -100 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility mlm_labels[~pred_mask] = -100 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
...@@ -231,8 +252,7 @@ class Distiller: ...@@ -231,8 +252,7 @@ class Distiller:
return token_ids, attn_mask, mlm_labels return token_ids, attn_mask, mlm_labels
def prepare_batch_clm(self, def prepare_batch_clm(self, batch):
batch):
""" """
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the labels for CLM. Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the labels for CLM.
...@@ -252,7 +272,7 @@ class Distiller: ...@@ -252,7 +272,7 @@ class Distiller:
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths) token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
assert token_ids.size(0) == lengths.size(0) assert token_ids.size(0) == lengths.size(0)
attn_mask = (torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]) attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
clm_labels = token_ids.new(token_ids.size()).copy_(token_ids) clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
clm_labels[~attn_mask] = -100 # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility clm_labels[~attn_mask] = -100 # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
...@@ -261,9 +281,7 @@ class Distiller: ...@@ -261,9 +281,7 @@ class Distiller:
return token_ids, attn_mask, clm_labels return token_ids, attn_mask, clm_labels
def round_batch(self, def round_batch(self, x: torch.tensor, lengths: torch.tensor):
x: torch.tensor,
lengths: torch.tensor):
""" """
For float16 only. For float16 only.
Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8. Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.
...@@ -299,9 +317,9 @@ class Distiller: ...@@ -299,9 +317,9 @@ class Distiller:
pad = 8 - (ml1 % 8) pad = 8 - (ml1 % 8)
ml2 = ml1 + pad ml2 = ml1 + pad
if self.mlm: if self.mlm:
pad_id = self.params.special_tok_ids['pad_token'] pad_id = self.params.special_tok_ids["pad_token"]
else: else:
pad_id = self.params.special_tok_ids['unk_token'] pad_id = self.params.special_tok_ids["unk_token"]
padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id) padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id)
x = torch.cat([x, padding_tensor], 1) x = torch.cat([x, padding_tensor], 1)
assert x.size() == (bs2, ml2) assert x.size() == (bs2, ml2)
...@@ -314,20 +332,22 @@ class Distiller: ...@@ -314,20 +332,22 @@ class Distiller:
""" """
The real training loop. The real training loop.
""" """
if self.is_master: logger.info('Starting training') if self.is_master:
logger.info("Starting training")
self.last_log = time.time() self.last_log = time.time()
self.student.train() self.student.train()
self.teacher.eval() self.teacher.eval()
for _ in range(self.params.n_epoch): for _ in range(self.params.n_epoch):
if self.is_master: logger.info(f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}') if self.is_master:
logger.info(f"--- Starting epoch {self.epoch}/{self.params.n_epoch-1}")
if self.multi_gpu: if self.multi_gpu:
torch.distributed.barrier() torch.distributed.barrier()
iter_bar = tqdm(self.dataloader, desc="-Iter", disable=self.params.local_rank not in [-1, 0]) iter_bar = tqdm(self.dataloader, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
for batch in iter_bar: for batch in iter_bar:
if self.params.n_gpu > 0: if self.params.n_gpu > 0:
batch = tuple(t.to(f'cuda:{self.params.local_rank}') for t in batch) batch = tuple(t.to(f"cuda:{self.params.local_rank}") for t in batch)
if self.mlm: if self.mlm:
token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(batch=batch) token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(batch=batch)
...@@ -336,22 +356,21 @@ class Distiller: ...@@ -336,22 +356,21 @@ class Distiller:
self.step(input_ids=token_ids, attention_mask=attn_mask, lm_labels=lm_labels) self.step(input_ids=token_ids, attention_mask=attn_mask, lm_labels=lm_labels)
iter_bar.update() iter_bar.update()
iter_bar.set_postfix({'Last_loss': f'{self.last_loss:.2f}', iter_bar.set_postfix(
'Avg_cum_loss': f'{self.total_loss_epoch/self.n_iter:.2f}'}) {"Last_loss": f"{self.last_loss:.2f}", "Avg_cum_loss": f"{self.total_loss_epoch/self.n_iter:.2f}"}
)
iter_bar.close() iter_bar.close()
if self.is_master: logger.info(f'--- Ending epoch {self.epoch}/{self.params.n_epoch-1}') if self.is_master:
logger.info(f"--- Ending epoch {self.epoch}/{self.params.n_epoch-1}")
self.end_epoch() self.end_epoch()
if self.is_master: if self.is_master:
logger.info(f'Save very last checkpoint as `pytorch_model.bin`.') logger.info(f"Save very last checkpoint as `pytorch_model.bin`.")
self.save_checkpoint(checkpoint_name=f'pytorch_model.bin') self.save_checkpoint(checkpoint_name=f"pytorch_model.bin")
logger.info('Training is finished') logger.info("Training is finished")
def step(self, def step(self, input_ids: torch.tensor, attention_mask: torch.tensor, lm_labels: torch.tensor):
input_ids: torch.tensor,
attention_mask: torch.tensor,
lm_labels: torch.tensor):
""" """
One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation), One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
and possibly a parameter update (depending on the gradient accumulation). and possibly a parameter update (depending on the gradient accumulation).
...@@ -363,19 +382,27 @@ class Distiller: ...@@ -363,19 +382,27 @@ class Distiller:
lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM). lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
""" """
if self.mlm: if self.mlm:
s_logits, s_hidden_states = self.student(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size) s_logits, s_hidden_states = self.student(
input_ids=input_ids, attention_mask=attention_mask
) # (bs, seq_length, voc_size)
with torch.no_grad(): with torch.no_grad():
t_logits, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size) t_logits, t_hidden_states = self.teacher(
input_ids=input_ids, attention_mask=attention_mask
) # (bs, seq_length, voc_size)
else: else:
s_logits, _, s_hidden_states = self.student(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size) s_logits, _, s_hidden_states = self.student(
input_ids=input_ids, attention_mask=None
) # (bs, seq_length, voc_size)
with torch.no_grad(): with torch.no_grad():
t_logits, _, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size) t_logits, _, t_hidden_states = self.teacher(
input_ids=input_ids, attention_mask=None
) # (bs, seq_length, voc_size)
assert s_logits.size() == t_logits.size() assert s_logits.size() == t_logits.size()
#https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 # https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
#https://github.com/peterliht/knowledge-distillation-pytorch/issues/2 # https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
if self.params.restrict_ce_to_mask: if self.params.restrict_ce_to_mask:
mask = (lm_labels>-1).unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size) mask = (lm_labels > -1).unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size)
else: else:
mask = attention_mask.unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size) mask = attention_mask.unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size)
s_logits_slct = torch.masked_select(s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask s_logits_slct = torch.masked_select(s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask
...@@ -384,24 +411,30 @@ class Distiller: ...@@ -384,24 +411,30 @@ class Distiller:
t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
assert t_logits_slct.size() == s_logits_slct.size() assert t_logits_slct.size() == s_logits_slct.size()
loss_ce = self.ce_loss_fct(F.log_softmax(s_logits_slct/self.temperature, dim=-1), loss_ce = (
F.softmax(t_logits_slct/self.temperature, dim=-1)) * (self.temperature)**2 self.ce_loss_fct(
loss = self.alpha_ce*loss_ce F.log_softmax(s_logits_slct / self.temperature, dim=-1),
F.softmax(t_logits_slct / self.temperature, dim=-1),
)
* (self.temperature) ** 2
)
loss = self.alpha_ce * loss_ce
if self.alpha_mlm > 0.: if self.alpha_mlm > 0.0:
loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), lm_labels.view(-1)) loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), lm_labels.view(-1))
loss += self.alpha_mlm * loss_mlm loss += self.alpha_mlm * loss_mlm
if self.alpha_clm > 0.: if self.alpha_clm > 0.0:
shift_logits = s_logits[..., :-1, :].contiguous() shift_logits = s_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[..., 1:].contiguous() shift_labels = lm_labels[..., 1:].contiguous()
loss_clm = self.lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)), loss_clm = self.lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
shift_labels.view(-1))
loss += self.alpha_clm * loss_clm loss += self.alpha_clm * loss_clm
if self.alpha_mse > 0.: if self.alpha_mse > 0.0:
loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct)/s_logits_slct.size(0) # Reproducing batchmean reduction loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct) / s_logits_slct.size(
0
) # Reproducing batchmean reduction
loss += self.alpha_mse * loss_mse loss += self.alpha_mse * loss_mse
if self.alpha_cos > 0.: if self.alpha_cos > 0.0:
s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim) s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim)
t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim) t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim)
mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states) # (bs, seq_length, dim) mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states) # (bs, seq_length, dim)
...@@ -420,21 +453,20 @@ class Distiller: ...@@ -420,21 +453,20 @@ class Distiller:
self.total_loss_epoch += loss.item() self.total_loss_epoch += loss.item()
self.last_loss = loss.item() self.last_loss = loss.item()
self.last_loss_ce = loss_ce.item() self.last_loss_ce = loss_ce.item()
if self.alpha_mlm > 0.: if self.alpha_mlm > 0.0:
self.last_loss_mlm = loss_mlm.item() self.last_loss_mlm = loss_mlm.item()
if self.alpha_clm > 0.: if self.alpha_clm > 0.0:
self.last_loss_clm = loss_clm.item() self.last_loss_clm = loss_clm.item()
if self.alpha_mse > 0.: if self.alpha_mse > 0.0:
self.last_loss_mse = loss_mse.item() self.last_loss_mse = loss_mse.item()
if self.alpha_cos > 0.: if self.alpha_cos > 0.0:
self.last_loss_cos = loss_cos.item() self.last_loss_cos = loss_cos.item()
self.optimize(loss) self.optimize(loss)
self.n_sequences_epoch += input_ids.size(0) self.n_sequences_epoch += input_ids.size(0)
def optimize(self, def optimize(self, loss):
loss):
""" """
Normalization on the loss (gradient accumulation or distributed training), followed by Normalization on the loss (gradient accumulation or distributed training), followed by
backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation). backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
...@@ -442,7 +474,7 @@ class Distiller: ...@@ -442,7 +474,7 @@ class Distiller:
""" """
# Check for NaN # Check for NaN
if (loss != loss).data.any(): if (loss != loss).data.any():
logger.error('NaN detected') logger.error("NaN detected")
exit() exit()
if self.multi_gpu: if self.multi_gpu:
...@@ -452,6 +484,7 @@ class Distiller: ...@@ -452,6 +484,7 @@ class Distiller:
if self.fp16: if self.fp16:
from apex import amp from apex import amp
with amp.scale_loss(loss, self.optimizer) as scaled_loss: with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
else: else:
...@@ -488,53 +521,84 @@ class Distiller: ...@@ -488,53 +521,84 @@ class Distiller:
return return
for param_name, param in self.student.named_parameters(): for param_name, param in self.student.named_parameters():
self.tensorboard.add_scalar(tag='parameter_mean/' + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter) self.tensorboard.add_scalar(
self.tensorboard.add_scalar(tag='parameter_std/' + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter) tag="parameter_mean/" + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter
)
self.tensorboard.add_scalar(
tag="parameter_std/" + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter
)
if param.grad is None: if param.grad is None:
continue continue
self.tensorboard.add_scalar(tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(),global_step=self.n_total_iter) self.tensorboard.add_scalar(
self.tensorboard.add_scalar(tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter) tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(), global_step=self.n_total_iter
)
self.tensorboard.add_scalar(tag="losses/cum_avg_loss_epoch", scalar_value=self.total_loss_epoch/self.n_iter, global_step=self.n_total_iter) self.tensorboard.add_scalar(
tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter
)
self.tensorboard.add_scalar(
tag="losses/cum_avg_loss_epoch",
scalar_value=self.total_loss_epoch / self.n_iter,
global_step=self.n_total_iter,
)
self.tensorboard.add_scalar(tag="losses/loss", scalar_value=self.last_loss, global_step=self.n_total_iter) self.tensorboard.add_scalar(tag="losses/loss", scalar_value=self.last_loss, global_step=self.n_total_iter)
self.tensorboard.add_scalar(tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter) self.tensorboard.add_scalar(
if self.alpha_mlm > 0.: tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter
self.tensorboard.add_scalar(tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter) )
if self.alpha_clm > 0.: if self.alpha_mlm > 0.0:
self.tensorboard.add_scalar(tag="losses/loss_clm", scalar_value=self.last_loss_clm, global_step=self.n_total_iter) self.tensorboard.add_scalar(
if self.alpha_mse > 0.: tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter
self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter) )
if self.alpha_cos > 0.: if self.alpha_clm > 0.0:
self.tensorboard.add_scalar(tag="losses/loss_cos", scalar_value=self.last_loss_cos, global_step=self.n_total_iter) self.tensorboard.add_scalar(
self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter) tag="losses/loss_clm", scalar_value=self.last_loss_clm, global_step=self.n_total_iter
)
self.tensorboard.add_scalar(tag="global/memory_usage", scalar_value=psutil.virtual_memory()._asdict()['used']/1_000_000, global_step=self.n_total_iter) if self.alpha_mse > 0.0:
self.tensorboard.add_scalar(tag="global/speed", scalar_value=time.time()-self.last_log, global_step=self.n_total_iter) self.tensorboard.add_scalar(
tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter
)
if self.alpha_cos > 0.0:
self.tensorboard.add_scalar(
tag="losses/loss_cos", scalar_value=self.last_loss_cos, global_step=self.n_total_iter
)
self.tensorboard.add_scalar(
tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter
)
self.tensorboard.add_scalar(
tag="global/memory_usage",
scalar_value=psutil.virtual_memory()._asdict()["used"] / 1_000_000,
global_step=self.n_total_iter,
)
self.tensorboard.add_scalar(
tag="global/speed", scalar_value=time.time() - self.last_log, global_step=self.n_total_iter
)
def end_epoch(self): def end_epoch(self):
""" """
Finally arrived at the end of epoch (full pass on dataset). Finally arrived at the end of epoch (full pass on dataset).
Do some tensorboard logging and checkpoint saving. Do some tensorboard logging and checkpoint saving.
""" """
logger.info(f'{self.n_sequences_epoch} sequences have been trained during this epoch.') logger.info(f"{self.n_sequences_epoch} sequences have been trained during this epoch.")
if self.is_master: if self.is_master:
self.save_checkpoint(checkpoint_name=f'model_epoch_{self.epoch}.pth') self.save_checkpoint(checkpoint_name=f"model_epoch_{self.epoch}.pth")
self.tensorboard.add_scalar(tag='epoch/loss', scalar_value=self.total_loss_epoch/self.n_iter, global_step=self.epoch) self.tensorboard.add_scalar(
tag="epoch/loss", scalar_value=self.total_loss_epoch / self.n_iter, global_step=self.epoch
)
self.epoch += 1 self.epoch += 1
self.n_sequences_epoch = 0 self.n_sequences_epoch = 0
self.n_iter = 0 self.n_iter = 0
self.total_loss_epoch = 0 self.total_loss_epoch = 0
def save_checkpoint(self, def save_checkpoint(self, checkpoint_name: str = "checkpoint.pth"):
checkpoint_name: str = 'checkpoint.pth'):
""" """
Save the current state. Only by the master process. Save the current state. Only by the master process.
""" """
if not self.is_master: if not self.is_master:
return return
mdl_to_save = self.student.module if hasattr(self.student, 'module') else self.student mdl_to_save = self.student.module if hasattr(self.student, "module") else self.student
mdl_to_save.config.save_pretrained(self.dump_path) mdl_to_save.config.save_pretrained(self.dump_path)
state_dict = mdl_to_save.state_dict() state_dict = mdl_to_save.state_dict()
torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name)) torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))
...@@ -23,12 +23,14 @@ from torch.utils.data.sampler import BatchSampler, Sampler ...@@ -23,12 +23,14 @@ from torch.utils.data.sampler import BatchSampler, Sampler
from utils import logger from utils import logger
def _quantize(x, bins): def _quantize(x, bins):
bins = copy.deepcopy(bins) bins = copy.deepcopy(bins)
bins = sorted(bins) bins = sorted(bins)
quantized = list(map(lambda y: bisect.bisect_right(bins, y), x)) quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
return quantized return quantized
def create_lengths_groups(lengths, k=0): def create_lengths_groups(lengths, k=0):
bins = np.arange(start=3, stop=k, step=4).tolist() if k > 0 else [10] bins = np.arange(start=3, stop=k, step=4).tolist() if k > 0 else [10]
groups = _quantize(lengths, bins) groups = _quantize(lengths, bins)
...@@ -39,6 +41,7 @@ def create_lengths_groups(lengths, k=0): ...@@ -39,6 +41,7 @@ def create_lengths_groups(lengths, k=0):
logger.info("Count of instances per bin: {}".format(counts)) logger.info("Count of instances per bin: {}".format(counts))
return groups return groups
class GroupedBatchSampler(BatchSampler): class GroupedBatchSampler(BatchSampler):
""" """
Wraps another sampler to yield a mini-batch of indices. Wraps another sampler to yield a mini-batch of indices.
...@@ -53,11 +56,11 @@ class GroupedBatchSampler(BatchSampler): ...@@ -53,11 +56,11 @@ class GroupedBatchSampler(BatchSampler):
0, i.e. they must be in the range [0, num_groups). 0, i.e. they must be in the range [0, num_groups).
batch_size (int): Size of mini-batch. batch_size (int): Size of mini-batch.
""" """
def __init__(self, sampler, group_ids, batch_size): def __init__(self, sampler, group_ids, batch_size):
if not isinstance(sampler, Sampler): if not isinstance(sampler, Sampler):
raise ValueError( raise ValueError(
"sampler should be an instance of " "sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}".format(sampler)
"torch.utils.data.Sampler, but got sampler={}".format(sampler)
) )
self.sampler = sampler self.sampler = sampler
self.group_ids = group_ids self.group_ids = group_ids
...@@ -73,7 +76,7 @@ class GroupedBatchSampler(BatchSampler): ...@@ -73,7 +76,7 @@ class GroupedBatchSampler(BatchSampler):
buffer_per_group[group_id].append(idx) buffer_per_group[group_id].append(idx)
samples_per_group[group_id].append(idx) samples_per_group[group_id].append(idx)
if len(buffer_per_group[group_id]) == self.batch_size: if len(buffer_per_group[group_id]) == self.batch_size:
yield buffer_per_group[group_id] #TODO yield buffer_per_group[group_id] # TODO
num_batches += 1 num_batches += 1
del buffer_per_group[group_id] del buffer_per_group[group_id]
assert len(buffer_per_group[group_id]) < self.batch_size assert len(buffer_per_group[group_id]) < self.batch_size
...@@ -90,8 +93,8 @@ class GroupedBatchSampler(BatchSampler): ...@@ -90,8 +93,8 @@ class GroupedBatchSampler(BatchSampler):
for group_id, idxs in sorted(buffer_per_group.items(), key=lambda x: x[0]): for group_id, idxs in sorted(buffer_per_group.items(), key=lambda x: x[0]):
batch_idx.extend(idxs) batch_idx.extend(idxs)
if len(batch_idx) >= self.batch_size: if len(batch_idx) >= self.batch_size:
yield batch_idx[:self.batch_size] yield batch_idx[: self.batch_size]
batch_idx = batch_idx[self.batch_size:] batch_idx = batch_idx[self.batch_size :]
num_remaining -= 1 num_remaining -= 1
if len(batch_idx) > 0: if len(batch_idx) > 0:
yield batch_idx yield batch_idx
......
...@@ -21,6 +21,7 @@ from torch.utils.data import Dataset ...@@ -21,6 +21,7 @@ from torch.utils.data import Dataset
import numpy as np import numpy as np
from utils import logger from utils import logger
class LmSeqsDataset(Dataset): class LmSeqsDataset(Dataset):
"""Custom Dataset wrapping language modeling sequences. """Custom Dataset wrapping language modeling sequences.
...@@ -32,9 +33,7 @@ class LmSeqsDataset(Dataset): ...@@ -32,9 +33,7 @@ class LmSeqsDataset(Dataset):
data: `List[np.array[int]] data: `List[np.array[int]]
""" """
def __init__(self, def __init__(self, params, data):
params,
data):
self.params = params self.params = params
self.token_ids = np.array(data) self.token_ids = np.array(data)
...@@ -65,17 +64,17 @@ class LmSeqsDataset(Dataset): ...@@ -65,17 +64,17 @@ class LmSeqsDataset(Dataset):
""" """
max_len = self.params.max_model_input_size max_len = self.params.max_model_input_size
indices = self.lengths > max_len indices = self.lengths > max_len
logger.info(f'Splitting {sum(indices)} too long sequences.') logger.info(f"Splitting {sum(indices)} too long sequences.")
def divide_chunks(l, n): def divide_chunks(l, n):
return [l[i:i + n] for i in range(0, len(l), n)] return [l[i : i + n] for i in range(0, len(l), n)]
new_tok_ids = [] new_tok_ids = []
new_lengths = [] new_lengths = []
if self.params.mlm: if self.params.mlm:
cls_id, sep_id = self.params.special_tok_ids['cls_token'], self.params.special_tok_ids['sep_token'] cls_id, sep_id = self.params.special_tok_ids["cls_token"], self.params.special_tok_ids["sep_token"]
else: else:
cls_id, sep_id = self.params.special_tok_ids['bos_token'], self.params.special_tok_ids['eos_token'] cls_id, sep_id = self.params.special_tok_ids["bos_token"], self.params.special_tok_ids["eos_token"]
for seq_, len_ in zip(self.token_ids, self.lengths): for seq_, len_ in zip(self.token_ids, self.lengths):
assert (seq_[0] == cls_id) and (seq_[-1] == sep_id), seq_ assert (seq_[0] == cls_id) and (seq_[-1] == sep_id), seq_
...@@ -84,7 +83,7 @@ class LmSeqsDataset(Dataset): ...@@ -84,7 +83,7 @@ class LmSeqsDataset(Dataset):
new_lengths.append(len_) new_lengths.append(len_)
else: else:
sub_seqs = [] sub_seqs = []
for sub_s in divide_chunks(seq_, max_len-2): for sub_s in divide_chunks(seq_, max_len - 2):
if sub_s[0] != cls_id: if sub_s[0] != cls_id:
sub_s = np.insert(sub_s, 0, cls_id) sub_s = np.insert(sub_s, 0, cls_id)
if sub_s[-1] != sep_id: if sub_s[-1] != sep_id:
...@@ -108,7 +107,7 @@ class LmSeqsDataset(Dataset): ...@@ -108,7 +107,7 @@ class LmSeqsDataset(Dataset):
self.token_ids = self.token_ids[indices] self.token_ids = self.token_ids[indices]
self.lengths = self.lengths[indices] self.lengths = self.lengths[indices]
new_size = len(self) new_size = len(self)
logger.info(f'Remove {init_size - new_size} too short (<=11 tokens) sequences.') logger.info(f"Remove {init_size - new_size} too short (<=11 tokens) sequences.")
def print_statistics(self): def print_statistics(self):
""" """
...@@ -116,7 +115,7 @@ class LmSeqsDataset(Dataset): ...@@ -116,7 +115,7 @@ class LmSeqsDataset(Dataset):
""" """
if not self.params.is_master: if not self.params.is_master:
return return
logger.info(f'{len(self)} sequences') logger.info(f"{len(self)} sequences")
# data_len = sum(self.lengths) # data_len = sum(self.lengths)
# nb_unique_tokens = len(Counter(list(chain(*self.token_ids)))) # nb_unique_tokens = len(Counter(list(chain(*self.token_ids))))
# logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)') # logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)')
...@@ -125,8 +124,7 @@ class LmSeqsDataset(Dataset): ...@@ -125,8 +124,7 @@ class LmSeqsDataset(Dataset):
# nb_unkown = sum([(t==unk_idx).sum() for t in self.token_ids]) # nb_unkown = sum([(t==unk_idx).sum() for t in self.token_ids])
# logger.info(f'{nb_unkown} unknown tokens (covering {100*nb_unkown/data_len:.2f}% of the data)') # logger.info(f'{nb_unkown} unknown tokens (covering {100*nb_unkown/data_len:.2f}% of the data)')
def batch_sequences(self, def batch_sequences(self, batch):
batch):
""" """
Do the padding and transform into torch.tensor. Do the padding and transform into torch.tensor.
""" """
...@@ -139,10 +137,10 @@ class LmSeqsDataset(Dataset): ...@@ -139,10 +137,10 @@ class LmSeqsDataset(Dataset):
# Pad token ids # Pad token ids
if self.params.mlm: if self.params.mlm:
pad_idx = self.params.special_tok_ids['pad_token'] pad_idx = self.params.special_tok_ids["pad_token"]
else: else:
pad_idx = self.params.special_tok_ids['unk_token'] pad_idx = self.params.special_tok_ids["unk_token"]
tk_ = [list(t.astype(int)) + [pad_idx]*(max_seq_len_-len(t)) for t in token_ids] tk_ = [list(t.astype(int)) + [pad_idx] * (max_seq_len_ - len(t)) for t in token_ids]
assert len(tk_) == len(token_ids) assert len(tk_) == len(token_ids)
assert all(len(t) == max_seq_len_ for t in tk_) assert all(len(t) == max_seq_len_ for t in tk_)
......
...@@ -25,8 +25,7 @@ import glob ...@@ -25,8 +25,7 @@ import glob
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
TensorDataset)
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
import torch.nn.functional as F import torch.nn.functional as F
import torch.nn as nn import torch.nn as nn
...@@ -38,19 +37,32 @@ except: ...@@ -38,19 +37,32 @@ except:
from tqdm import tqdm, trange from tqdm import tqdm, trange
from transformers import (WEIGHTS_NAME, BertConfig, from transformers import (
BertForQuestionAnswering, BertTokenizer, WEIGHTS_NAME,
XLMConfig, XLMForQuestionAnswering, BertConfig,
XLMTokenizer, XLNetConfig, BertForQuestionAnswering,
BertTokenizer,
XLMConfig,
XLMForQuestionAnswering,
XLMTokenizer,
XLNetConfig,
XLNetForQuestionAnswering, XLNetForQuestionAnswering,
XLNetTokenizer, XLNetTokenizer,
DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer) DistilBertConfig,
DistilBertForQuestionAnswering,
DistilBertTokenizer,
)
from transformers import AdamW, get_linear_schedule_with_warmup from transformers import AdamW, get_linear_schedule_with_warmup
from ..utils_squad import (read_squad_examples, convert_examples_to_features, from ..utils_squad import (
RawResult, write_predictions, read_squad_examples,
RawResultExtended, write_predictions_extended) convert_examples_to_features,
RawResult,
write_predictions,
RawResultExtended,
write_predictions_extended,
)
# The follwing import is the official SQuAD evaluation script (2.0). # The follwing import is the official SQuAD evaluation script (2.0).
# You can remove it from the dependencies if you are using this script outside of the library # You can remove it from the dependencies if you are using this script outside of the library
...@@ -59,16 +71,18 @@ from ..utils_squad_evaluate import EVAL_OPTS, main as evaluate_on_squad ...@@ -59,16 +71,18 @@ from ..utils_squad_evaluate import EVAL_OPTS, main as evaluate_on_squad
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \ ALL_MODELS = sum(
for conf in (BertConfig, XLNetConfig, XLMConfig)), ()) (tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig)), ()
)
MODEL_CLASSES = { MODEL_CLASSES = {
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer), "bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer), "xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer), "xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer) "distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
} }
def set_seed(args): def set_seed(args):
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
...@@ -76,9 +90,11 @@ def set_seed(args): ...@@ -76,9 +90,11 @@ def set_seed(args):
if args.n_gpu > 0: if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed) torch.cuda.manual_seed_all(args.seed)
def to_list(tensor): def to_list(tensor):
return tensor.detach().cpu().tolist() return tensor.detach().cpu().tolist()
def train(args, train_dataset, model, tokenizer, teacher=None): def train(args, train_dataset, model, tokenizer, teacher=None):
""" Train the model """ """ Train the model """
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
...@@ -95,13 +111,18 @@ def train(args, train_dataset, model, tokenizer, teacher=None): ...@@ -95,13 +111,18 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
# Prepare optimizer and schedule (linear warmup and decay) # Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight'] no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, {
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
] ]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)
if args.fp16: if args.fp16:
try: try:
from apex import amp from apex import amp
...@@ -115,17 +136,21 @@ def train(args, train_dataset, model, tokenizer, teacher=None): ...@@ -115,17 +136,21 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
# Distributed training (should be after apex fp16 initialization) # Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], model = torch.nn.parallel.DistributedDataParallel(
output_device=args.local_rank, model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
find_unused_parameters=True) )
# Train! # Train!
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", logger.info(
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) " Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size
* args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
)
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total) logger.info(" Total optimization steps = %d", t_total)
...@@ -141,37 +166,44 @@ def train(args, train_dataset, model, tokenizer, teacher=None): ...@@ -141,37 +166,44 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
if teacher is not None: if teacher is not None:
teacher.eval() teacher.eval()
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0], inputs = {
'attention_mask': batch[1], "input_ids": batch[0],
'start_positions': batch[3], "attention_mask": batch[1],
'end_positions': batch[4]} "start_positions": batch[3],
if args.model_type != 'distilbert': "end_positions": batch[4],
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] }
if args.model_type in ['xlnet', 'xlm']: if args.model_type != "distilbert":
inputs.update({'cls_index': batch[5], inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2]
'p_mask': batch[6]}) if args.model_type in ["xlnet", "xlm"]:
inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
outputs = model(**inputs) outputs = model(**inputs)
loss, start_logits_stu, end_logits_stu = outputs loss, start_logits_stu, end_logits_stu = outputs
# Distillation loss # Distillation loss
if teacher is not None: if teacher is not None:
if 'token_type_ids' not in inputs: if "token_type_ids" not in inputs:
inputs['token_type_ids'] = None if args.teacher_type == 'xlm' else batch[2] inputs["token_type_ids"] = None if args.teacher_type == "xlm" else batch[2]
with torch.no_grad(): with torch.no_grad():
start_logits_tea, end_logits_tea = teacher(input_ids=inputs['input_ids'], start_logits_tea, end_logits_tea = teacher(
token_type_ids=inputs['token_type_ids'], input_ids=inputs["input_ids"],
attention_mask=inputs['attention_mask']) token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"],
)
assert start_logits_tea.size() == start_logits_stu.size() assert start_logits_tea.size() == start_logits_stu.size()
assert end_logits_tea.size() == end_logits_stu.size() assert end_logits_tea.size() == end_logits_stu.size()
loss_fct = nn.KLDivLoss(reduction='batchmean') loss_fct = nn.KLDivLoss(reduction="batchmean")
loss_start = loss_fct(F.log_softmax(start_logits_stu/args.temperature, dim=-1), loss_start = loss_fct(
F.softmax(start_logits_tea/args.temperature, dim=-1)) * (args.temperature**2) F.log_softmax(start_logits_stu / args.temperature, dim=-1),
loss_end = loss_fct(F.log_softmax(end_logits_stu/args.temperature, dim=-1), F.softmax(start_logits_tea / args.temperature, dim=-1),
F.softmax(end_logits_tea/args.temperature, dim=-1)) * (args.temperature**2) ) * (args.temperature ** 2)
loss_ce = (loss_start + loss_end)/2. loss_end = loss_fct(
F.log_softmax(end_logits_stu / args.temperature, dim=-1),
F.softmax(end_logits_tea / args.temperature, dim=-1),
) * (args.temperature ** 2)
loss_ce = (loss_start + loss_end) / 2.0
loss = args.alpha_ce*loss_ce + args.alpha_squad*loss loss = args.alpha_ce * loss_ce + args.alpha_squad * loss
if args.n_gpu > 1: if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
...@@ -195,22 +227,26 @@ def train(args, train_dataset, model, tokenizer, teacher=None): ...@@ -195,22 +227,26 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
# Log metrics # Log metrics
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well if (
args.local_rank == -1 and args.evaluate_during_training
): # Only evaluate when single GPU otherwise metrics may not average well
results = evaluate(args, model, tokenizer) results = evaluate(args, model, tokenizer)
for key, value in results.items(): for key, value in results.items():
tb_writer.add_scalar('eval_{}'.format(key), value, global_step) tb_writer.add_scalar("eval_{}".format(key), value, global_step)
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step) tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
logging_loss = tr_loss logging_loss = tr_loss
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
# Save model checkpoint # Save model checkpoint
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir) model_to_save.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, 'training_args.bin')) torch.save(args, os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to %s", output_dir) logger.info("Saving model checkpoint to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps: if args.max_steps > 0 and global_step > args.max_steps:
...@@ -246,32 +282,31 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -246,32 +282,31 @@ def evaluate(args, model, tokenizer, prefix=""):
model.eval() model.eval()
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
with torch.no_grad(): with torch.no_grad():
inputs = {'input_ids': batch[0], inputs = {"input_ids": batch[0], "attention_mask": batch[1]}
'attention_mask': batch[1] if args.model_type != "distilbert":
} inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2] # XLM don't use segment_ids
if args.model_type != 'distilbert':
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
example_indices = batch[3] example_indices = batch[3]
if args.model_type in ['xlnet', 'xlm']: if args.model_type in ["xlnet", "xlm"]:
inputs.update({'cls_index': batch[4], inputs.update({"cls_index": batch[4], "p_mask": batch[5]})
'p_mask': batch[5]})
outputs = model(**inputs) outputs = model(**inputs)
for i, example_index in enumerate(example_indices): for i, example_index in enumerate(example_indices):
eval_feature = features[example_index.item()] eval_feature = features[example_index.item()]
unique_id = int(eval_feature.unique_id) unique_id = int(eval_feature.unique_id)
if args.model_type in ['xlnet', 'xlm']: if args.model_type in ["xlnet", "xlm"]:
# XLNet uses a more complex post-processing procedure # XLNet uses a more complex post-processing procedure
result = RawResultExtended(unique_id = unique_id, result = RawResultExtended(
start_top_log_probs = to_list(outputs[0][i]), unique_id=unique_id,
start_top_index = to_list(outputs[1][i]), start_top_log_probs=to_list(outputs[0][i]),
end_top_log_probs = to_list(outputs[2][i]), start_top_index=to_list(outputs[1][i]),
end_top_index = to_list(outputs[3][i]), end_top_log_probs=to_list(outputs[2][i]),
cls_logits = to_list(outputs[4][i])) end_top_index=to_list(outputs[3][i]),
cls_logits=to_list(outputs[4][i]),
)
else: else:
result = RawResult(unique_id = unique_id, result = RawResult(
start_logits = to_list(outputs[0][i]), unique_id=unique_id, start_logits=to_list(outputs[0][i]), end_logits=to_list(outputs[1][i])
end_logits = to_list(outputs[1][i])) )
all_results.append(result) all_results.append(result)
# Compute predictions # Compute predictions
...@@ -282,23 +317,44 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -282,23 +317,44 @@ def evaluate(args, model, tokenizer, prefix=""):
else: else:
output_null_log_odds_file = None output_null_log_odds_file = None
if args.model_type in ['xlnet', 'xlm']: if args.model_type in ["xlnet", "xlm"]:
# XLNet uses a more complex post-processing procedure # XLNet uses a more complex post-processing procedure
write_predictions_extended(examples, features, all_results, args.n_best_size, write_predictions_extended(
args.max_answer_length, output_prediction_file, examples,
output_nbest_file, output_null_log_odds_file, args.predict_file, features,
model.config.start_n_top, model.config.end_n_top, all_results,
args.version_2_with_negative, tokenizer, args.verbose_logging) args.n_best_size,
args.max_answer_length,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
args.predict_file,
model.config.start_n_top,
model.config.end_n_top,
args.version_2_with_negative,
tokenizer,
args.verbose_logging,
)
else: else:
write_predictions(examples, features, all_results, args.n_best_size, write_predictions(
args.max_answer_length, args.do_lower_case, output_prediction_file, examples,
output_nbest_file, output_null_log_odds_file, args.verbose_logging, features,
args.version_2_with_negative, args.null_score_diff_threshold) all_results,
args.n_best_size,
args.max_answer_length,
args.do_lower_case,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
args.verbose_logging,
args.version_2_with_negative,
args.null_score_diff_threshold,
)
# Evaluate with the official SQuAD script # Evaluate with the official SQuAD script
evaluate_options = EVAL_OPTS(data_file=args.predict_file, evaluate_options = EVAL_OPTS(
pred_file=output_prediction_file, data_file=args.predict_file, pred_file=output_prediction_file, na_prob_file=output_null_log_odds_file
na_prob_file=output_null_log_odds_file) )
results = evaluate_on_squad(evaluate_options) results = evaluate_on_squad(evaluate_options)
return results return results
...@@ -309,24 +365,30 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ...@@ -309,24 +365,30 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
# Load data features from cache or dataset file # Load data features from cache or dataset file
input_file = args.predict_file if evaluate else args.train_file input_file = args.predict_file if evaluate else args.train_file
cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format( cached_features_file = os.path.join(
'dev' if evaluate else 'train', os.path.dirname(input_file),
list(filter(None, args.model_name_or_path.split('/'))).pop(), "cached_{}_{}_{}".format(
str(args.max_seq_length))) "dev" if evaluate else "train",
list(filter(None, args.model_name_or_path.split("/"))).pop(),
str(args.max_seq_length),
),
)
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples: if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
logger.info("Loading features from cached file %s", cached_features_file) logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file) features = torch.load(cached_features_file)
else: else:
logger.info("Creating features from dataset file at %s", input_file) logger.info("Creating features from dataset file at %s", input_file)
examples = read_squad_examples(input_file=input_file, examples = read_squad_examples(
is_training=not evaluate, input_file=input_file, is_training=not evaluate, version_2_with_negative=args.version_2_with_negative
version_2_with_negative=args.version_2_with_negative) )
features = convert_examples_to_features(examples=examples, features = convert_examples_to_features(
examples=examples,
tokenizer=tokenizer, tokenizer=tokenizer,
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride, doc_stride=args.doc_stride,
max_query_length=args.max_query_length, max_query_length=args.max_query_length,
is_training=not evaluate) is_training=not evaluate,
)
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file) logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file) torch.save(features, cached_features_file)
...@@ -342,14 +404,21 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ...@@ -342,14 +404,21 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float) all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
if evaluate: if evaluate:
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, dataset = TensorDataset(
all_example_index, all_cls_index, all_p_mask) all_input_ids, all_input_mask, all_segment_ids, all_example_index, all_cls_index, all_p_mask
)
else: else:
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long) all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long) all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, dataset = TensorDataset(
all_start_positions, all_end_positions, all_input_ids,
all_cls_index, all_p_mask) all_input_mask,
all_segment_ids,
all_start_positions,
all_end_positions,
all_cls_index,
all_p_mask,
)
if output_examples: if output_examples:
return dataset, examples, features return dataset, examples, features
...@@ -360,121 +429,213 @@ def main(): ...@@ -360,121 +429,213 @@ def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters ## Required parameters
parser.add_argument("--train_file", default=None, type=str, required=True, parser.add_argument(
help="SQuAD json for training. E.g., train-v1.1.json") "--train_file", default=None, type=str, required=True, help="SQuAD json for training. E.g., train-v1.1.json"
parser.add_argument("--predict_file", default=None, type=str, required=True, )
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json") parser.add_argument(
parser.add_argument("--model_type", default=None, type=str, required=True, "--predict_file",
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) default=None,
parser.add_argument("--model_name_or_path", default=None, type=str, required=True, type=str,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) required=True,
parser.add_argument("--output_dir", default=None, type=str, required=True, help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json",
help="The output directory where the model checkpoints and predictions will be written.") )
parser.add_argument(
"--model_type",
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
)
parser.add_argument(
"--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model checkpoints and predictions will be written.",
)
# Distillation parameters (optional) # Distillation parameters (optional)
parser.add_argument('--teacher_type', default=None, type=str, parser.add_argument(
help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.") "--teacher_type",
parser.add_argument('--teacher_name_or_path', default=None, type=str, default=None,
help="Path to the already SQuAD fine-tuned teacher model. Only for distillation.") type=str,
parser.add_argument('--alpha_ce', default=0.5, type=float, help="Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.",
help="Distillation loss linear weight. Only for distillation.") )
parser.add_argument('--alpha_squad', default=0.5, type=float, parser.add_argument(
help="True SQuAD loss linear weight. Only for distillation.") "--teacher_name_or_path",
parser.add_argument('--temperature', default=2.0, type=float, default=None,
help="Distillation temperature. Only for distillation.") type=str,
help="Path to the already SQuAD fine-tuned teacher model. Only for distillation.",
)
parser.add_argument(
"--alpha_ce", default=0.5, type=float, help="Distillation loss linear weight. Only for distillation."
)
parser.add_argument(
"--alpha_squad", default=0.5, type=float, help="True SQuAD loss linear weight. Only for distillation."
)
parser.add_argument(
"--temperature", default=2.0, type=float, help="Distillation temperature. Only for distillation."
)
## Other parameters ## Other parameters
parser.add_argument("--config_name", default="", type=str, parser.add_argument(
help="Pretrained config name or path if not the same as model_name") "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
parser.add_argument("--tokenizer_name", default="", type=str, )
help="Pretrained tokenizer name or path if not the same as model_name") parser.add_argument(
parser.add_argument("--cache_dir", default="", type=str, "--tokenizer_name",
help="Where do you want to store the pre-trained models downloaded from s3") default="",
type=str,
parser.add_argument('--version_2_with_negative', action='store_true', help="Pretrained tokenizer name or path if not the same as model_name",
help='If true, the SQuAD examples contain some that do not have an answer.') )
parser.add_argument('--null_score_diff_threshold', type=float, default=0.0, parser.add_argument(
help="If null_score - best_non_null is greater than the threshold predict null.") "--cache_dir",
default="",
parser.add_argument("--max_seq_length", default=384, type=int, type=str,
help="Where do you want to store the pre-trained models downloaded from s3",
)
parser.add_argument(
"--version_2_with_negative",
action="store_true",
help="If true, the SQuAD examples contain some that do not have an answer.",
)
parser.add_argument(
"--null_score_diff_threshold",
type=float,
default=0.0,
help="If null_score - best_non_null is greater than the threshold predict null.",
)
parser.add_argument(
"--max_seq_length",
default=384,
type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences " help="The maximum total input sequence length after WordPiece tokenization. Sequences "
"longer than this will be truncated, and sequences shorter than this will be padded.") "longer than this will be truncated, and sequences shorter than this will be padded.",
parser.add_argument("--doc_stride", default=128, type=int, )
help="When splitting up a long document into chunks, how much stride to take between chunks.") parser.add_argument(
parser.add_argument("--max_query_length", default=64, type=int, "--doc_stride",
default=128,
type=int,
help="When splitting up a long document into chunks, how much stride to take between chunks.",
)
parser.add_argument(
"--max_query_length",
default=64,
type=int,
help="The maximum number of tokens for the question. Questions longer than this will " help="The maximum number of tokens for the question. Questions longer than this will "
"be truncated to this length.") "be truncated to this length.",
parser.add_argument("--do_train", action='store_true', )
help="Whether to run training.") parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true', parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
help="Whether to run eval on the dev set.") parser.add_argument(
parser.add_argument("--evaluate_during_training", action='store_true', "--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
help="Rul evaluation during training at each logging step.") )
parser.add_argument("--do_lower_case", action='store_true', parser.add_argument(
help="Set this flag if you are using an uncased model.") "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
)
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
help="Batch size per GPU/CPU for training.") parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, parser.add_argument(
help="Batch size per GPU/CPU for evaluation.") "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
parser.add_argument("--learning_rate", default=5e-5, type=float, )
help="The initial learning rate for Adam.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1, parser.add_argument(
help="Number of updates steps to accumulate before performing a backward/update pass.") "--gradient_accumulation_steps",
parser.add_argument("--weight_decay", default=0.0, type=float, type=int,
help="Weight deay if we apply some.") default=1,
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Number of updates steps to accumulate before performing a backward/update pass.",
help="Epsilon for Adam optimizer.") )
parser.add_argument("--max_grad_norm", default=1.0, type=float, parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
help="Max gradient norm.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--num_train_epochs", default=3.0, type=float, parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
help="Total number of training epochs to perform.") parser.add_argument(
parser.add_argument("--max_steps", default=-1, type=int, "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
help="If > 0: set total number of training steps to perform. Override num_train_epochs.") )
parser.add_argument("--warmup_steps", default=0, type=int, parser.add_argument(
help="Linear warmup over warmup_steps.") "--max_steps",
parser.add_argument("--n_best_size", default=20, type=int, default=-1,
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.") type=int,
parser.add_argument("--max_answer_length", default=30, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
)
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument(
"--n_best_size",
default=20,
type=int,
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.",
)
parser.add_argument(
"--max_answer_length",
default=30,
type=int,
help="The maximum length of an answer that can be generated. This is needed because the start " help="The maximum length of an answer that can be generated. This is needed because the start "
"and end predictions are not conditioned on one another.") "and end predictions are not conditioned on one another.",
parser.add_argument("--verbose_logging", action='store_true', )
parser.add_argument(
"--verbose_logging",
action="store_true",
help="If true, all of the warnings related to data processing will be printed. " help="If true, all of the warnings related to data processing will be printed. "
"A number of warnings are expected for a normal SQuAD evaluation.") "A number of warnings are expected for a normal SQuAD evaluation.",
)
parser.add_argument('--logging_steps', type=int, default=50,
help="Log every X updates steps.") parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
parser.add_argument('--save_steps', type=int, default=50, parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
help="Save checkpoint every X updates steps.") parser.add_argument(
parser.add_argument("--eval_all_checkpoints", action='store_true', "--eval_all_checkpoints",
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") action="store_true",
parser.add_argument("--no_cuda", action='store_true', help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
help="Whether not to use CUDA when available") )
parser.add_argument('--overwrite_output_dir', action='store_true', parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
help="Overwrite the content of the output directory") parser.add_argument(
parser.add_argument('--overwrite_cache', action='store_true', "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
help="Overwrite the cached training and evaluation sets") )
parser.add_argument('--seed', type=int, default=42, parser.add_argument(
help="random seed for initialization") "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
)
parser.add_argument("--local_rank", type=int, default=-1, parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
help="local_rank for distributed training on gpus")
parser.add_argument('--fp16', action='store_true', parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") parser.add_argument(
parser.add_argument('--fp16_opt_level', type=str, default='O1', "--fp16",
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
parser.add_argument(
"--fp16_opt_level",
type=str,
default="O1",
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html") "See details at https://nvidia.github.io/apex/amp.html",
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") )
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
args = parser.parse_args() args = parser.parse_args()
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: if (
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) os.path.exists(args.output_dir)
and os.listdir(args.output_dir)
and args.do_train
and not args.overwrite_output_dir
):
raise ValueError(
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
args.output_dir
)
)
# Setup distant debugging if needed # Setup distant debugging if needed
if args.server_ip and args.server_port: if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd import ptvsd
print("Waiting for debugger attach") print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach() ptvsd.wait_for_attach()
...@@ -486,16 +647,24 @@ def main(): ...@@ -486,16 +647,24 @@ def main():
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.cuda.set_device(args.local_rank) torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank) device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend='nccl') torch.distributed.init_process_group(backend="nccl")
args.n_gpu = 1 args.n_gpu = 1
args.device = device args.device = device
# Setup logging # Setup logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(
datefmt = '%m/%d/%Y %H:%M:%S', format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) datefmt="%m/%d/%Y %H:%M:%S",
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) )
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
args.local_rank,
device,
args.n_gpu,
bool(args.local_rank != -1),
args.fp16,
)
# Set seed # Set seed
set_seed(args) set_seed(args)
...@@ -506,27 +675,34 @@ def main(): ...@@ -506,27 +675,34 @@ def main():
args.model_type = args.model_type.lower() args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, config = config_class.from_pretrained(
cache_dir=args.cache_dir if args.cache_dir else None) args.config_name if args.config_name else args.model_name_or_path,
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, cache_dir=args.cache_dir if args.cache_dir else None,
)
tokenizer = tokenizer_class.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
do_lower_case=args.do_lower_case, do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None) cache_dir=args.cache_dir if args.cache_dir else None,
model = model_class.from_pretrained(args.model_name_or_path, )
from_tf=bool('.ckpt' in args.model_name_or_path), model = model_class.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config, config=config,
cache_dir=args.cache_dir if args.cache_dir else None) cache_dir=args.cache_dir if args.cache_dir else None,
)
if args.teacher_type is not None: if args.teacher_type is not None:
assert args.teacher_name_or_path is not None assert args.teacher_name_or_path is not None
assert args.alpha_ce > 0. assert args.alpha_ce > 0.0
assert args.alpha_ce + args.alpha_squad > 0. assert args.alpha_ce + args.alpha_squad > 0.0
assert args.teacher_type != 'distilbert', "We constraint teachers not to be of type DistilBERT." assert args.teacher_type != "distilbert", "We constraint teachers not to be of type DistilBERT."
teacher_config_class, teacher_model_class, _ = MODEL_CLASSES[args.teacher_type] teacher_config_class, teacher_model_class, _ = MODEL_CLASSES[args.teacher_type]
teacher_config = teacher_config_class.from_pretrained(args.teacher_name_or_path, teacher_config = teacher_config_class.from_pretrained(
cache_dir=args.cache_dir if args.cache_dir else None) args.teacher_name_or_path, cache_dir=args.cache_dir if args.cache_dir else None
teacher = teacher_model_class.from_pretrained(args.teacher_name_or_path, )
config=teacher_config, teacher = teacher_model_class.from_pretrained(
cache_dir=args.cache_dir if args.cache_dir else None) args.teacher_name_or_path, config=teacher_config, cache_dir=args.cache_dir if args.cache_dir else None
)
teacher.to(args.device) teacher.to(args.device)
else: else:
teacher = None teacher = None
...@@ -544,7 +720,6 @@ def main(): ...@@ -544,7 +720,6 @@ def main():
global_step, tr_loss = train(args, train_dataset, model, tokenizer, teacher=teacher) global_step, tr_loss = train(args, train_dataset, model, tokenizer, teacher=teacher)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
# Save the trained model and the tokenizer # Save the trained model and the tokenizer
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
# Create output directory if needed # Create output directory if needed
...@@ -554,41 +729,44 @@ def main(): ...@@ -554,41 +729,44 @@ def main():
logger.info("Saving model checkpoint to %s", args.output_dir) logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`. # Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()` # They can then be reloaded using `from_pretrained()`
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(args.output_dir) model_to_save.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
# Good practice: save your training arguments together with the trained model # Good practice: save your training arguments together with the trained model
torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
# Load a trained model and vocabulary that you have fine-tuned # Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir, cache_dir=args.cache_dir if args.cache_dir else None) model = model_class.from_pretrained(args.output_dir, cache_dir=args.cache_dir if args.cache_dir else None)
tokenizer = tokenizer_class.from_pretrained(args.output_dir, tokenizer = tokenizer_class.from_pretrained(
do_lower_case=args.do_lower_case, args.output_dir, do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None
cache_dir=args.cache_dir if args.cache_dir else None) )
model.to(args.device) model.to(args.device)
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
results = {} results = {}
if args.do_eval and args.local_rank in [-1, 0]: if args.do_eval and args.local_rank in [-1, 0]:
checkpoints = [args.output_dir] checkpoints = [args.output_dir]
if args.eval_all_checkpoints: if args.eval_all_checkpoints:
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) checkpoints = list(
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
)
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
# Reload the model # Reload the model
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
model = model_class.from_pretrained(checkpoint, cache_dir=args.cache_dir if args.cache_dir else None) model = model_class.from_pretrained(checkpoint, cache_dir=args.cache_dir if args.cache_dir else None)
model.to(args.device) model.to(args.device)
# Evaluate # Evaluate
result = evaluate(args, model, tokenizer, prefix=global_step) result = evaluate(args, model, tokenizer, prefix=global_step)
result = dict((k + ('_{}'.format(global_step) if global_step else ''), v) for k, v in result.items()) result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
results.update(result) results.update(result)
logger.info("Results: {}".format(results)) logger.info("Results: {}".format(results))
......
...@@ -23,68 +23,65 @@ import numpy as np ...@@ -23,68 +23,65 @@ import numpy as np
from transformers import BertTokenizer, RobertaTokenizer, GPT2Tokenizer from transformers import BertTokenizer, RobertaTokenizer, GPT2Tokenizer
import logging import logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(
datefmt = '%m/%d/%Y %H:%M:%S', format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
level = logging.INFO) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def main(): def main():
parser = argparse.ArgumentParser(description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids).") parser = argparse.ArgumentParser(
parser.add_argument('--file_path', type=str, default='data/dump.txt', description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids)."
help='The path to the data.') )
parser.add_argument('--tokenizer_type', type=str, default='bert', choices=['bert', 'roberta', 'gpt2']) parser.add_argument("--file_path", type=str, default="data/dump.txt", help="The path to the data.")
parser.add_argument('--tokenizer_name', type=str, default='bert-base-uncased', parser.add_argument("--tokenizer_type", type=str, default="bert", choices=["bert", "roberta", "gpt2"])
help="The tokenizer to use.") parser.add_argument("--tokenizer_name", type=str, default="bert-base-uncased", help="The tokenizer to use.")
parser.add_argument('--dump_file', type=str, default='data/dump', parser.add_argument("--dump_file", type=str, default="data/dump", help="The dump file prefix.")
help='The dump file prefix.')
args = parser.parse_args() args = parser.parse_args()
logger.info(f"Loading Tokenizer ({args.tokenizer_name})")
logger.info(f'Loading Tokenizer ({args.tokenizer_name})') if args.tokenizer_type == "bert":
if args.tokenizer_type == 'bert':
tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name) tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name)
bos = tokenizer.special_tokens_map['cls_token'] # `[CLS]` bos = tokenizer.special_tokens_map["cls_token"] # `[CLS]`
sep = tokenizer.special_tokens_map['sep_token'] # `[SEP]` sep = tokenizer.special_tokens_map["sep_token"] # `[SEP]`
elif args.tokenizer_type == 'roberta': elif args.tokenizer_type == "roberta":
tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name) tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name)
bos = tokenizer.special_tokens_map['cls_token'] # `<s>` bos = tokenizer.special_tokens_map["cls_token"] # `<s>`
sep = tokenizer.special_tokens_map['sep_token'] # `</s>` sep = tokenizer.special_tokens_map["sep_token"] # `</s>`
elif args.tokenizer_type == 'gpt2': elif args.tokenizer_type == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_name) tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_name)
bos = tokenizer.special_tokens_map['bos_token'] # `<|endoftext|>` bos = tokenizer.special_tokens_map["bos_token"] # `<|endoftext|>`
sep = tokenizer.special_tokens_map['eos_token'] # `<|endoftext|>` sep = tokenizer.special_tokens_map["eos_token"] # `<|endoftext|>`
logger.info(f'Loading text from {args.file_path}') logger.info(f"Loading text from {args.file_path}")
with open(args.file_path, 'r', encoding='utf8') as fp: with open(args.file_path, "r", encoding="utf8") as fp:
data = fp.readlines() data = fp.readlines()
logger.info(f"Start encoding")
logger.info(f'Start encoding') logger.info(f"{len(data)} examples to process.")
logger.info(f'{len(data)} examples to process.')
rslt = [] rslt = []
iter = 0 iter = 0
interval = 10000 interval = 10000
start = time.time() start = time.time()
for text in data: for text in data:
text = f'{bos} {text.strip()} {sep}' text = f"{bos} {text.strip()} {sep}"
token_ids = tokenizer.encode(text, add_special_tokens=False) token_ids = tokenizer.encode(text, add_special_tokens=False)
rslt.append(token_ids) rslt.append(token_ids)
iter += 1 iter += 1
if iter % interval == 0: if iter % interval == 0:
end = time.time() end = time.time()
logger.info(f'{iter} examples processed. - {(end-start)/interval:.2f}s/expl') logger.info(f"{iter} examples processed. - {(end-start)/interval:.2f}s/expl")
start = time.time() start = time.time()
logger.info('Finished binarization') logger.info("Finished binarization")
logger.info(f'{len(data)} examples processed.') logger.info(f"{len(data)} examples processed.")
dp_file = f'{args.dump_file}.{args.tokenizer_name}.pickle' dp_file = f"{args.dump_file}.{args.tokenizer_name}.pickle"
rslt_ = [np.uint16(d) for d in rslt] rslt_ = [np.uint16(d) for d in rslt]
random.shuffle(rslt_) random.shuffle(rslt_)
logger.info(f'Dump to {dp_file}') logger.info(f"Dump to {dp_file}")
with open(dp_file, 'wb') as handle: with open(dp_file, "wb") as handle:
pickle.dump(rslt_, handle, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(rslt_, handle, protocol=pickle.HIGHEST_PROTOCOL)
......
...@@ -20,70 +20,80 @@ from transformers import BertForMaskedLM, RobertaForMaskedLM, GPT2LMHeadModel ...@@ -20,70 +20,80 @@ from transformers import BertForMaskedLM, RobertaForMaskedLM, GPT2LMHeadModel
import torch import torch
import argparse import argparse
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation") parser = argparse.ArgumentParser(
description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation"
)
parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"]) parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"])
parser.add_argument("--model_name", default='roberta-large', type=str) parser.add_argument("--model_name", default="roberta-large", type=str)
parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_roberta_048131723.pth', type=str) parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_roberta_048131723.pth", type=str)
parser.add_argument("--vocab_transform", action='store_true') parser.add_argument("--vocab_transform", action="store_true")
args = parser.parse_args() args = parser.parse_args()
if args.model_type == "roberta":
if args.model_type == 'roberta':
model = RobertaForMaskedLM.from_pretrained(args.model_name) model = RobertaForMaskedLM.from_pretrained(args.model_name)
prefix = 'roberta' prefix = "roberta"
elif args.model_type == 'gpt2': elif args.model_type == "gpt2":
model = GPT2LMHeadModel.from_pretrained(args.model_name) model = GPT2LMHeadModel.from_pretrained(args.model_name)
prefix = 'transformer' prefix = "transformer"
state_dict = model.state_dict() state_dict = model.state_dict()
compressed_sd = {} compressed_sd = {}
### Embeddings ### ### Embeddings ###
if args.model_type == 'gpt2': if args.model_type == "gpt2":
for param_name in ['wte.weight', 'wpe.weight']: for param_name in ["wte.weight", "wpe.weight"]:
compressed_sd[f'{prefix}.{param_name}'] = state_dict[f'{prefix}.{param_name}'] compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"]
else: else:
for w in ['word_embeddings', 'position_embeddings', 'token_type_embeddings']: for w in ["word_embeddings", "position_embeddings", "token_type_embeddings"]:
param_name = f'{prefix}.embeddings.{w}.weight' param_name = f"{prefix}.embeddings.{w}.weight"
compressed_sd[param_name] = state_dict[param_name] compressed_sd[param_name] = state_dict[param_name]
for w in ['weight', 'bias']: for w in ["weight", "bias"]:
param_name = f'{prefix}.embeddings.LayerNorm.{w}' param_name = f"{prefix}.embeddings.LayerNorm.{w}"
compressed_sd[param_name] = state_dict[param_name] compressed_sd[param_name] = state_dict[param_name]
### Transformer Blocks ### ### Transformer Blocks ###
std_idx = 0 std_idx = 0
for teacher_idx in [0, 2, 4, 7, 9, 11]: for teacher_idx in [0, 2, 4, 7, 9, 11]:
if args.model_type == 'gpt2': if args.model_type == "gpt2":
for layer in ['ln_1', 'attn.c_attn', 'attn.c_proj', 'ln_2', 'mlp.c_fc', 'mlp.c_proj']: for layer in ["ln_1", "attn.c_attn", "attn.c_proj", "ln_2", "mlp.c_fc", "mlp.c_proj"]:
for w in ['weight', 'bias']: for w in ["weight", "bias"]:
compressed_sd[f'{prefix}.h.{std_idx}.{layer}.{w}'] = \ compressed_sd[f"{prefix}.h.{std_idx}.{layer}.{w}"] = state_dict[
state_dict[f'{prefix}.h.{teacher_idx}.{layer}.{w}'] f"{prefix}.h.{teacher_idx}.{layer}.{w}"
compressed_sd[f'{prefix}.h.{std_idx}.attn.bias'] = state_dict[f'{prefix}.h.{teacher_idx}.attn.bias'] ]
compressed_sd[f"{prefix}.h.{std_idx}.attn.bias"] = state_dict[f"{prefix}.h.{teacher_idx}.attn.bias"]
else: else:
for layer in ['attention.self.query', 'attention.self.key', 'attention.self.value', for layer in [
'attention.output.dense', 'attention.output.LayerNorm', "attention.self.query",
'intermediate.dense', 'output.dense', 'output.LayerNorm']: "attention.self.key",
for w in ['weight', 'bias']: "attention.self.value",
compressed_sd[f'{prefix}.encoder.layer.{std_idx}.{layer}.{w}'] = \ "attention.output.dense",
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.{layer}.{w}'] "attention.output.LayerNorm",
"intermediate.dense",
"output.dense",
"output.LayerNorm",
]:
for w in ["weight", "bias"]:
compressed_sd[f"{prefix}.encoder.layer.{std_idx}.{layer}.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.{layer}.{w}"
]
std_idx += 1 std_idx += 1
### Language Modeling Head ###s ### Language Modeling Head ###s
if args.model_type == 'roberta': if args.model_type == "roberta":
for layer in ['lm_head.decoder.weight', 'lm_head.bias']: for layer in ["lm_head.decoder.weight", "lm_head.bias"]:
compressed_sd[f'{layer}'] = state_dict[f'{layer}'] compressed_sd[f"{layer}"] = state_dict[f"{layer}"]
if args.vocab_transform: if args.vocab_transform:
for w in ['weight', 'bias']: for w in ["weight", "bias"]:
compressed_sd[f'lm_head.dense.{w}'] = state_dict[f'lm_head.dense.{w}'] compressed_sd[f"lm_head.dense.{w}"] = state_dict[f"lm_head.dense.{w}"]
compressed_sd[f'lm_head.layer_norm.{w}'] = state_dict[f'lm_head.layer_norm.{w}'] compressed_sd[f"lm_head.layer_norm.{w}"] = state_dict[f"lm_head.layer_norm.{w}"]
elif args.model_type == 'gpt2': elif args.model_type == "gpt2":
for w in ['weight', 'bias']: for w in ["weight", "bias"]:
compressed_sd[f'{prefix}.ln_f.{w}'] = state_dict[f'{prefix}.ln_f.{w}'] compressed_sd[f"{prefix}.ln_f.{w}"] = state_dict[f"{prefix}.ln_f.{w}"]
compressed_sd[f'lm_head.weight'] = state_dict[f'lm_head.weight'] compressed_sd[f"lm_head.weight"] = state_dict[f"lm_head.weight"]
print(f'N layers selected for distillation: {std_idx}') print(f"N layers selected for distillation: {std_idx}")
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}') print(f"Number of params transfered for distillation: {len(compressed_sd.keys())}")
print(f'Save transfered checkpoint to {args.dump_checkpoint}.') print(f"Save transfered checkpoint to {args.dump_checkpoint}.")
torch.save(compressed_sd, args.dump_checkpoint) torch.save(compressed_sd, args.dump_checkpoint)
...@@ -20,63 +20,70 @@ from transformers import BertForMaskedLM, RobertaForMaskedLM ...@@ -20,63 +20,70 @@ from transformers import BertForMaskedLM, RobertaForMaskedLM
import torch import torch
import argparse import argparse
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation") parser = argparse.ArgumentParser(
description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation"
)
parser.add_argument("--model_type", default="bert", choices=["bert"]) parser.add_argument("--model_type", default="bert", choices=["bert"])
parser.add_argument("--model_name", default='bert-base-uncased', type=str) parser.add_argument("--model_name", default="bert-base-uncased", type=str)
parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_bert-base-uncased_0247911.pth', type=str) parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_bert-base-uncased_0247911.pth", type=str)
parser.add_argument("--vocab_transform", action='store_true') parser.add_argument("--vocab_transform", action="store_true")
args = parser.parse_args() args = parser.parse_args()
if args.model_type == "bert":
if args.model_type == 'bert':
model = BertForMaskedLM.from_pretrained(args.model_name) model = BertForMaskedLM.from_pretrained(args.model_name)
prefix = 'bert' prefix = "bert"
else: else:
raise ValueError(f'args.model_type should be "bert".') raise ValueError(f'args.model_type should be "bert".')
state_dict = model.state_dict() state_dict = model.state_dict()
compressed_sd = {} compressed_sd = {}
for w in ['word_embeddings', 'position_embeddings']: for w in ["word_embeddings", "position_embeddings"]:
compressed_sd[f'distilbert.embeddings.{w}.weight'] = \ compressed_sd[f"distilbert.embeddings.{w}.weight"] = state_dict[f"{prefix}.embeddings.{w}.weight"]
state_dict[f'{prefix}.embeddings.{w}.weight'] for w in ["weight", "bias"]:
for w in ['weight', 'bias']: compressed_sd[f"distilbert.embeddings.LayerNorm.{w}"] = state_dict[f"{prefix}.embeddings.LayerNorm.{w}"]
compressed_sd[f'distilbert.embeddings.LayerNorm.{w}'] = \
state_dict[f'{prefix}.embeddings.LayerNorm.{w}']
std_idx = 0 std_idx = 0
for teacher_idx in [0, 2, 4, 7, 9, 11]: for teacher_idx in [0, 2, 4, 7, 9, 11]:
for w in ['weight', 'bias']: for w in ["weight", "bias"]:
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}'] = \ compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}"] = state_dict[
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.query.{w}'] f"{prefix}.encoder.layer.{teacher_idx}.attention.self.query.{w}"
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}'] = \ ]
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.key.{w}'] compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}"] = state_dict[
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}'] = \ f"{prefix}.encoder.layer.{teacher_idx}.attention.self.key.{w}"
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.value.{w}'] ]
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.attention.self.value.{w}"
]
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}'] = \ compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}"] = state_dict[
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.output.dense.{w}'] f"{prefix}.encoder.layer.{teacher_idx}.attention.output.dense.{w}"
compressed_sd[f'distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}'] = \ ]
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}'] compressed_sd[f"distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}"
]
compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}'] = \ compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}"] = state_dict[
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.intermediate.dense.{w}'] f"{prefix}.encoder.layer.{teacher_idx}.intermediate.dense.{w}"
compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}'] = \ ]
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.dense.{w}'] compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}"] = state_dict[
compressed_sd[f'distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}'] = \ f"{prefix}.encoder.layer.{teacher_idx}.output.dense.{w}"
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}'] ]
compressed_sd[f"distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}"
]
std_idx += 1 std_idx += 1
compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight'] compressed_sd[f"vocab_projector.weight"] = state_dict[f"cls.predictions.decoder.weight"]
compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias'] compressed_sd[f"vocab_projector.bias"] = state_dict[f"cls.predictions.bias"]
if args.vocab_transform: if args.vocab_transform:
for w in ['weight', 'bias']: for w in ["weight", "bias"]:
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}'] compressed_sd[f"vocab_transform.{w}"] = state_dict[f"cls.predictions.transform.dense.{w}"]
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}'] compressed_sd[f"vocab_layer_norm.{w}"] = state_dict[f"cls.predictions.transform.LayerNorm.{w}"]
print(f'N layers selected for distillation: {std_idx}') print(f"N layers selected for distillation: {std_idx}")
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}') print(f"Number of params transfered for distillation: {len(compressed_sd.keys())}")
print(f'Save transfered checkpoint to {args.dump_checkpoint}.') print(f"Save transfered checkpoint to {args.dump_checkpoint}.")
torch.save(compressed_sd, args.dump_checkpoint) torch.save(compressed_sd, args.dump_checkpoint)
...@@ -20,32 +20,36 @@ import argparse ...@@ -20,32 +20,36 @@ import argparse
import pickle import pickle
import logging import logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(
datefmt = '%m/%d/%Y %H:%M:%S', format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
level = logging.INFO) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Token Counts for smoothing the masking probabilities in MLM (cf XLM/word2vec)") parser = argparse.ArgumentParser(
parser.add_argument("--data_file", type=str, default="data/dump.bert-base-uncased.pickle", description="Token Counts for smoothing the masking probabilities in MLM (cf XLM/word2vec)"
help="The binarized dataset.") )
parser.add_argument("--token_counts_dump", type=str, default="data/token_counts.bert-base-uncased.pickle", parser.add_argument(
help="The dump file.") "--data_file", type=str, default="data/dump.bert-base-uncased.pickle", help="The binarized dataset."
)
parser.add_argument(
"--token_counts_dump", type=str, default="data/token_counts.bert-base-uncased.pickle", help="The dump file."
)
parser.add_argument("--vocab_size", default=30522, type=int) parser.add_argument("--vocab_size", default=30522, type=int)
args = parser.parse_args() args = parser.parse_args()
logger.info(f'Loading data from {args.data_file}') logger.info(f"Loading data from {args.data_file}")
with open(args.data_file, 'rb') as fp: with open(args.data_file, "rb") as fp:
data = pickle.load(fp) data = pickle.load(fp)
logger.info('Counting occurences for MLM.') logger.info("Counting occurences for MLM.")
counter = Counter() counter = Counter()
for tk_ids in data: for tk_ids in data:
counter.update(tk_ids) counter.update(tk_ids)
counts = [0]*args.vocab_size counts = [0] * args.vocab_size
for k, v in counter.items(): for k, v in counter.items():
counts[k] = v counts[k] = v
logger.info(f'Dump to {args.token_counts_dump}') logger.info(f"Dump to {args.token_counts_dump}")
with open(args.token_counts_dump, 'wb') as handle: with open(args.token_counts_dump, "wb") as handle:
pickle.dump(counts, handle, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(counts, handle, protocol=pickle.HIGHEST_PROTOCOL)
...@@ -35,166 +35,200 @@ from lm_seqs_dataset import LmSeqsDataset ...@@ -35,166 +35,200 @@ from lm_seqs_dataset import LmSeqsDataset
MODEL_CLASSES = { MODEL_CLASSES = {
'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer), "distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer), "roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
'bert': (BertConfig, BertForMaskedLM, BertTokenizer), "bert": (BertConfig, BertForMaskedLM, BertTokenizer),
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer) "gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
} }
def sanity_checks(args): def sanity_checks(args):
""" """
A bunch of args sanity checks to perform even starting... A bunch of args sanity checks to perform even starting...
""" """
assert (args.mlm and args.alpha_mlm > 0.) or (not args.mlm and args.alpha_mlm == 0.) assert (args.mlm and args.alpha_mlm > 0.0) or (not args.mlm and args.alpha_mlm == 0.0)
assert (args.alpha_mlm > 0. and args.alpha_clm == 0.) or (args.alpha_mlm == 0. and args.alpha_clm > 0.) assert (args.alpha_mlm > 0.0 and args.alpha_clm == 0.0) or (args.alpha_mlm == 0.0 and args.alpha_clm > 0.0)
if args.mlm: if args.mlm:
assert os.path.isfile(args.token_counts) assert os.path.isfile(args.token_counts)
assert (args.student_type in ['roberta', 'distilbert']) and (args.teacher_type in ['roberta', 'bert']) assert (args.student_type in ["roberta", "distilbert"]) and (args.teacher_type in ["roberta", "bert"])
else: else:
assert (args.student_type in ['gpt2']) and (args.teacher_type in ['gpt2']) assert (args.student_type in ["gpt2"]) and (args.teacher_type in ["gpt2"])
assert args.teacher_type == args.student_type or (args.student_type=='distilbert' and args.teacher_type=='bert') assert args.teacher_type == args.student_type or (
args.student_type == "distilbert" and args.teacher_type == "bert"
)
assert os.path.isfile(args.student_config) assert os.path.isfile(args.student_config)
if args.student_pretrained_weights is not None: if args.student_pretrained_weights is not None:
assert os.path.isfile(args.student_pretrained_weights) assert os.path.isfile(args.student_pretrained_weights)
if args.freeze_token_type_embds: assert args.student_type in ['roberta'] if args.freeze_token_type_embds:
assert args.student_type in ["roberta"]
assert args.alpha_ce >= 0.0
assert args.alpha_mlm >= 0.0
assert args.alpha_clm >= 0.0
assert args.alpha_mse >= 0.0
assert args.alpha_cos >= 0.0
assert args.alpha_ce + args.alpha_mlm + args.alpha_clm + args.alpha_mse + args.alpha_cos > 0.0
assert args.alpha_ce >= 0.
assert args.alpha_mlm >= 0.
assert args.alpha_clm >= 0.
assert args.alpha_mse >= 0.
assert args.alpha_cos >= 0.
assert args.alpha_ce + args.alpha_mlm + args.alpha_clm + args.alpha_mse + args.alpha_cos > 0.
def freeze_pos_embeddings(student, args): def freeze_pos_embeddings(student, args):
if args.student_type == 'roberta': if args.student_type == "roberta":
student.roberta.embeddings.position_embeddings.weight.requires_grad = False student.roberta.embeddings.position_embeddings.weight.requires_grad = False
elif args.student_type == 'gpt2': elif args.student_type == "gpt2":
student.transformer.wpe.weight.requires_grad = False student.transformer.wpe.weight.requires_grad = False
def freeze_token_type_embeddings(student, args): def freeze_token_type_embeddings(student, args):
if args.student_type == 'roberta': if args.student_type == "roberta":
student.roberta.embeddings.token_type_embeddings.weight.requires_grad = False student.roberta.embeddings.token_type_embeddings.weight.requires_grad = False
def main(): def main():
parser = argparse.ArgumentParser(description="Training") parser = argparse.ArgumentParser(description="Training")
parser.add_argument("--force", action='store_true', parser.add_argument("--force", action="store_true", help="Overwrite dump_path if it already exists.")
help="Overwrite dump_path if it already exists.")
parser.add_argument(
parser.add_argument("--dump_path", type=str, required=True, "--dump_path", type=str, required=True, help="The output directory (log, checkpoints, parameters, etc.)"
help="The output directory (log, checkpoints, parameters, etc.)") )
parser.add_argument("--data_file", type=str, required=True, parser.add_argument(
help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.") "--data_file",
type=str,
parser.add_argument("--student_type", type=str, choices=["distilbert", "roberta", "gpt2"], required=True, required=True,
help="The student type (DistilBERT, RoBERTa).") help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.",
parser.add_argument("--student_config", type=str, required=True, )
help="Path to the student configuration.")
parser.add_argument("--student_pretrained_weights", default=None, type=str, parser.add_argument(
help="Load student initialization checkpoint.") "--student_type",
type=str,
parser.add_argument("--teacher_type", choices=["bert", "roberta", "gpt2"], required=True, choices=["distilbert", "roberta", "gpt2"],
help="Teacher type (BERT, RoBERTa).") required=True,
parser.add_argument("--teacher_name", type=str, required=True, help="The student type (DistilBERT, RoBERTa).",
help="The teacher model.") )
parser.add_argument("--student_config", type=str, required=True, help="Path to the student configuration.")
parser.add_argument("--temperature", default=2., type=float, parser.add_argument(
help="Temperature for the softmax temperature.") "--student_pretrained_weights", default=None, type=str, help="Load student initialization checkpoint."
parser.add_argument("--alpha_ce", default=0.5, type=float, )
help="Linear weight for the distillation loss. Must be >=0.")
parser.add_argument("--alpha_mlm", default=0.0, type=float, parser.add_argument(
help="Linear weight for the MLM loss. Must be >=0. Should be used in coonjunction with `mlm` flag.") "--teacher_type", choices=["bert", "roberta", "gpt2"], required=True, help="Teacher type (BERT, RoBERTa)."
parser.add_argument("--alpha_clm", default=0.5, type=float, )
help="Linear weight for the CLM loss. Must be >=0.") parser.add_argument("--teacher_name", type=str, required=True, help="The teacher model.")
parser.add_argument("--alpha_mse", default=0.0, type=float,
help="Linear weight of the MSE loss. Must be >=0.") parser.add_argument("--temperature", default=2.0, type=float, help="Temperature for the softmax temperature.")
parser.add_argument("--alpha_cos", default=0.0, type=float, parser.add_argument(
help="Linear weight of the cosine embedding loss. Must be >=0.") "--alpha_ce", default=0.5, type=float, help="Linear weight for the distillation loss. Must be >=0."
)
parser.add_argument("--mlm", action="store_true", parser.add_argument(
help="The LM step: MLM or CLM. If `mlm` is True, the MLM is used over CLM.") "--alpha_mlm",
parser.add_argument("--mlm_mask_prop", default=0.15, type=float, default=0.0,
help="Proportion of tokens for which we need to make a prediction.") type=float,
parser.add_argument("--word_mask", default=0.8, type=float, help="Linear weight for the MLM loss. Must be >=0. Should be used in coonjunction with `mlm` flag.",
help="Proportion of tokens to mask out.") )
parser.add_argument("--word_keep", default=0.1, type=float, parser.add_argument("--alpha_clm", default=0.5, type=float, help="Linear weight for the CLM loss. Must be >=0.")
help="Proportion of tokens to keep.") parser.add_argument("--alpha_mse", default=0.0, type=float, help="Linear weight of the MSE loss. Must be >=0.")
parser.add_argument("--word_rand", default=0.1, type=float, parser.add_argument(
help="Proportion of tokens to randomly replace.") "--alpha_cos", default=0.0, type=float, help="Linear weight of the cosine embedding loss. Must be >=0."
parser.add_argument("--mlm_smoothing", default=0.7, type=float, )
help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).")
parser.add_argument("--token_counts", type=str, parser.add_argument(
help="The token counts in the data_file for MLM.") "--mlm", action="store_true", help="The LM step: MLM or CLM. If `mlm` is True, the MLM is used over CLM."
)
parser.add_argument("--restrict_ce_to_mask", action='store_true', parser.add_argument(
help="If true, compute the distilation loss only the [MLM] prediction distribution.") "--mlm_mask_prop",
parser.add_argument("--freeze_pos_embs", action="store_true", default=0.15,
help="Freeze positional embeddings during distillation. For student_type in ['roberta', 'gpt2'] only.") type=float,
parser.add_argument("--freeze_token_type_embds", action="store_true", help="Proportion of tokens for which we need to make a prediction.",
help="Freeze token type embeddings during distillation if existent. For student_type in ['roberta'] only.") )
parser.add_argument("--word_mask", default=0.8, type=float, help="Proportion of tokens to mask out.")
parser.add_argument("--n_epoch", type=int, default=3, parser.add_argument("--word_keep", default=0.1, type=float, help="Proportion of tokens to keep.")
help="Number of pass on the whole dataset.") parser.add_argument("--word_rand", default=0.1, type=float, help="Proportion of tokens to randomly replace.")
parser.add_argument("--batch_size", type=int, default=5, parser.add_argument(
help="Batch size (for each process).") "--mlm_smoothing",
parser.add_argument("--group_by_size", action='store_false', default=0.7,
help="If true, group sequences that have similar length into the same batch. Default is true.") type=float,
help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).",
parser.add_argument("--gradient_accumulation_steps", type=int, default=50, )
help="Gradient accumulation for larger training batches.") parser.add_argument("--token_counts", type=str, help="The token counts in the data_file for MLM.")
parser.add_argument("--warmup_prop", default=0.05, type=float,
help="Linear warmup proportion.") parser.add_argument(
parser.add_argument("--weight_decay", default=0.0, type=float, "--restrict_ce_to_mask",
help="Weight deay if we apply some.") action="store_true",
parser.add_argument("--learning_rate", default=5e-4, type=float, help="If true, compute the distilation loss only the [MLM] prediction distribution.",
help="The initial learning rate for Adam.") )
parser.add_argument("--adam_epsilon", default=1e-6, type=float, parser.add_argument(
help="Epsilon for Adam optimizer.") "--freeze_pos_embs",
parser.add_argument("--max_grad_norm", default=5.0, type=float, action="store_true",
help="Max gradient norm.") help="Freeze positional embeddings during distillation. For student_type in ['roberta', 'gpt2'] only.",
parser.add_argument("--initializer_range", default=0.02, type=float, )
help="Random initialization range.") parser.add_argument(
"--freeze_token_type_embds",
parser.add_argument('--fp16', action='store_true', action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") help="Freeze token type embeddings during distillation if existent. For student_type in ['roberta'] only.",
parser.add_argument('--fp16_opt_level', type=str, default='O1', )
parser.add_argument("--n_epoch", type=int, default=3, help="Number of pass on the whole dataset.")
parser.add_argument("--batch_size", type=int, default=5, help="Batch size (for each process).")
parser.add_argument(
"--group_by_size",
action="store_false",
help="If true, group sequences that have similar length into the same batch. Default is true.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=50,
help="Gradient accumulation for larger training batches.",
)
parser.add_argument("--warmup_prop", default=0.05, type=float, help="Linear warmup proportion.")
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
parser.add_argument("--learning_rate", default=5e-4, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--adam_epsilon", default=1e-6, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=5.0, type=float, help="Max gradient norm.")
parser.add_argument("--initializer_range", default=0.02, type=float, help="Random initialization range.")
parser.add_argument(
"--fp16",
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
parser.add_argument(
"--fp16_opt_level",
type=str,
default="O1",
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html") "See details at https://nvidia.github.io/apex/amp.html",
parser.add_argument("--n_gpu", type=int, default=1, )
help="Number of GPUs in the node.") parser.add_argument("--n_gpu", type=int, default=1, help="Number of GPUs in the node.")
parser.add_argument("--local_rank", type=int, default=-1, parser.add_argument("--local_rank", type=int, default=-1, help="Distributed training - Local rank")
help="Distributed training - Local rank") parser.add_argument("--seed", type=int, default=56, help="Random seed")
parser.add_argument("--seed", type=int, default=56,
help="Random seed") parser.add_argument("--log_interval", type=int, default=500, help="Tensorboard logging interval.")
parser.add_argument("--checkpoint_interval", type=int, default=4000, help="Checkpoint interval.")
parser.add_argument("--log_interval", type=int, default=500,
help="Tensorboard logging interval.")
parser.add_argument("--checkpoint_interval", type=int, default=4000,
help="Checkpoint interval.")
args = parser.parse_args() args = parser.parse_args()
sanity_checks(args) sanity_checks(args)
## ARGS ## ## ARGS ##
init_gpu_params(args) init_gpu_params(args)
set_seed(args) set_seed(args)
if args.is_master: if args.is_master:
if os.path.exists(args.dump_path): if os.path.exists(args.dump_path):
if not args.force: if not args.force:
raise ValueError(f'Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite it' raise ValueError(
'Use `--force` if you want to overwrite it') f"Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite it"
"Use `--force` if you want to overwrite it"
)
else: else:
shutil.rmtree(args.dump_path) shutil.rmtree(args.dump_path)
if not os.path.exists(args.dump_path): if not os.path.exists(args.dump_path):
os.makedirs(args.dump_path) os.makedirs(args.dump_path)
logger.info(f'Experiment will be dumped and logged in {args.dump_path}') logger.info(f"Experiment will be dumped and logged in {args.dump_path}")
### SAVE PARAMS ### ### SAVE PARAMS ###
logger.info(f'Param: {args}') logger.info(f"Param: {args}")
with open(os.path.join(args.dump_path, 'parameters.json'), 'w') as f: with open(os.path.join(args.dump_path, "parameters.json"), "w") as f:
json.dump(vars(args), f, indent=4) json.dump(vars(args), f, indent=4)
git_log(args.dump_path) git_log(args.dump_path)
...@@ -207,58 +241,50 @@ def main(): ...@@ -207,58 +241,50 @@ def main():
for tok_name, tok_symbol in tokenizer.special_tokens_map.items(): for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
idx = tokenizer.all_special_tokens.index(tok_symbol) idx = tokenizer.all_special_tokens.index(tok_symbol)
special_tok_ids[tok_name] = tokenizer.all_special_ids[idx] special_tok_ids[tok_name] = tokenizer.all_special_ids[idx]
logger.info(f'Special tokens {special_tok_ids}') logger.info(f"Special tokens {special_tok_ids}")
args.special_tok_ids = special_tok_ids args.special_tok_ids = special_tok_ids
args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name] args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name]
## DATA LOADER ## ## DATA LOADER ##
logger.info(f'Loading data from {args.data_file}') logger.info(f"Loading data from {args.data_file}")
with open(args.data_file, 'rb') as fp: with open(args.data_file, "rb") as fp:
data = pickle.load(fp) data = pickle.load(fp)
if args.mlm: if args.mlm:
logger.info(f'Loading token counts from {args.token_counts} (already pre-computed)') logger.info(f"Loading token counts from {args.token_counts} (already pre-computed)")
with open(args.token_counts, 'rb') as fp: with open(args.token_counts, "rb") as fp:
counts = pickle.load(fp) counts = pickle.load(fp)
token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing
for idx in special_tok_ids.values(): for idx in special_tok_ids.values():
token_probs[idx] = 0. # do not predict special tokens token_probs[idx] = 0.0 # do not predict special tokens
token_probs = torch.from_numpy(token_probs) token_probs = torch.from_numpy(token_probs)
else: else:
token_probs = None token_probs = None
train_lm_seq_dataset = LmSeqsDataset(params=args, data=data) train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
logger.info(f'Data loader created.') logger.info(f"Data loader created.")
## STUDENT ## ## STUDENT ##
logger.info(f'Loading student config from {args.student_config}') logger.info(f"Loading student config from {args.student_config}")
stu_architecture_config = student_config_class.from_pretrained(args.student_config) stu_architecture_config = student_config_class.from_pretrained(args.student_config)
stu_architecture_config.output_hidden_states = True stu_architecture_config.output_hidden_states = True
if args.student_pretrained_weights is not None: if args.student_pretrained_weights is not None:
logger.info(f'Loading pretrained weights from {args.student_pretrained_weights}') logger.info(f"Loading pretrained weights from {args.student_pretrained_weights}")
student = student_model_class.from_pretrained(args.student_pretrained_weights, student = student_model_class.from_pretrained(args.student_pretrained_weights, config=stu_architecture_config)
config=stu_architecture_config)
else: else:
student = student_model_class(stu_architecture_config) student = student_model_class(stu_architecture_config)
if args.n_gpu > 0: if args.n_gpu > 0:
student.to(f'cuda:{args.local_rank}') student.to(f"cuda:{args.local_rank}")
logger.info(f'Student loaded.') logger.info(f"Student loaded.")
## TEACHER ## ## TEACHER ##
teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True) teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True)
if args.n_gpu > 0: if args.n_gpu > 0:
teacher.to(f'cuda:{args.local_rank}') teacher.to(f"cuda:{args.local_rank}")
logger.info(f'Teacher loaded from {args.teacher_name}.') logger.info(f"Teacher loaded from {args.teacher_name}.")
## FREEZING ## ## FREEZING ##
if args.freeze_pos_embs: if args.freeze_pos_embs:
...@@ -266,7 +292,6 @@ def main(): ...@@ -266,7 +292,6 @@ def main():
if args.freeze_token_type_embds: if args.freeze_token_type_embds:
freeze_token_type_embeddings(student, args) freeze_token_type_embeddings(student, args)
## SANITY CHECKS ## ## SANITY CHECKS ##
assert student.config.vocab_size == teacher.config.vocab_size assert student.config.vocab_size == teacher.config.vocab_size
assert student.config.hidden_size == teacher.config.hidden_size assert student.config.hidden_size == teacher.config.hidden_size
...@@ -274,14 +299,11 @@ def main(): ...@@ -274,14 +299,11 @@ def main():
if args.mlm: if args.mlm:
assert token_probs.size(0) == stu_architecture_config.vocab_size assert token_probs.size(0) == stu_architecture_config.vocab_size
## DISTILLER ## ## DISTILLER ##
torch.cuda.empty_cache() torch.cuda.empty_cache()
distiller = Distiller(params=args, distiller = Distiller(
dataset=train_lm_seq_dataset, params=args, dataset=train_lm_seq_dataset, token_probs=token_probs, student=student, teacher=teacher
token_probs=token_probs, )
student=student,
teacher=teacher)
distiller.train() distiller.train()
logger.info("Let's go get some drinks.") logger.info("Let's go get some drinks.")
......
...@@ -23,9 +23,12 @@ import torch ...@@ -23,9 +23,12 @@ import torch
import numpy as np import numpy as np
import logging import logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', logging.basicConfig(
level = logging.INFO) format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -35,12 +38,12 @@ def git_log(folder_path: str): ...@@ -35,12 +38,12 @@ def git_log(folder_path: str):
""" """
repo = git.Repo(search_parent_directories=True) repo = git.Repo(search_parent_directories=True)
repo_infos = { repo_infos = {
'repo_id': str(repo), "repo_id": str(repo),
'repo_sha': str(repo.head.object.hexsha), "repo_sha": str(repo.head.object.hexsha),
'repo_branch': str(repo.active_branch) "repo_branch": str(repo.active_branch),
} }
with open(os.path.join(folder_path, 'git_log.json'), 'w') as f: with open(os.path.join(folder_path, "git_log.json"), "w") as f:
json.dump(repo_infos, f, indent=4) json.dump(repo_infos, f, indent=4)
...@@ -57,21 +60,21 @@ def init_gpu_params(params): ...@@ -57,21 +60,21 @@ def init_gpu_params(params):
assert torch.cuda.is_available() assert torch.cuda.is_available()
logger.info('Initializing GPUs') logger.info("Initializing GPUs")
if params.n_gpu > 1: if params.n_gpu > 1:
assert params.local_rank != -1 assert params.local_rank != -1
params.world_size = int(os.environ['WORLD_SIZE']) params.world_size = int(os.environ["WORLD_SIZE"])
params.n_gpu_per_node = int(os.environ['N_GPU_NODE']) params.n_gpu_per_node = int(os.environ["N_GPU_NODE"])
params.global_rank = int(os.environ['RANK']) params.global_rank = int(os.environ["RANK"])
# number of nodes / node ID # number of nodes / node ID
params.n_nodes = params.world_size // params.n_gpu_per_node params.n_nodes = params.world_size // params.n_gpu_per_node
params.node_id = params.global_rank // params.n_gpu_per_node params.node_id = params.global_rank // params.n_gpu_per_node
params.multi_gpu = True params.multi_gpu = True
assert params.n_nodes == int(os.environ['N_NODES']) assert params.n_nodes == int(os.environ["N_NODES"])
assert params.node_id == int(os.environ['NODE_RANK']) assert params.node_id == int(os.environ["NODE_RANK"])
# local job (single GPU) # local job (single GPU)
else: else:
...@@ -114,8 +117,7 @@ def init_gpu_params(params): ...@@ -114,8 +117,7 @@ def init_gpu_params(params):
if params.multi_gpu: if params.multi_gpu:
logger.info("Initializing PyTorch distributed") logger.info("Initializing PyTorch distributed")
torch.distributed.init_process_group( torch.distributed.init_process_group(
init_method='env://', init_method="env://", backend="nccl",
backend='nccl',
) )
......
...@@ -40,29 +40,49 @@ from tqdm import tqdm, trange ...@@ -40,29 +40,49 @@ from tqdm import tqdm, trange
from utils_mmimdb import ImageEncoder, JsonlDataset, collate_fn, get_mmimdb_labels, get_image_transforms from utils_mmimdb import ImageEncoder, JsonlDataset, collate_fn, get_mmimdb_labels, get_image_transforms
from transformers import (WEIGHTS_NAME, from transformers import (
BertConfig, BertModel, BertTokenizer, WEIGHTS_NAME,
RobertaConfig, RobertaModel, RobertaTokenizer, BertConfig,
XLMConfig, XLMModel, XLMTokenizer, BertModel,
XLNetConfig, XLNetModel, XLNetTokenizer, BertTokenizer,
DistilBertConfig, DistilBertModel, DistilBertTokenizer, RobertaConfig,
AlbertConfig, AlbertModel, AlbertTokenizer, RobertaModel,
MMBTForClassification, MMBTConfig) RobertaTokenizer,
XLMConfig,
XLMModel,
XLMTokenizer,
XLNetConfig,
XLNetModel,
XLNetTokenizer,
DistilBertConfig,
DistilBertModel,
DistilBertTokenizer,
AlbertConfig,
AlbertModel,
AlbertTokenizer,
MMBTForClassification,
MMBTConfig,
)
from transformers import AdamW, get_linear_schedule_with_warmup from transformers import AdamW, get_linear_schedule_with_warmup
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig, ALL_MODELS = sum(
RobertaConfig, DistilBertConfig)), ()) (
tuple(conf.pretrained_config_archive_map.keys())
for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig, DistilBertConfig)
),
(),
)
MODEL_CLASSES = { MODEL_CLASSES = {
'bert': (BertConfig, BertModel, BertTokenizer), "bert": (BertConfig, BertModel, BertTokenizer),
'xlnet': (XLNetConfig, XLNetModel, XLNetTokenizer), "xlnet": (XLNetConfig, XLNetModel, XLNetTokenizer),
'xlm': (XLMConfig, XLMModel, XLMTokenizer), "xlm": (XLMConfig, XLMModel, XLMTokenizer),
'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer), "roberta": (RobertaConfig, RobertaModel, RobertaTokenizer),
'distilbert': (DistilBertConfig, DistilBertModel, DistilBertTokenizer), "distilbert": (DistilBertConfig, DistilBertModel, DistilBertTokenizer),
'albert': (AlbertConfig, AlbertModel, AlbertTokenizer) "albert": (AlbertConfig, AlbertModel, AlbertTokenizer),
} }
...@@ -81,10 +101,13 @@ def train(args, train_dataset, model, tokenizer, criterion): ...@@ -81,10 +101,13 @@ def train(args, train_dataset, model, tokenizer, criterion):
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, train_dataloader = DataLoader(
train_dataset,
sampler=train_sampler,
batch_size=args.train_batch_size, batch_size=args.train_batch_size,
collate_fn=collate_fn, collate_fn=collate_fn,
num_workers=args.num_workers) num_workers=args.num_workers,
)
if args.max_steps > 0: if args.max_steps > 0:
t_total = args.max_steps t_total = args.max_steps
...@@ -93,14 +116,19 @@ def train(args, train_dataset, model, tokenizer, criterion): ...@@ -93,14 +116,19 @@ def train(args, train_dataset, model, tokenizer, criterion):
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
# Prepare optimizer and schedule (linear warmup and decay) # Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight'] no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, {
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
] ]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)
if args.fp16: if args.fp16:
try: try:
from apex import amp from apex import amp
...@@ -114,17 +142,21 @@ def train(args, train_dataset, model, tokenizer, criterion): ...@@ -114,17 +142,21 @@ def train(args, train_dataset, model, tokenizer, criterion):
# Distributed training (should be after apex fp16 initialization) # Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], model = torch.nn.parallel.DistributedDataParallel(
output_device=args.local_rank, model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
find_unused_parameters=True) )
# Train! # Train!
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", logger.info(
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) " Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size
* args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
)
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total) logger.info(" Total optimization steps = %d", t_total)
...@@ -140,11 +172,13 @@ def train(args, train_dataset, model, tokenizer, criterion): ...@@ -140,11 +172,13 @@ def train(args, train_dataset, model, tokenizer, criterion):
model.train() model.train()
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
labels = batch[5] labels = batch[5]
inputs = {'input_ids': batch[0], inputs = {
'input_modal': batch[2], "input_ids": batch[0],
'attention_mask': batch[1], "input_modal": batch[2],
'modal_start_tokens': batch[3], "attention_mask": batch[1],
'modal_end_tokens': batch[4]} "modal_start_tokens": batch[3],
"modal_end_tokens": batch[4],
}
outputs = model(**inputs) outputs = model(**inputs)
logits = outputs[0] # model outputs are always tuple in transformers (see doc) logits = outputs[0] # model outputs are always tuple in transformers (see doc)
loss = criterion(logits, labels) loss = criterion(logits, labels)
...@@ -174,30 +208,34 @@ def train(args, train_dataset, model, tokenizer, criterion): ...@@ -174,30 +208,34 @@ def train(args, train_dataset, model, tokenizer, criterion):
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
logs = {} logs = {}
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well if (
args.local_rank == -1 and args.evaluate_during_training
): # Only evaluate when single GPU otherwise metrics may not average well
results = evaluate(args, model, tokenizer, criterion) results = evaluate(args, model, tokenizer, criterion)
for key, value in results.items(): for key, value in results.items():
eval_key = 'eval_{}'.format(key) eval_key = "eval_{}".format(key)
logs[eval_key] = value logs[eval_key] = value
loss_scalar = (tr_loss - logging_loss) / args.logging_steps loss_scalar = (tr_loss - logging_loss) / args.logging_steps
learning_rate_scalar = scheduler.get_lr()[0] learning_rate_scalar = scheduler.get_lr()[0]
logs['learning_rate'] = learning_rate_scalar logs["learning_rate"] = learning_rate_scalar
logs['loss'] = loss_scalar logs["loss"] = loss_scalar
logging_loss = tr_loss logging_loss = tr_loss
for key, value in logs.items(): for key, value in logs.items():
tb_writer.add_scalar(key, value, global_step) tb_writer.add_scalar(key, value, global_step)
print(json.dumps({**logs, **{'step': global_step}})) print(json.dumps({**logs, **{"step": global_step}}))
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
# Save model checkpoint # Save model checkpoint
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
torch.save(model_to_save.state_dict(), os.path.join(output_dir, WEIGHTS_NAME)) torch.save(model_to_save.state_dict(), os.path.join(output_dir, WEIGHTS_NAME))
torch.save(args, os.path.join(output_dir, 'training_args.bin')) torch.save(args, os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to %s", output_dir) logger.info("Saving model checkpoint to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps: if args.max_steps > 0 and global_step > args.max_steps:
...@@ -209,8 +247,8 @@ def train(args, train_dataset, model, tokenizer, criterion): ...@@ -209,8 +247,8 @@ def train(args, train_dataset, model, tokenizer, criterion):
if args.local_rank == -1: if args.local_rank == -1:
results = evaluate(args, model, tokenizer, criterion) results = evaluate(args, model, tokenizer, criterion)
if results['micro_f1'] > best_f1: if results["micro_f1"] > best_f1:
best_f1 = results['micro_f1'] best_f1 = results["micro_f1"]
n_no_improve = 0 n_no_improve = 0
else: else:
n_no_improve += 1 n_no_improve += 1
...@@ -236,7 +274,9 @@ def evaluate(args, model, tokenizer, criterion, prefix=""): ...@@ -236,7 +274,9 @@ def evaluate(args, model, tokenizer, criterion, prefix=""):
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
# Note that DistributedSampler samples randomly # Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(eval_dataset) eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate_fn) eval_dataloader = DataLoader(
eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate_fn
)
# multi-gpu eval # multi-gpu eval
if args.n_gpu > 1: if args.n_gpu > 1:
...@@ -257,11 +297,13 @@ def evaluate(args, model, tokenizer, criterion, prefix=""): ...@@ -257,11 +297,13 @@ def evaluate(args, model, tokenizer, criterion, prefix=""):
with torch.no_grad(): with torch.no_grad():
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
labels = batch[5] labels = batch[5]
inputs = {'input_ids': batch[0], inputs = {
'input_modal': batch[2], "input_ids": batch[0],
'attention_mask': batch[1], "input_modal": batch[2],
'modal_start_tokens': batch[3], "attention_mask": batch[1],
'modal_end_tokens': batch[4]} "modal_start_tokens": batch[3],
"modal_end_tokens": batch[4],
}
outputs = model(**inputs) outputs = model(**inputs)
logits = outputs[0] # model outputs are always tuple in transformers (see doc) logits = outputs[0] # model outputs are always tuple in transformers (see doc)
tmp_eval_loss = criterion(logits, labels) tmp_eval_loss = criterion(logits, labels)
...@@ -278,7 +320,7 @@ def evaluate(args, model, tokenizer, criterion, prefix=""): ...@@ -278,7 +320,7 @@ def evaluate(args, model, tokenizer, criterion, prefix=""):
result = { result = {
"loss": eval_loss, "loss": eval_loss,
"macro_f1": f1_score(out_label_ids, preds, average="macro"), "macro_f1": f1_score(out_label_ids, preds, average="macro"),
"micro_f1": f1_score(out_label_ids, preds, average="micro") "micro_f1": f1_score(out_label_ids, preds, average="micro"),
} }
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt") output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
...@@ -303,94 +345,147 @@ def main(): ...@@ -303,94 +345,147 @@ def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters ## Required parameters
parser.add_argument("--data_dir", default=None, type=str, required=True, parser.add_argument(
help="The input data dir. Should contain the .jsonl files for MMIMDB.") "--data_dir",
parser.add_argument("--model_type", default=None, type=str, required=True, default=None,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) type=str,
parser.add_argument("--model_name_or_path", default=None, type=str, required=True, required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) help="The input data dir. Should contain the .jsonl files for MMIMDB.",
parser.add_argument("--output_dir", default=None, type=str, required=True, )
help="The output directory where the model predictions and checkpoints will be written.") parser.add_argument(
"--model_type",
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
)
parser.add_argument(
"--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model predictions and checkpoints will be written.",
)
## Other parameters ## Other parameters
parser.add_argument("--config_name", default="", type=str, parser.add_argument(
help="Pretrained config name or path if not the same as model_name") "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
parser.add_argument("--tokenizer_name", default="", type=str, )
help="Pretrained tokenizer name or path if not the same as model_name") parser.add_argument(
parser.add_argument("--cache_dir", default="", type=str, "--tokenizer_name",
help="Where do you want to store the pre-trained models downloaded from s3") default="",
parser.add_argument("--max_seq_length", default=128, type=int, type=str,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--cache_dir",
default="",
type=str,
help="Where do you want to store the pre-trained models downloaded from s3",
)
parser.add_argument(
"--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer " help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.") "than this will be truncated, sequences shorter will be padded.",
parser.add_argument("--num_image_embeds", default=1, type=int, )
help="Number of Image Embeddings from the Image Encoder") parser.add_argument(
parser.add_argument("--do_train", action='store_true', "--num_image_embeds", default=1, type=int, help="Number of Image Embeddings from the Image Encoder"
help="Whether to run training.") )
parser.add_argument("--do_eval", action='store_true', parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
help="Whether to run eval on the dev set.") parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
parser.add_argument("--evaluate_during_training", action='store_true', parser.add_argument(
help="Rul evaluation during training at each logging step.") "--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
parser.add_argument("--do_lower_case", action='store_true', )
help="Set this flag if you are using an uncased model.") parser.add_argument(
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, )
help="Batch size per GPU/CPU for training.")
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
help="Batch size per GPU/CPU for evaluation.") parser.add_argument(
parser.add_argument('--gradient_accumulation_steps', type=int, default=1, "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
help="Number of updates steps to accumulate before performing a backward/update pass.") )
parser.add_argument("--learning_rate", default=5e-5, type=float, parser.add_argument(
help="The initial learning rate for Adam.") "--gradient_accumulation_steps",
parser.add_argument("--weight_decay", default=0.0, type=float, type=int,
help="Weight deay if we apply some.") default=1,
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Number of updates steps to accumulate before performing a backward/update pass.",
help="Epsilon for Adam optimizer.") )
parser.add_argument("--max_grad_norm", default=1.0, type=float, parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
help="Max gradient norm.") parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
parser.add_argument("--num_train_epochs", default=3.0, type=float, parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
help="Total number of training epochs to perform.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--patience", default=5, type=int, parser.add_argument(
help="Patience for Early Stopping.") "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
parser.add_argument("--max_steps", default=-1, type=int, )
help="If > 0: set total number of training steps to perform. Override num_train_epochs.") parser.add_argument("--patience", default=5, type=int, help="Patience for Early Stopping.")
parser.add_argument("--warmup_steps", default=0, type=int, parser.add_argument(
help="Linear warmup over warmup_steps.") "--max_steps",
default=-1,
parser.add_argument('--logging_steps', type=int, default=50, type=int,
help="Log every X updates steps.") help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
parser.add_argument('--save_steps', type=int, default=50, )
help="Save checkpoint every X updates steps.") parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--eval_all_checkpoints", action='store_true',
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
parser.add_argument("--no_cuda", action='store_true', parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
help="Avoid using CUDA when available") parser.add_argument(
parser.add_argument('--num_workers', type=int, default=8, "--eval_all_checkpoints",
help="number of worker threads for dataloading") action="store_true",
parser.add_argument('--overwrite_output_dir', action='store_true', help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
help="Overwrite the content of the output directory") )
parser.add_argument('--overwrite_cache', action='store_true', parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
help="Overwrite the cached training and evaluation sets") parser.add_argument("--num_workers", type=int, default=8, help="number of worker threads for dataloading")
parser.add_argument('--seed', type=int, default=42, parser.add_argument(
help="random seed for initialization") "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
)
parser.add_argument('--fp16', action='store_true', parser.add_argument(
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
parser.add_argument('--fp16_opt_level', type=str, default='O1', )
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument(
"--fp16",
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
parser.add_argument(
"--fp16_opt_level",
type=str,
default="O1",
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html") "See details at https://nvidia.github.io/apex/amp.html",
parser.add_argument("--local_rank", type=int, default=-1, )
help="For distributed training: local_rank") parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.") parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.") parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
args = parser.parse_args() args = parser.parse_args()
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: if (
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) os.path.exists(args.output_dir)
and os.listdir(args.output_dir)
and args.do_train
and not args.overwrite_output_dir
):
raise ValueError(
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
args.output_dir
)
)
# Setup distant debugging if needed # Setup distant debugging if needed
if args.server_ip and args.server_port: if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd import ptvsd
print("Waiting for debugger attach") print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach() ptvsd.wait_for_attach()
...@@ -402,17 +497,25 @@ def main(): ...@@ -402,17 +497,25 @@ def main():
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.cuda.set_device(args.local_rank) torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank) device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend='nccl') torch.distributed.init_process_group(backend="nccl")
args.n_gpu = 1 args.n_gpu = 1
args.device = device args.device = device
# Setup logging # Setup logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(
datefmt = '%m/%d/%Y %H:%M:%S', format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) datefmt="%m/%d/%Y %H:%M:%S",
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) )
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
args.local_rank,
device,
args.n_gpu,
bool(args.local_rank != -1),
args.fp16,
)
# Set seed # Set seed
set_seed(args) set_seed(args)
...@@ -426,13 +529,17 @@ def main(): ...@@ -426,13 +529,17 @@ def main():
num_labels = len(labels) num_labels = len(labels)
args.model_type = args.model_type.lower() args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
transformer_config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) transformer_config = config_class.from_pretrained(
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, args.config_name if args.config_name else args.model_name_or_path
)
tokenizer = tokenizer_class.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
do_lower_case=args.do_lower_case, do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None) cache_dir=args.cache_dir if args.cache_dir else None,
transformer = model_class.from_pretrained(args.model_name_or_path, )
config=transformer_config, transformer = model_class.from_pretrained(
cache_dir=args.cache_dir if args.cache_dir else None) args.model_name_or_path, config=transformer_config, cache_dir=args.cache_dir if args.cache_dir else None
)
img_encoder = ImageEncoder(args) img_encoder = ImageEncoder(args)
config = MMBTConfig(transformer_config, num_labels=num_labels) config = MMBTConfig(transformer_config, num_labels=num_labels)
model = MMBTForClassification(config, transformer, img_encoder) model = MMBTForClassification(config, transformer, img_encoder)
...@@ -449,12 +556,13 @@ def main(): ...@@ -449,12 +556,13 @@ def main():
train_dataset = load_examples(args, tokenizer, evaluate=False) train_dataset = load_examples(args, tokenizer, evaluate=False)
label_frequences = train_dataset.get_label_frequencies() label_frequences = train_dataset.get_label_frequencies()
label_frequences = [label_frequences[l] for l in labels] label_frequences = [label_frequences[l] for l in labels]
label_weights = (torch.tensor(label_frequences, device=args.device, dtype=torch.float) / len(train_dataset)) ** -1 label_weights = (
torch.tensor(label_frequences, device=args.device, dtype=torch.float) / len(train_dataset)
) ** -1
criterion = nn.BCEWithLogitsLoss(pos_weight=label_weights) criterion = nn.BCEWithLogitsLoss(pos_weight=label_weights)
global_step, tr_loss = train(args, train_dataset, model, tokenizer, criterion) global_step, tr_loss = train(args, train_dataset, model, tokenizer, criterion)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
# Create output directory if needed # Create output directory if needed
...@@ -464,12 +572,14 @@ def main(): ...@@ -464,12 +572,14 @@ def main():
logger.info("Saving model checkpoint to %s", args.output_dir) logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`. # Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()` # They can then be reloaded using `from_pretrained()`
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
torch.save(model_to_save.state_dict(), os.path.join(args.output_dir, WEIGHTS_NAME)) torch.save(model_to_save.state_dict(), os.path.join(args.output_dir, WEIGHTS_NAME))
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
# Good practice: save your training arguments together with the trained model # Good practice: save your training arguments together with the trained model
torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
# Load a trained model and vocabulary that you have fine-tuned # Load a trained model and vocabulary that you have fine-tuned
model = MMBTForClassification(config, transformer, img_encoder) model = MMBTForClassification(config, transformer, img_encoder)
...@@ -477,24 +587,25 @@ def main(): ...@@ -477,24 +587,25 @@ def main():
tokenizer = tokenizer_class.from_pretrained(args.output_dir) tokenizer = tokenizer_class.from_pretrained(args.output_dir)
model.to(args.device) model.to(args.device)
# Evaluation # Evaluation
results = {} results = {}
if args.do_eval and args.local_rank in [-1, 0]: if args.do_eval and args.local_rank in [-1, 0]:
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
checkpoints = [args.output_dir] checkpoints = [args.output_dir]
if args.eval_all_checkpoints: if args.eval_all_checkpoints:
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) checkpoints = list(
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
)
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else "" prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
model = MMBTForClassification(config, transformer, img_encoder) model = MMBTForClassification(config, transformer, img_encoder)
model.load_state_dict(torch.load(checkpoint)) model.load_state_dict(torch.load(checkpoint))
model.to(args.device) model.to(args.device)
result = evaluate(args, model, tokenizer, criterion, prefix=prefix) result = evaluate(args, model, tokenizer, criterion, prefix=prefix)
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
results.update(result) results.update(result)
return results return results
......
...@@ -25,17 +25,7 @@ import torchvision ...@@ -25,17 +25,7 @@ import torchvision
import torchvision.transforms as transforms import torchvision.transforms as transforms
from torch.utils.data import Dataset from torch.utils.data import Dataset
POOLING_BREAKDOWN = { POOLING_BREAKDOWN = {1: (1, 1), 2: (2, 1), 3: (3, 1), 4: (2, 2), 5: (5, 1), 6: (3, 2), 7: (7, 1), 8: (4, 2), 9: (3, 3)}
1: (1, 1),
2: (2, 1),
3: (3, 1),
4: (2, 2),
5: (5, 1),
6: (3, 2),
7: (7, 1),
8: (4, 2),
9: (3, 3)
}
class ImageEncoder(nn.Module): class ImageEncoder(nn.Module):
...@@ -54,7 +44,6 @@ class ImageEncoder(nn.Module): ...@@ -54,7 +44,6 @@ class ImageEncoder(nn.Module):
return out # BxNx2048 return out # BxNx2048
class JsonlDataset(Dataset): class JsonlDataset(Dataset):
def __init__(self, data_path, tokenizer, transforms, labels, max_seq_length): def __init__(self, data_path, tokenizer, transforms, labels, max_seq_length):
self.data = [json.loads(l) for l in open(data_path)] self.data = [json.loads(l) for l in open(data_path)]
...@@ -72,7 +61,7 @@ class JsonlDataset(Dataset): ...@@ -72,7 +61,7 @@ class JsonlDataset(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
sentence = torch.LongTensor(self.tokenizer.encode(self.data[index]["text"], add_special_tokens=True)) sentence = torch.LongTensor(self.tokenizer.encode(self.data[index]["text"], add_special_tokens=True))
start_token, sentence, end_token = sentence[0], sentence[1:-1], sentence[-1] start_token, sentence, end_token = sentence[0], sentence[1:-1], sentence[-1]
sentence = sentence[:self.max_seq_length] sentence = sentence[: self.max_seq_length]
label = torch.zeros(self.n_classes) label = torch.zeros(self.n_classes)
label[[self.labels.index(tgt) for tgt in self.data[index]["label"]]] = 1 label[[self.labels.index(tgt) for tgt in self.data[index]["label"]]] = 1
...@@ -80,8 +69,13 @@ class JsonlDataset(Dataset): ...@@ -80,8 +69,13 @@ class JsonlDataset(Dataset):
image = Image.open(os.path.join(self.data_dir, self.data[index]["img"])).convert("RGB") image = Image.open(os.path.join(self.data_dir, self.data[index]["img"])).convert("RGB")
image = self.transforms(image) image = self.transforms(image)
return {"image_start_token": start_token, "image_end_token": end_token, return {
"sentence": sentence, "image": image, "label": label} "image_start_token": start_token,
"image_end_token": end_token,
"sentence": sentence,
"image": image,
"label": label,
}
def get_label_frequencies(self): def get_label_frequencies(self):
label_freqs = Counter() label_freqs = Counter()
...@@ -110,10 +104,31 @@ def collate_fn(batch): ...@@ -110,10 +104,31 @@ def collate_fn(batch):
def get_mmimdb_labels(): def get_mmimdb_labels():
return ['Crime', 'Drama', 'Thriller', 'Action', 'Comedy', 'Romance', return [
'Documentary', 'Short', 'Mystery', 'History', 'Family', 'Adventure', "Crime",
'Fantasy', 'Sci-Fi', 'Western', 'Horror', 'Sport', 'War', 'Music', "Drama",
'Musical', 'Animation', 'Biography', 'Film-Noir'] "Thriller",
"Action",
"Comedy",
"Romance",
"Documentary",
"Short",
"Mystery",
"History",
"Family",
"Adventure",
"Fantasy",
"Sci-Fi",
"Western",
"Horror",
"Sport",
"War",
"Music",
"Musical",
"Animation",
"Biography",
"Film-Noir",
]
def get_image_transforms(): def get_image_transforms():
...@@ -122,9 +137,6 @@ def get_image_transforms(): ...@@ -122,9 +137,6 @@ def get_image_transforms():
transforms.Resize(256), transforms.Resize(256),
transforms.CenterCrop(224), transforms.CenterCrop(224),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize( transforms.Normalize(mean=[0.46777044, 0.44531429, 0.40661017], std=[0.12221994, 0.12145835, 0.14380469],),
mean=[0.46777044, 0.44531429, 0.40661017],
std=[0.12221994, 0.12145835, 0.14380469],
),
] ]
) )
import torch import torch
class ClassificationHead(torch.nn.Module): class ClassificationHead(torch.nn.Module):
"""Classification Head for transformer encoders""" """Classification Head for transformer encoders"""
......
#! /usr/bin/env python3 #! /usr/bin/env python3
# coding=utf-8 # coding=utf-8
#Copyright (c) 2019 Uber Technologies, Inc. # Copyright (c) 2019 Uber Technologies, Inc.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
#You may obtain a copy of the License at # You may obtain a copy of the License at
# #
#http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
#Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
#limitations under the License. # limitations under the License.
""" """
Example command with bag of words: Example command with bag of words:
...@@ -46,13 +46,13 @@ SMALL_CONST = 1e-15 ...@@ -46,13 +46,13 @@ SMALL_CONST = 1e-15
BIG_CONST = 1e10 BIG_CONST = 1e10
BAG_OF_WORDS_ARCHIVE_MAP = { BAG_OF_WORDS_ARCHIVE_MAP = {
'legal': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/legal.txt", "legal": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/legal.txt",
'military': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/military.txt", "military": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/military.txt",
'politics': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/politics.txt", "politics": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/politics.txt",
'religion': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/religion.txt", "religion": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/religion.txt",
'science': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/science.txt", "science": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/science.txt",
'space': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/space.txt", "space": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/space.txt",
'technology': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/technology.txt", "technology": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/technology.txt",
} }
DISCRIMINATOR_MODELS_PARAMS = { DISCRIMINATOR_MODELS_PARAMS = {
...@@ -75,10 +75,10 @@ DISCRIMINATOR_MODELS_PARAMS = { ...@@ -75,10 +75,10 @@ DISCRIMINATOR_MODELS_PARAMS = {
} }
def to_var(x, requires_grad=False, volatile=False, device='cuda'): def to_var(x, requires_grad=False, volatile=False, device="cuda"):
if torch.cuda.is_available() and device == 'cuda': if torch.cuda.is_available() and device == "cuda":
x = x.cuda() x = x.cuda()
elif device != 'cuda': elif device != "cuda":
x = x.to(device) x = x.to(device)
return Variable(x, requires_grad=requires_grad, volatile=volatile) return Variable(x, requires_grad=requires_grad, volatile=volatile)
...@@ -95,11 +95,8 @@ def top_k_filter(logits, k, probs=False): ...@@ -95,11 +95,8 @@ def top_k_filter(logits, k, probs=False):
values = torch.topk(logits, k)[0] values = torch.topk(logits, k)[0]
batch_mins = values[:, -1].view(-1, 1).expand_as(logits) batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
if probs: if probs:
return torch.where(logits < batch_mins, return torch.where(logits < batch_mins, torch.ones_like(logits) * 0.0, logits)
torch.ones_like(logits) * 0.0, logits) return torch.where(logits < batch_mins, torch.ones_like(logits) * -BIG_CONST, logits)
return torch.where(logits < batch_mins,
torch.ones_like(logits) * -BIG_CONST,
logits)
def perturb_past( def perturb_past(
...@@ -121,23 +118,16 @@ def perturb_past( ...@@ -121,23 +118,16 @@ def perturb_past(
decay=False, decay=False,
gamma=1.5, gamma=1.5,
kl_scale=0.01, kl_scale=0.01,
device='cuda', device="cuda",
): ):
# Generate inital perturbed past # Generate inital perturbed past
grad_accumulator = [ grad_accumulator = [(np.zeros(p.shape).astype("float32")) for p in past]
(np.zeros(p.shape).astype("float32"))
for p in past
]
if accumulated_hidden is None: if accumulated_hidden is None:
accumulated_hidden = 0 accumulated_hidden = 0
if decay: if decay:
decay_mask = torch.arange( decay_mask = torch.arange(0.0, 1.0 + SMALL_CONST, 1.0 / (window_length))[1:]
0.,
1.0 + SMALL_CONST,
1.0 / (window_length)
)[1:]
else: else:
decay_mask = 1.0 decay_mask = 1.0
...@@ -146,26 +136,17 @@ def perturb_past( ...@@ -146,26 +136,17 @@ def perturb_past(
_, _, _, curr_length, _ = past[0].shape _, _, _, curr_length, _ = past[0].shape
if curr_length > window_length and window_length > 0: if curr_length > window_length and window_length > 0:
ones_key_val_shape = ( ones_key_val_shape = tuple(past[0].shape[:-2]) + tuple([window_length]) + tuple(past[0].shape[-1:])
tuple(past[0].shape[:-2])
+ tuple([window_length])
+ tuple(past[0].shape[-1:])
)
zeros_key_val_shape = ( zeros_key_val_shape = (
tuple(past[0].shape[:-2]) tuple(past[0].shape[:-2]) + tuple([curr_length - window_length]) + tuple(past[0].shape[-1:])
+ tuple([curr_length - window_length])
+ tuple(past[0].shape[-1:])
) )
ones_mask = torch.ones(ones_key_val_shape) ones_mask = torch.ones(ones_key_val_shape)
ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3) ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)
ones_mask = ones_mask.permute(0, 1, 2, 4, 3) ones_mask = ones_mask.permute(0, 1, 2, 4, 3)
window_mask = torch.cat( window_mask = torch.cat((ones_mask, torch.zeros(zeros_key_val_shape)), dim=-2).to(device)
(ones_mask, torch.zeros(zeros_key_val_shape)),
dim=-2
).to(device)
else: else:
window_mask = torch.ones_like(past[0]).to(device) window_mask = torch.ones_like(past[0]).to(device)
...@@ -175,8 +156,7 @@ def perturb_past( ...@@ -175,8 +156,7 @@ def perturb_past(
for i in range(num_iterations): for i in range(num_iterations):
print("Iteration ", i + 1) print("Iteration ", i + 1)
curr_perturbation = [ curr_perturbation = [
to_var(torch.from_numpy(p_), requires_grad=True, device=device) to_var(torch.from_numpy(p_), requires_grad=True, device=device) for p_ in grad_accumulator
for p_ in grad_accumulator
] ]
# Compute hidden using perturbed past # Compute hidden using perturbed past
...@@ -184,10 +164,7 @@ def perturb_past( ...@@ -184,10 +164,7 @@ def perturb_past(
_, _, _, curr_length, _ = curr_perturbation[0].shape _, _, _, curr_length, _ = curr_perturbation[0].shape
all_logits, _, all_hidden = model(last, past=perturbed_past) all_logits, _, all_hidden = model(last, past=perturbed_past)
hidden = all_hidden[-1] hidden = all_hidden[-1]
new_accumulated_hidden = accumulated_hidden + torch.sum( new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach()
hidden,
dim=1
).detach()
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth) # TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
logits = all_logits[:, -1, :] logits = all_logits[:, -1, :]
probs = F.softmax(logits, dim=-1) probs = F.softmax(logits, dim=-1)
...@@ -210,20 +187,13 @@ def perturb_past( ...@@ -210,20 +187,13 @@ def perturb_past(
wte = model.resize_token_embeddings() wte = model.resize_token_embeddings()
for _ in range(horizon_length): for _ in range(horizon_length):
inputs_embeds = torch.matmul(curr_probs, wte.weight.data) inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
_, curr_unpert_past, curr_all_hidden = model( _, curr_unpert_past, curr_all_hidden = model(past=curr_unpert_past, inputs_embeds=inputs_embeds)
past=curr_unpert_past,
inputs_embeds=inputs_embeds
)
curr_hidden = curr_all_hidden[-1] curr_hidden = curr_all_hidden[-1]
new_accumulated_hidden = new_accumulated_hidden + torch.sum( new_accumulated_hidden = new_accumulated_hidden + torch.sum(curr_hidden, dim=1)
curr_hidden, dim=1)
prediction = classifier(new_accumulated_hidden / prediction = classifier(new_accumulated_hidden / (curr_length + 1 + horizon_length))
(curr_length + 1 + horizon_length))
label = torch.tensor(prediction.shape[0] * [class_label], label = torch.tensor(prediction.shape[0] * [class_label], device=device, dtype=torch.long)
device=device,
dtype=torch.long)
discrim_loss = ce_loss(prediction, label) discrim_loss = ce_loss(prediction, label)
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy()) print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
loss += discrim_loss loss += discrim_loss
...@@ -232,21 +202,15 @@ def perturb_past( ...@@ -232,21 +202,15 @@ def perturb_past(
kl_loss = 0.0 kl_loss = 0.0
if kl_scale > 0.0: if kl_scale > 0.0:
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1) unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
unpert_probs = ( unpert_probs = unpert_probs + SMALL_CONST * (unpert_probs <= SMALL_CONST).float().to(device).detach()
unpert_probs + SMALL_CONST * correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach()
(unpert_probs <= SMALL_CONST).float().to(device).detach()
)
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(
device).detach()
corrected_probs = probs + correction.detach() corrected_probs = probs + correction.detach()
kl_loss = kl_scale * ( kl_loss = kl_scale * ((corrected_probs * (corrected_probs / unpert_probs).log()).sum())
(corrected_probs * (corrected_probs / unpert_probs).log()).sum() print(" kl_loss", kl_loss.data.cpu().numpy())
)
print(' kl_loss', kl_loss.data.cpu().numpy())
loss += kl_loss loss += kl_loss
loss_per_iter.append(loss.data.cpu().numpy()) loss_per_iter.append(loss.data.cpu().numpy())
print(' pplm_loss', (loss - kl_loss).data.cpu().numpy()) print(" pplm_loss", (loss - kl_loss).data.cpu().numpy())
# compute gradients # compute gradients
loss.backward() loss.backward()
...@@ -259,15 +223,12 @@ def perturb_past( ...@@ -259,15 +223,12 @@ def perturb_past(
] ]
else: else:
grad_norms = [ grad_norms = [
(torch.norm(p_.grad * window_mask) + SMALL_CONST) (torch.norm(p_.grad * window_mask) + SMALL_CONST) for index, p_ in enumerate(curr_perturbation)
for index, p_ in enumerate(curr_perturbation)
] ]
# normalize gradients # normalize gradients
grad = [ grad = [
-stepsize * -stepsize * (p_.grad * window_mask / grad_norms[index] ** gamma).data.cpu().numpy()
(p_.grad * window_mask / grad_norms[
index] ** gamma).data.cpu().numpy()
for index, p_ in enumerate(curr_perturbation) for index, p_ in enumerate(curr_perturbation)
] ]
...@@ -285,36 +246,27 @@ def perturb_past( ...@@ -285,36 +246,27 @@ def perturb_past(
past = new_past past = new_past
# apply the accumulated perturbations to the past # apply the accumulated perturbations to the past
grad_accumulator = [ grad_accumulator = [to_var(torch.from_numpy(p_), requires_grad=True, device=device) for p_ in grad_accumulator]
to_var(torch.from_numpy(p_), requires_grad=True, device=device)
for p_ in grad_accumulator
]
pert_past = list(map(add, past, grad_accumulator)) pert_past = list(map(add, past, grad_accumulator))
return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter
def get_classifier( def get_classifier(
name: Optional[str], class_label: Union[str, int], name: Optional[str], class_label: Union[str, int], device: str
device: str
) -> Tuple[Optional[ClassificationHead], Optional[int]]: ) -> Tuple[Optional[ClassificationHead], Optional[int]]:
if name is None: if name is None:
return None, None return None, None
params = DISCRIMINATOR_MODELS_PARAMS[name] params = DISCRIMINATOR_MODELS_PARAMS[name]
classifier = ClassificationHead( classifier = ClassificationHead(class_size=params["class_size"], embed_size=params["embed_size"]).to(device)
class_size=params['class_size'],
embed_size=params['embed_size']
).to(device)
if "url" in params: if "url" in params:
resolved_archive_file = cached_path(params["url"]) resolved_archive_file = cached_path(params["url"])
elif "path" in params: elif "path" in params:
resolved_archive_file = params["path"] resolved_archive_file = params["path"]
else: else:
raise ValueError("Either url or path have to be specified " raise ValueError("Either url or path have to be specified " "in the discriminator model parameters")
"in the discriminator model parameters") classifier.load_state_dict(torch.load(resolved_archive_file, map_location=device))
classifier.load_state_dict(
torch.load(resolved_archive_file, map_location=device))
classifier.eval() classifier.eval()
if isinstance(class_label, str): if isinstance(class_label, str):
...@@ -341,8 +293,7 @@ def get_classifier( ...@@ -341,8 +293,7 @@ def get_classifier(
return classifier, label_id return classifier, label_id
def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> \ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> List[List[List[int]]]:
List[List[List[int]]]:
bow_indices = [] bow_indices = []
for id_or_path in bag_of_words_ids_or_paths: for id_or_path in bag_of_words_ids_or_paths:
if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP: if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP:
...@@ -351,13 +302,11 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> ...@@ -351,13 +302,11 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) ->
filepath = id_or_path filepath = id_or_path
with open(filepath, "r") as f: with open(filepath, "r") as f:
words = f.read().strip().split("\n") words = f.read().strip().split("\n")
bow_indices.append( bow_indices.append([tokenizer.encode(word.strip(), add_prefix_space=True) for word in words])
[tokenizer.encode(word.strip(), add_prefix_space=True) for word in
words])
return bow_indices return bow_indices
def build_bows_one_hot_vectors(bow_indices, tokenizer, device='cuda'): def build_bows_one_hot_vectors(bow_indices, tokenizer, device="cuda"):
if bow_indices is None: if bow_indices is None:
return None return None
...@@ -396,16 +345,11 @@ def full_text_generation( ...@@ -396,16 +345,11 @@ def full_text_generation(
kl_scale=0.01, kl_scale=0.01,
**kwargs **kwargs
): ):
classifier, class_id = get_classifier( classifier, class_id = get_classifier(discrim, class_label, device)
discrim,
class_label,
device
)
bow_indices = [] bow_indices = []
if bag_of_words: if bag_of_words:
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"), bow_indices = get_bag_of_words_indices(bag_of_words.split(";"), tokenizer)
tokenizer)
if bag_of_words and classifier: if bag_of_words and classifier:
print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.") print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
...@@ -423,15 +367,9 @@ def full_text_generation( ...@@ -423,15 +367,9 @@ def full_text_generation(
raise Exception("Specify either a bag of words or a discriminator") raise Exception("Specify either a bag of words or a discriminator")
unpert_gen_tok_text, _, _ = generate_text_pplm( unpert_gen_tok_text, _, _ = generate_text_pplm(
model=model, model=model, tokenizer=tokenizer, context=context, device=device, length=length, sample=sample, perturb=False
tokenizer=tokenizer,
context=context,
device=device,
length=length,
sample=sample,
perturb=False
) )
if device == 'cuda': if device == "cuda":
torch.cuda.empty_cache() torch.cuda.empty_cache()
pert_gen_tok_texts = [] pert_gen_tok_texts = []
...@@ -468,7 +406,7 @@ def full_text_generation( ...@@ -468,7 +406,7 @@ def full_text_generation(
discrim_losses.append(discrim_loss.data.cpu().numpy()) discrim_losses.append(discrim_loss.data.cpu().numpy())
losses_in_time.append(loss_in_time) losses_in_time.append(loss_in_time)
if device == 'cuda': if device == "cuda":
torch.cuda.empty_cache() torch.cuda.empty_cache()
return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
...@@ -507,8 +445,7 @@ def generate_text_pplm( ...@@ -507,8 +445,7 @@ def generate_text_pplm(
output_so_far = context_t output_so_far = context_t
# collect one hot vectors for bags of words # collect one hot vectors for bags of words
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer, one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer, device)
device)
grad_norms = None grad_norms = None
last = None last = None
...@@ -575,13 +512,9 @@ def generate_text_pplm( ...@@ -575,13 +512,9 @@ def generate_text_pplm(
if classifier is not None: if classifier is not None:
ce_loss = torch.nn.CrossEntropyLoss() ce_loss = torch.nn.CrossEntropyLoss()
prediction = classifier(torch.mean(unpert_last_hidden, dim=1)) prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
label = torch.tensor([class_label], device=device, label = torch.tensor([class_label], device=device, dtype=torch.long)
dtype=torch.long)
unpert_discrim_loss = ce_loss(prediction, label) unpert_discrim_loss = ce_loss(prediction, label)
print( print("unperturbed discrim loss", unpert_discrim_loss.data.cpu().numpy())
"unperturbed discrim loss",
unpert_discrim_loss.data.cpu().numpy()
)
else: else:
unpert_discrim_loss = 0 unpert_discrim_loss = 0
...@@ -590,10 +523,8 @@ def generate_text_pplm( ...@@ -590,10 +523,8 @@ def generate_text_pplm(
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1) unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
pert_probs = ((pert_probs ** gm_scale) * ( pert_probs = (pert_probs ** gm_scale) * (unpert_probs ** (1 - gm_scale)) # + SMALL_CONST
unpert_probs ** (1 - gm_scale))) # + SMALL_CONST pert_probs = top_k_filter(pert_probs, k=top_k, probs=True) # + SMALL_CONST
pert_probs = top_k_filter(pert_probs, k=top_k,
probs=True) # + SMALL_CONST
# rescale # rescale
if torch.sum(pert_probs) <= 1: if torch.sum(pert_probs) <= 1:
...@@ -611,10 +542,7 @@ def generate_text_pplm( ...@@ -611,10 +542,7 @@ def generate_text_pplm(
_, last = torch.topk(pert_probs, k=1, dim=-1) _, last = torch.topk(pert_probs, k=1, dim=-1)
# update context/output_so_far appending the new token # update context/output_so_far appending the new token
output_so_far = ( output_so_far = last if output_so_far is None else torch.cat((output_so_far, last), dim=1)
last if output_so_far is None
else torch.cat((output_so_far, last), dim=1)
)
print(tokenizer.decode(output_so_far.tolist()[0])) print(tokenizer.decode(output_so_far.tolist()[0]))
...@@ -623,16 +551,14 @@ def generate_text_pplm( ...@@ -623,16 +551,14 @@ def generate_text_pplm(
def set_generic_model_params(discrim_weights, discrim_meta): def set_generic_model_params(discrim_weights, discrim_meta):
if discrim_weights is None: if discrim_weights is None:
raise ValueError('When using a generic discriminator, ' raise ValueError("When using a generic discriminator, " "discrim_weights need to be specified")
'discrim_weights need to be specified')
if discrim_meta is None: if discrim_meta is None:
raise ValueError('When using a generic discriminator, ' raise ValueError("When using a generic discriminator, " "discrim_meta need to be specified")
'discrim_meta need to be specified')
with open(discrim_meta, 'r') as discrim_meta_file: with open(discrim_meta, "r") as discrim_meta_file:
meta = json.load(discrim_meta_file) meta = json.load(discrim_meta_file)
meta['path'] = discrim_weights meta["path"] = discrim_weights
DISCRIMINATOR_MODELS_PARAMS['generic'] = meta DISCRIMINATOR_MODELS_PARAMS["generic"] = meta
def run_pplm_example( def run_pplm_example(
...@@ -660,7 +586,7 @@ def run_pplm_example( ...@@ -660,7 +586,7 @@ def run_pplm_example(
kl_scale=0.01, kl_scale=0.01,
seed=0, seed=0,
no_cuda=False, no_cuda=False,
colorama=False colorama=False,
): ):
# set Random seed # set Random seed
torch.manual_seed(seed) torch.manual_seed(seed)
...@@ -669,21 +595,15 @@ def run_pplm_example( ...@@ -669,21 +595,15 @@ def run_pplm_example(
# set the device # set the device
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu" device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
if discrim == 'generic': if discrim == "generic":
set_generic_model_params(discrim_weights, discrim_meta) set_generic_model_params(discrim_weights, discrim_meta)
if discrim is not None: if discrim is not None:
pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim][ pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim]["pretrained_model"]
"pretrained_model" print("discrim = {}, pretrained_model set " "to discriminator's = {}".format(discrim, pretrained_model))
]
print("discrim = {}, pretrained_model set "
"to discriminator's = {}".format(discrim, pretrained_model))
# load pretrained model # load pretrained model
model = GPT2LMHeadModel.from_pretrained( model = GPT2LMHeadModel.from_pretrained(pretrained_model, output_hidden_states=True)
pretrained_model,
output_hidden_states=True
)
model.to(device) model.to(device)
model.eval() model.eval()
...@@ -696,9 +616,7 @@ def run_pplm_example( ...@@ -696,9 +616,7 @@ def run_pplm_example(
# figure out conditioning text # figure out conditioning text
if uncond: if uncond:
tokenized_cond_text = tokenizer.encode( tokenized_cond_text = tokenizer.encode([tokenizer.bos_token])
[tokenizer.bos_token]
)
else: else:
raw_text = cond_text raw_text = cond_text
while not raw_text: while not raw_text:
...@@ -750,8 +668,7 @@ def run_pplm_example( ...@@ -750,8 +668,7 @@ def run_pplm_example(
bow_word_ids = set() bow_word_ids = set()
if bag_of_words and colorama: if bag_of_words and colorama:
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"), bow_indices = get_bag_of_words_indices(bag_of_words.split(";"), tokenizer)
tokenizer)
for single_bow_list in bow_indices: for single_bow_list in bow_indices:
# filtering all words in the list composed of more than 1 token # filtering all words in the list composed of more than 1 token
filtered = list(filter(lambda x: len(x) <= 1, single_bow_list)) filtered = list(filter(lambda x: len(x) <= 1, single_bow_list))
...@@ -765,13 +682,11 @@ def run_pplm_example( ...@@ -765,13 +682,11 @@ def run_pplm_example(
if colorama: if colorama:
import colorama import colorama
pert_gen_text = '' pert_gen_text = ""
for word_id in pert_gen_tok_text.tolist()[0]: for word_id in pert_gen_tok_text.tolist()[0]:
if word_id in bow_word_ids: if word_id in bow_word_ids:
pert_gen_text += '{}{}{}'.format( pert_gen_text += "{}{}{}".format(
colorama.Fore.RED, colorama.Fore.RED, tokenizer.decode([word_id]), colorama.Style.RESET_ALL
tokenizer.decode([word_id]),
colorama.Style.RESET_ALL
) )
else: else:
pert_gen_text += tokenizer.decode([word_id]) pert_gen_text += tokenizer.decode([word_id])
...@@ -785,14 +700,12 @@ def run_pplm_example( ...@@ -785,14 +700,12 @@ def run_pplm_example(
pass pass
# keep the prefix, perturbed seq, original seq for each index # keep the prefix, perturbed seq, original seq for each index
generated_texts.append( generated_texts.append((tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text))
(tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text)
)
return return
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--pretrained_model", "--pretrained_model",
...@@ -801,19 +714,10 @@ if __name__ == '__main__': ...@@ -801,19 +714,10 @@ if __name__ == '__main__':
default="gpt2-medium", default="gpt2-medium",
help="pretrained model name or path to local checkpoint", help="pretrained model name or path to local checkpoint",
) )
parser.add_argument("--cond_text", type=str, default="The lake", help="Prefix texts to condition on")
parser.add_argument("--uncond", action="store_true", help="Generate from end-of-text as prefix")
parser.add_argument( parser.add_argument(
"--cond_text", type=str, default="The lake", "--num_samples", type=int, default=1, help="Number of samples to generate from the modified latents",
help="Prefix texts to condition on"
)
parser.add_argument(
"--uncond", action="store_true",
help="Generate from end-of-text as prefix"
)
parser.add_argument(
"--num_samples",
type=int,
default=1,
help="Number of samples to generate from the modified latents",
) )
parser.add_argument( parser.add_argument(
"--bag_of_words", "--bag_of_words",
...@@ -832,48 +736,36 @@ if __name__ == '__main__': ...@@ -832,48 +736,36 @@ if __name__ == '__main__':
choices=("clickbait", "sentiment", "toxicity", "generic"), choices=("clickbait", "sentiment", "toxicity", "generic"),
help="Discriminator to use", help="Discriminator to use",
) )
parser.add_argument('--discrim_weights', type=str, default=None, parser.add_argument("--discrim_weights", type=str, default=None, help="Weights for the generic discriminator")
help='Weights for the generic discriminator')
parser.add_argument('--discrim_meta', type=str, default=None,
help='Meta information for the generic discriminator')
parser.add_argument( parser.add_argument(
"--class_label", "--discrim_meta", type=str, default=None, help="Meta information for the generic discriminator"
type=int, )
default=-1, parser.add_argument(
help="Class label used for the discriminator", "--class_label", type=int, default=-1, help="Class label used for the discriminator",
) )
parser.add_argument("--length", type=int, default=100) parser.add_argument("--length", type=int, default=100)
parser.add_argument("--stepsize", type=float, default=0.02) parser.add_argument("--stepsize", type=float, default=0.02)
parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_k", type=int, default=10) parser.add_argument("--top_k", type=int, default=10)
parser.add_argument( parser.add_argument("--sample", action="store_true", help="Generate from end-of-text as prefix")
"--sample", action="store_true",
help="Generate from end-of-text as prefix"
)
parser.add_argument("--num_iterations", type=int, default=3) parser.add_argument("--num_iterations", type=int, default=3)
parser.add_argument("--grad_length", type=int, default=10000) parser.add_argument("--grad_length", type=int, default=10000)
parser.add_argument( parser.add_argument(
"--window_length", "--window_length",
type=int, type=int,
default=0, default=0,
help="Length of past which is being optimized; " help="Length of past which is being optimized; " "0 corresponds to infinite window length",
"0 corresponds to infinite window length",
) )
parser.add_argument( parser.add_argument(
"--horizon_length", "--horizon_length", type=int, default=1, help="Length of future to optimize over",
type=int,
default=1,
help="Length of future to optimize over",
) )
parser.add_argument("--decay", action="store_true", parser.add_argument("--decay", action="store_true", help="whether to decay or not")
help="whether to decay or not")
parser.add_argument("--gamma", type=float, default=1.5) parser.add_argument("--gamma", type=float, default=1.5)
parser.add_argument("--gm_scale", type=float, default=0.9) parser.add_argument("--gm_scale", type=float, default=0.9)
parser.add_argument("--kl_scale", type=float, default=0.01) parser.add_argument("--kl_scale", type=float, default=0.01)
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--no_cuda", action="store_true", help="no cuda") parser.add_argument("--no_cuda", action="store_true", help="no cuda")
parser.add_argument("--colorama", action="store_true", parser.add_argument("--colorama", action="store_true", help="colors keywords")
help="colors keywords")
args = parser.parse_args() args = parser.parse_args()
run_pplm_example(**vars(args)) run_pplm_example(**vars(args))
#! /usr/bin/env python3 #! /usr/bin/env python3
# coding=utf-8 # coding=utf-8
#Copyright (c) 2019 Uber Technologies, Inc. # Copyright (c) 2019 Uber Technologies, Inc.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
#You may obtain a copy of the License at # You may obtain a copy of the License at
# #
#http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
#Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
#limitations under the License. # limitations under the License.
import argparse import argparse
import csv import csv
...@@ -42,26 +42,15 @@ example_sentence = "This is incredible! I love it, this is the best chicken I ha ...@@ -42,26 +42,15 @@ example_sentence = "This is incredible! I love it, this is the best chicken I ha
max_length_seq = 100 max_length_seq = 100
class Discriminator(torch.nn.Module): class Discriminator(torch.nn.Module):
"""Transformer encoder followed by a Classification Head""" """Transformer encoder followed by a Classification Head"""
def __init__( def __init__(self, class_size, pretrained_model="gpt2-medium", cached_mode=False, device="cpu"):
self,
class_size,
pretrained_model="gpt2-medium",
cached_mode=False,
device='cpu'
):
super(Discriminator, self).__init__() super(Discriminator, self).__init__()
self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model) self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
self.encoder = GPT2LMHeadModel.from_pretrained(pretrained_model) self.encoder = GPT2LMHeadModel.from_pretrained(pretrained_model)
self.embed_size = self.encoder.transformer.config.hidden_size self.embed_size = self.encoder.transformer.config.hidden_size
self.classifier_head = ClassificationHead( self.classifier_head = ClassificationHead(class_size=class_size, embed_size=self.embed_size)
class_size=class_size,
embed_size=self.embed_size
)
self.cached_mode = cached_mode self.cached_mode = cached_mode
self.device = device self.device = device
...@@ -74,14 +63,10 @@ class Discriminator(torch.nn.Module): ...@@ -74,14 +63,10 @@ class Discriminator(torch.nn.Module):
self.classifier_head.train() self.classifier_head.train()
def avg_representation(self, x): def avg_representation(self, x):
mask = x.ne(0).unsqueeze(2).repeat( mask = x.ne(0).unsqueeze(2).repeat(1, 1, self.embed_size).float().to(self.device).detach()
1, 1, self.embed_size
).float().to(self.device).detach()
hidden, _ = self.encoder.transformer(x) hidden, _ = self.encoder.transformer(x)
masked_hidden = hidden * mask masked_hidden = hidden * mask
avg_hidden = torch.sum(masked_hidden, dim=1) / ( avg_hidden = torch.sum(masked_hidden, dim=1) / (torch.sum(mask, dim=1).detach() + EPSILON)
torch.sum(mask, dim=1).detach() + EPSILON
)
return avg_hidden return avg_hidden
def forward(self, x): def forward(self, x):
...@@ -117,10 +102,7 @@ def collate_fn(data): ...@@ -117,10 +102,7 @@ def collate_fn(data):
def pad_sequences(sequences): def pad_sequences(sequences):
lengths = [len(seq) for seq in sequences] lengths = [len(seq) for seq in sequences]
padded_sequences = torch.zeros( padded_sequences = torch.zeros(len(sequences), max(lengths)).long() # padding value = 0
len(sequences),
max(lengths)
).long() # padding value = 0
for i, seq in enumerate(sequences): for i, seq in enumerate(sequences):
end = lengths[i] end = lengths[i]
...@@ -149,8 +131,7 @@ def cached_collate_fn(data): ...@@ -149,8 +131,7 @@ def cached_collate_fn(data):
return x_batch, y_batch return x_batch, y_batch
def train_epoch(data_loader, discriminator, optimizer, def train_epoch(data_loader, discriminator, optimizer, epoch=0, log_interval=10, device="cpu"):
epoch=0, log_interval=10, device='cpu'):
samples_so_far = 0 samples_so_far = 0
discriminator.train_custom() discriminator.train_custom()
for batch_idx, (input_t, target_t) in enumerate(data_loader): for batch_idx, (input_t, target_t) in enumerate(data_loader):
...@@ -169,13 +150,15 @@ def train_epoch(data_loader, discriminator, optimizer, ...@@ -169,13 +150,15 @@ def train_epoch(data_loader, discriminator, optimizer,
print( print(
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
epoch + 1, epoch + 1,
samples_so_far, len(data_loader.dataset), samples_so_far,
100 * samples_so_far / len(data_loader.dataset), loss.item() len(data_loader.dataset),
100 * samples_so_far / len(data_loader.dataset),
loss.item(),
) )
) )
def evaluate_performance(data_loader, discriminator, device='cpu'): def evaluate_performance(data_loader, discriminator, device="cpu"):
discriminator.eval() discriminator.eval()
test_loss = 0 test_loss = 0
correct = 0 correct = 0
...@@ -194,13 +177,12 @@ def evaluate_performance(data_loader, discriminator, device='cpu'): ...@@ -194,13 +177,12 @@ def evaluate_performance(data_loader, discriminator, device='cpu'):
print( print(
"Performance on test set: " "Performance on test set: "
"Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format( "Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(
test_loss, correct, len(data_loader.dataset), test_loss, correct, len(data_loader.dataset), 100.0 * correct / len(data_loader.dataset)
100. * correct / len(data_loader.dataset)
) )
) )
def predict(input_sentence, model, classes, cached=False, device='cpu'): def predict(input_sentence, model, classes, cached=False, device="cpu"):
input_t = model.tokenizer.encode(input_sentence) input_t = model.tokenizer.encode(input_sentence)
input_t = torch.tensor([input_t], dtype=torch.long, device=device) input_t = torch.tensor([input_t], dtype=torch.long, device=device)
if cached: if cached:
...@@ -208,17 +190,14 @@ def predict(input_sentence, model, classes, cached=False, device='cpu'): ...@@ -208,17 +190,14 @@ def predict(input_sentence, model, classes, cached=False, device='cpu'):
log_probs = model(input_t).data.cpu().numpy().flatten().tolist() log_probs = model(input_t).data.cpu().numpy().flatten().tolist()
print("Input sentence:", input_sentence) print("Input sentence:", input_sentence)
print("Predictions:", ", ".join( print(
"{}: {:.4f}".format(c, math.exp(log_prob)) for c, log_prob in "Predictions:",
zip(classes, log_probs) ", ".join("{}: {:.4f}".format(c, math.exp(log_prob)) for c, log_prob in zip(classes, log_probs)),
)) )
def get_cached_data_loader(dataset, batch_size, discriminator, def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False, device="cpu"):
shuffle=False, device='cpu'): data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, collate_fn=collate_fn)
data_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
collate_fn=collate_fn)
xs = [] xs = []
ys = [] ys = []
...@@ -231,50 +210,44 @@ def get_cached_data_loader(dataset, batch_size, discriminator, ...@@ -231,50 +210,44 @@ def get_cached_data_loader(dataset, batch_size, discriminator,
ys += y.cpu().numpy().tolist() ys += y.cpu().numpy().tolist()
data_loader = torch.utils.data.DataLoader( data_loader = torch.utils.data.DataLoader(
dataset=Dataset(xs, ys), dataset=Dataset(xs, ys), batch_size=batch_size, shuffle=shuffle, collate_fn=cached_collate_fn
batch_size=batch_size, )
shuffle=shuffle,
collate_fn=cached_collate_fn)
return data_loader return data_loader
def train_discriminator( def train_discriminator(
dataset, dataset_fp=None, pretrained_model="gpt2-medium", dataset,
epochs=10, batch_size=64, log_interval=10, dataset_fp=None,
save_model=False, cached=False, no_cuda=False): pretrained_model="gpt2-medium",
epochs=10,
batch_size=64,
log_interval=10,
save_model=False,
cached=False,
no_cuda=False,
):
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu" device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
print("Preprocessing {} dataset...".format(dataset)) print("Preprocessing {} dataset...".format(dataset))
start = time.time() start = time.time()
if dataset == "SST": if dataset == "SST":
idx2class = ["positive", "negative", "very positive", "very negative", idx2class = ["positive", "negative", "very positive", "very negative", "neutral"]
"neutral"]
class2idx = {c: i for i, c in enumerate(idx2class)} class2idx = {c: i for i, c in enumerate(idx2class)}
discriminator = Discriminator( discriminator = Discriminator(
class_size=len(idx2class), class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
pretrained_model=pretrained_model,
cached_mode=cached,
device=device
).to(device) ).to(device)
text = torchtext_data.Field() text = torchtext_data.Field()
label = torchtext_data.Field(sequential=False) label = torchtext_data.Field(sequential=False)
train_data, val_data, test_data = datasets.SST.splits( train_data, val_data, test_data = datasets.SST.splits(text, label, fine_grained=True, train_subtrees=True,)
text,
label,
fine_grained=True,
train_subtrees=True,
)
x = [] x = []
y = [] y = []
for i in trange(len(train_data), ascii=True): for i in trange(len(train_data), ascii=True):
seq = TreebankWordDetokenizer().detokenize( seq = TreebankWordDetokenizer().detokenize(vars(train_data[i])["text"])
vars(train_data[i])["text"]
)
seq = discriminator.tokenizer.encode(seq) seq = discriminator.tokenizer.encode(seq)
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long) seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
x.append(seq) x.append(seq)
...@@ -284,9 +257,7 @@ def train_discriminator( ...@@ -284,9 +257,7 @@ def train_discriminator(
test_x = [] test_x = []
test_y = [] test_y = []
for i in trange(len(test_data), ascii=True): for i in trange(len(test_data), ascii=True):
seq = TreebankWordDetokenizer().detokenize( seq = TreebankWordDetokenizer().detokenize(vars(test_data[i])["text"])
vars(test_data[i])["text"]
)
seq = discriminator.tokenizer.encode(seq) seq = discriminator.tokenizer.encode(seq)
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long) seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
test_x.append(seq) test_x.append(seq)
...@@ -306,10 +277,7 @@ def train_discriminator( ...@@ -306,10 +277,7 @@ def train_discriminator(
class2idx = {c: i for i, c in enumerate(idx2class)} class2idx = {c: i for i, c in enumerate(idx2class)}
discriminator = Discriminator( discriminator = Discriminator(
class_size=len(idx2class), class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
pretrained_model=pretrained_model,
cached_mode=cached,
device=device
).to(device) ).to(device)
with open("datasets/clickbait/clickbait_train_prefix.txt") as f: with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
...@@ -318,9 +286,7 @@ def train_discriminator( ...@@ -318,9 +286,7 @@ def train_discriminator(
try: try:
data.append(eval(line)) data.append(eval(line))
except: except:
print("Error evaluating line {}: {}".format( print("Error evaluating line {}: {}".format(i, line))
i, line
))
continue continue
x = [] x = []
y = [] y = []
...@@ -331,27 +297,20 @@ def train_discriminator( ...@@ -331,27 +297,20 @@ def train_discriminator(
seq = discriminator.tokenizer.encode(d["text"]) seq = discriminator.tokenizer.encode(d["text"])
if len(seq) < max_length_seq: if len(seq) < max_length_seq:
seq = torch.tensor( seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
[50256] + seq, device=device, dtype=torch.long
)
else: else:
print("Line {} is longer than maximum length {}".format( print("Line {} is longer than maximum length {}".format(i, max_length_seq))
i, max_length_seq
))
continue continue
x.append(seq) x.append(seq)
y.append(d["label"]) y.append(d["label"])
except: except:
print("Error evaluating / tokenizing" print("Error evaluating / tokenizing" " line {}, skipping it".format(i))
" line {}, skipping it".format(i))
pass pass
full_dataset = Dataset(x, y) full_dataset = Dataset(x, y)
train_size = int(0.9 * len(full_dataset)) train_size = int(0.9 * len(full_dataset))
test_size = len(full_dataset) - train_size test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split( train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
full_dataset, [train_size, test_size]
)
discriminator_meta = { discriminator_meta = {
"class_size": len(idx2class), "class_size": len(idx2class),
...@@ -366,10 +325,7 @@ def train_discriminator( ...@@ -366,10 +325,7 @@ def train_discriminator(
class2idx = {c: i for i, c in enumerate(idx2class)} class2idx = {c: i for i, c in enumerate(idx2class)}
discriminator = Discriminator( discriminator = Discriminator(
class_size=len(idx2class), class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
pretrained_model=pretrained_model,
cached_mode=cached,
device=device
).to(device) ).to(device)
x = [] x = []
...@@ -381,27 +337,20 @@ def train_discriminator( ...@@ -381,27 +337,20 @@ def train_discriminator(
seq = discriminator.tokenizer.encode(d["text"]) seq = discriminator.tokenizer.encode(d["text"])
if len(seq) < max_length_seq: if len(seq) < max_length_seq:
seq = torch.tensor( seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
[50256] + seq, device=device, dtype=torch.long
)
else: else:
print("Line {} is longer than maximum length {}".format( print("Line {} is longer than maximum length {}".format(i, max_length_seq))
i, max_length_seq
))
continue continue
x.append(seq) x.append(seq)
y.append(int(np.sum(d["label"]) > 0)) y.append(int(np.sum(d["label"]) > 0))
except: except:
print("Error evaluating / tokenizing" print("Error evaluating / tokenizing" " line {}, skipping it".format(i))
" line {}, skipping it".format(i))
pass pass
full_dataset = Dataset(x, y) full_dataset = Dataset(x, y)
train_size = int(0.9 * len(full_dataset)) train_size = int(0.9 * len(full_dataset))
test_size = len(full_dataset) - train_size test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split( train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
full_dataset, [train_size, test_size]
)
discriminator_meta = { discriminator_meta = {
"class_size": len(idx2class), "class_size": len(idx2class),
...@@ -416,8 +365,7 @@ def train_discriminator( ...@@ -416,8 +365,7 @@ def train_discriminator(
# class \t text # class \t text
if dataset_fp is None: if dataset_fp is None:
raise ValueError("When generic dataset is selected, " raise ValueError("When generic dataset is selected, " "dataset_fp needs to be specified aswell.")
"dataset_fp needs to be specified aswell.")
classes = set() classes = set()
with open(dataset_fp) as f: with open(dataset_fp) as f:
...@@ -430,10 +378,7 @@ def train_discriminator( ...@@ -430,10 +378,7 @@ def train_discriminator(
class2idx = {c: i for i, c in enumerate(idx2class)} class2idx = {c: i for i, c in enumerate(idx2class)}
discriminator = Discriminator( discriminator = Discriminator(
class_size=len(idx2class), class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
pretrained_model=pretrained_model,
cached_mode=cached,
device=device
).to(device) ).to(device)
x = [] x = []
...@@ -447,18 +392,11 @@ def train_discriminator( ...@@ -447,18 +392,11 @@ def train_discriminator(
try: try:
seq = discriminator.tokenizer.encode(text) seq = discriminator.tokenizer.encode(text)
if (len(seq) < max_length_seq): if len(seq) < max_length_seq:
seq = torch.tensor( seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
[50256] + seq,
device=device,
dtype=torch.long
)
else: else:
print( print("Line {} is longer than maximum length {}".format(i, max_length_seq))
"Line {} is longer than maximum length {}".format(
i, max_length_seq
))
continue continue
x.append(seq) x.append(seq)
...@@ -471,10 +409,7 @@ def train_discriminator( ...@@ -471,10 +409,7 @@ def train_discriminator(
full_dataset = Dataset(x, y) full_dataset = Dataset(x, y)
train_size = int(0.9 * len(full_dataset)) train_size = int(0.9 * len(full_dataset))
test_size = len(full_dataset) - train_size test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split( train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
full_dataset,
[train_size, test_size]
)
discriminator_meta = { discriminator_meta = {
"class_size": len(idx2class), "class_size": len(idx2class),
...@@ -485,9 +420,7 @@ def train_discriminator( ...@@ -485,9 +420,7 @@ def train_discriminator(
} }
end = time.time() end = time.time()
print("Preprocessed {} data points".format( print("Preprocessed {} data points".format(len(train_dataset) + len(test_dataset)))
len(train_dataset) + len(test_dataset))
)
print("Data preprocessing took: {:.3f}s".format(end - start)) print("Data preprocessing took: {:.3f}s".format(end - start))
if cached: if cached:
...@@ -495,30 +428,21 @@ def train_discriminator( ...@@ -495,30 +428,21 @@ def train_discriminator(
start = time.time() start = time.time()
train_loader = get_cached_data_loader( train_loader = get_cached_data_loader(train_dataset, batch_size, discriminator, shuffle=True, device=device)
train_dataset, batch_size, discriminator,
shuffle=True, device=device
)
test_loader = get_cached_data_loader( test_loader = get_cached_data_loader(test_dataset, batch_size, discriminator, device=device)
test_dataset, batch_size, discriminator, device=device
)
end = time.time() end = time.time()
print("Building representation cache took: {:.3f}s".format(end - start)) print("Building representation cache took: {:.3f}s".format(end - start))
else: else:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, train_loader = torch.utils.data.DataLoader(
batch_size=batch_size, dataset=train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
shuffle=True, )
collate_fn=collate_fn) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, collate_fn=collate_fn)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
collate_fn=collate_fn)
if save_model: if save_model:
with open("{}_classifier_head_meta.json".format(dataset), with open("{}_classifier_head_meta.json".format(dataset), "w") as meta_file:
"w") as meta_file:
json.dump(discriminator_meta, meta_file) json.dump(discriminator_meta, meta_file)
optimizer = optim.Adam(discriminator.parameters(), lr=0.0001) optimizer = optim.Adam(discriminator.parameters(), lr=0.0001)
...@@ -533,56 +457,61 @@ def train_discriminator( ...@@ -533,56 +457,61 @@ def train_discriminator(
optimizer=optimizer, optimizer=optimizer,
epoch=epoch, epoch=epoch,
log_interval=log_interval, log_interval=log_interval,
device=device device=device,
)
evaluate_performance(
data_loader=test_loader,
discriminator=discriminator,
device=device
) )
evaluate_performance(data_loader=test_loader, discriminator=discriminator, device=device)
end = time.time() end = time.time()
print("Epoch took: {:.3f}s".format(end - start)) print("Epoch took: {:.3f}s".format(end - start))
print("\nExample prediction") print("\nExample prediction")
predict(example_sentence, discriminator, idx2class, predict(example_sentence, discriminator, idx2class, cached=cached, device=device)
cached=cached, device=device)
if save_model: if save_model:
# torch.save(discriminator.state_dict(), # torch.save(discriminator.state_dict(),
# "{}_discriminator_{}.pt".format( # "{}_discriminator_{}.pt".format(
# args.dataset, epoch + 1 # args.dataset, epoch + 1
# )) # ))
torch.save(discriminator.get_classifier().state_dict(), torch.save(
"{}_classifier_head_epoch_{}.pt".format(dataset, discriminator.get_classifier().state_dict(),
epoch + 1)) "{}_classifier_head_epoch_{}.pt".format(dataset, epoch + 1),
)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Train a discriminator on top of GPT-2 representations")
description="Train a discriminator on top of GPT-2 representations") parser.add_argument(
parser.add_argument("--dataset", type=str, default="SST", "--dataset",
type=str,
default="SST",
choices=("SST", "clickbait", "toxic", "generic"), choices=("SST", "clickbait", "toxic", "generic"),
help="dataset to train the discriminator on." help="dataset to train the discriminator on."
"In case of generic, the dataset is expected" "In case of generic, the dataset is expected"
"to be a TSBV file with structure: class \\t text") "to be a TSBV file with structure: class \\t text",
parser.add_argument("--dataset_fp", type=str, default="", )
help="File path of the dataset to use. " parser.add_argument(
"Needed only in case of generic datadset") "--dataset_fp",
parser.add_argument("--pretrained_model", type=str, default="gpt2-medium", type=str,
help="Pretrained model to use as encoder") default="",
parser.add_argument("--epochs", type=int, default=10, metavar="N", help="File path of the dataset to use. " "Needed only in case of generic datadset",
help="Number of training epochs") )
parser.add_argument("--batch_size", type=int, default=64, metavar="N", parser.add_argument(
help="input batch size for training (default: 64)") "--pretrained_model", type=str, default="gpt2-medium", help="Pretrained model to use as encoder"
parser.add_argument("--log_interval", type=int, default=10, metavar="N", )
help="how many batches to wait before logging training status") parser.add_argument("--epochs", type=int, default=10, metavar="N", help="Number of training epochs")
parser.add_argument("--save_model", action="store_true", parser.add_argument(
help="whether to save the model") "--batch_size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)"
parser.add_argument("--cached", action="store_true", )
help="whether to cache the input representations") parser.add_argument(
parser.add_argument("--no_cuda", action="store_true", "--log_interval",
help="use to turn off cuda") type=int,
default=10,
metavar="N",
help="how many batches to wait before logging training status",
)
parser.add_argument("--save_model", action="store_true", help="whether to save the model")
parser.add_argument("--cached", action="store_true", help="whether to cache the input representations")
parser.add_argument("--no_cuda", action="store_true", help="use to turn off cuda")
args = parser.parse_args() args = parser.parse_args()
train_discriminator(**(vars(args))) train_discriminator(**(vars(args)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment