Commit 976554a4 authored by Matthew Carrigan's avatar Matthew Carrigan
Browse files

First commit of the new LM finetuning

parent f3e54048
from argparse import ArgumentParser
from pathlib import Path
import torch
import logging
import json
import random
import numpy as np
from collections import namedtuple
from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from pytorch_pretrained_bert.modeling import BertForPreTraining
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
InputFeatures = namedtuple("InputFeatures", "input_ids input_mask segment_ids lm_label_ids is_next")
logger = logging.getLogger(__name__)
def convert_example_to_features(example, tokenizer, max_seq_length):
tokens = example["tokens"]
segment_ids = example["segment_ids"]
is_random_next = example["is_random_next"]
masked_lm_positions = example["masked_lm_positions"]
masked_lm_labels = example["masked_lm_labels"]
assert len(tokens) == len(segment_ids) <= max_seq_length # The preprocessed data should be already truncated
input_ids = tokenizer.convert_tokens_to_ids(tokens)
masked_label_ids = tokenizer.convert_tokens_to_ids(masked_lm_labels)
input_array = np.zeros(max_seq_length, dtype=np.int)
input_array[:len(input_ids)] = input_ids
mask_array = np.zeros(max_seq_length, dtype=np.bool)
mask_array[:len(input_ids)] = 1
segment_array = np.zeros(max_seq_length, dtype=np.bool)
segment_array[:len(segment_ids)] = segment_ids
lm_label_array = np.full(max_seq_length, dtype=np.int, fill_value=-1)
lm_label_array[masked_lm_positions] = masked_label_ids
features = InputFeatures(input_ids=input_array,
input_mask=mask_array,
segment_ids=segment_array,
lm_label_ids=lm_label_array,
is_next=is_random_next)
return features
class PregeneratedDataset(Dataset):
def __init__(self, training_path, epoch, tokenizer, num_data_epochs):
# TODO Add an option to memmap the training data
self.vocab = tokenizer.vocab
self.tokenizer = tokenizer
self.epoch = epoch
self.data_epoch = epoch % num_data_epochs
data_file = training_path / f"epoch_{self.data_epoch}.json"
metrics_file = training_path / f"epoch_{self.data_epoch}_metrics.json"
assert data_file.is_file() and metrics_file.is_file()
metrics = json.loads(metrics_file.read_text())
num_samples = metrics['num_training_examples']
seq_len = metrics['max_seq_len']
input_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.int32)
input_masks = np.zeros(shape=(num_samples, seq_len), dtype=np.bool)
segment_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.bool)
lm_label_ids = np.full(shape=(num_samples, seq_len), dtype=np.int32, fill_value=-1)
is_nexts = np.zeros(shape=(num_samples,), dtype=np.bool)
logger.info(f"Loading training examples for epoch {epoch}")
with data_file.open() as f:
for i, line in enumerate(tqdm(f, total=num_samples, desc="Training examples")):
example = json.loads(line.rstrip())
features = convert_example_to_features(example, tokenizer, seq_len)
input_ids[i] = features.input_ids
segment_ids[i] = features.segment_ids
input_masks[i] = features.input_mask
lm_label_ids[i] = features.lm_label_ids
is_nexts[i] = features.is_next
assert i == num_samples - 1 # Assert that the sample count metric was true
logger.info("Loading complete!")
self.num_samples = num_samples
self.seq_len = seq_len
self.input_ids = input_ids
self.input_masks = input_masks
self.segment_ids = segment_ids
self.lm_label_ids = lm_label_ids
self.is_nexts = is_nexts
def __len__(self):
return self.num_samples
def __getitem__(self, item):
return (torch.tensor(self.input_ids[item].astype(np.int64)),
torch.tensor(self.input_masks[item].astype(np.int64)),
torch.tensor(self.segment_ids[item].astype(np.int64)),
torch.tensor(self.lm_label_ids[item].astype(np.int64)),
torch.tensor(self.is_nexts[item].astype(np.int64)))
# TODO 2: Test it's all working
# TODO 3: Add a README (can you do that with subfolders?)
def main():
parser = ArgumentParser()
parser.add_argument('--pregenerated_data', type=Path, required=True)
parser.add_argument('--output_dir', type=Path, required=True)
parser.add_argument("--bert_model", type=str, required=True,
choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased",
"bert-base-multilingual", "bert-base-chinese"])
parser.add_argument("--do_lower_case", action="store_true")
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs to train for")
parser.add_argument("--local_rank",
type=int,
default=-1,
help="local_rank for distributed training on gpus")
parser.add_argument("--no_cuda",
action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument('--gradient_accumulation_steps',
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--train_batch_size",
default=32,
type=int,
help="Total batch size for training.")
parser.add_argument('--fp16',
action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--loss_scale',
type = float, default = 0,
help = "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n")
parser.add_argument("--warmup_proportion",
default=0.1,
type=float,
help="Proportion of training to perform linear learning rate warmup for. "
"E.g., 0.1 = 10%% of training.")
parser.add_argument("--learning_rate",
default=3e-5,
type=float,
help="The initial learning rate for Adam.")
parser.add_argument('--seed',
type=int,
default=42,
help="random seed for initialization")
args = parser.parse_args()
assert args.pregenerated_data.is_dir(), "--pregenerated_data should point to the folder of files made by pregenerate_training_data.py!"
samples_per_epoch = []
for i in range(args.epochs):
epoch_file = args.pregenerated_data / f"epoch_{i}.json"
metrics_file = args.pregenerated_data / f"epoch_{i}_metrics.json"
if epoch_file.is_file() and metrics_file.is_file():
metrics = json.loads(metrics_file.read_text())
samples_per_epoch.append(metrics['num_training_examples'])
else:
if i == 0:
exit("No training data was found!")
print(f"Warning! There are fewer epochs of pregenerated data ({i}) than training epochs ({args.epochs}).")
print("This script will loop over the available data, but training diversity may be negatively impacted.")
num_data_epochs = i
break
else:
num_data_epochs = args.epochs
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
n_gpu = torch.cuda.device_count()
else:
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl')
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
device, n_gpu, bool(args.local_rank != -1), args.fp16))
if args.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps))
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
if args.output_dir.is_dir() and list(args.output_dir.iterdir()):
logger.warning(f"Output directory ({args.output_dir}) already exists and is not empty!")
args.output_dir.mkdir(parents=True, exist_ok=True)
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
total_train_examples = 0
for i in range(args.epochs):
# The modulo takes into account the fact that we may loop over limited epochs of data
total_train_examples += samples_per_epoch[i % len(samples_per_epoch)]
num_train_optimization_steps = int(
total_train_examples / args.train_batch_size / args.gradient_accumulation_steps)
if args.local_rank != -1:
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
# Prepare model
model = BertForPreTraining.from_pretrained(args.bert_model)
if args.fp16:
model.half()
model.to(device)
if args.local_rank != -1:
try:
from apex.parallel import DistributedDataParallel as DDP
except ImportError:
raise ImportError(
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
model = DDP(model)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
# Prepare optimizer
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
if args.fp16:
try:
from apex.optimizers import FP16_Optimizer
from apex.optimizers import FusedAdam
except ImportError:
raise ImportError(
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
optimizer = FusedAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
bias_correction=False,
max_grad_norm=1.0)
if args.loss_scale == 0:
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
else:
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
else:
optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=num_train_optimization_steps)
global_step = 0
logger.info("***** Running training *****")
logger.info(f" Num examples = {total_train_examples}")
logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_optimization_steps)
model.train()
for epoch in range(args.epochs):
epoch_dataset = PregeneratedDataset(epoch=epoch, training_path=args.pregenerated_data, tokenizer=tokenizer,
num_data_epochs=num_data_epochs)
if args.local_rank == -1:
train_sampler = RandomSampler(epoch_dataset)
else:
train_sampler = DistributedSampler(epoch_dataset)
train_dataloader = DataLoader(epoch_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0
with tqdm(total=len(train_dataloader), desc=f"Epoch {epoch}") as pbar:
for step, batch in enumerate(train_dataloader):
batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next)
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
if args.fp16:
optimizer.backward(loss)
else:
loss.backward()
tr_loss += loss.item()
nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1
pbar.update(1)
mean_loss = tr_loss / nb_tr_steps
pbar.set_postfix_str(f"Loss: {mean_loss:.5f}")
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
# modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used that handles this automatically
lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
optimizer.step()
optimizer.zero_grad()
global_step += 1
# Save a trained model
logger.info("** ** * Saving fine-tuned model ** ** * ")
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = args.output_dir / "pytorch_model.bin"
torch.save(model_to_save.state_dict(), str(output_model_file))
if __name__ == '__main__':
main()
# Step 1: Slurp the dataset up, tokenize each sentence, and store as docs -> sentences -> tokens
# Step 2: Walk over the dataset, using the Google BERT logic to concatenate sentences into training examples
# Step 3: Write out the examples, possibly as Torch tensors?
from argparse import ArgumentParser
from pathlib import Path
from tqdm import tqdm, trange
from random import random, randint, shuffle, choice, sample
from pytorch_pretrained_bert.tokenization import BertTokenizer
import json
class DocumentDatabase:
def __init__(self, document_list):
self.document_list = document_list
self.doc_starts = {}
self.weighted_doc_samples = []
i = 0
for doc_idx, doc in enumerate(document_list):
self.doc_starts[doc_idx] = i
self.weighted_doc_samples.extend([doc_idx] * len(doc))
i += len(doc)
def sample_doc(self, current_idx, sentence_weighted=True):
# Uses the current iteration counter to ensure we don't sample the same doc twice
if sentence_weighted:
num_sentences = len(self.document_list[current_idx])
# This very painful line randomly selects a document, weighted by the number of sentences they contain,
# while guaranteeing that it won't return the original document
sampled_val = (
(self.doc_starts[current_idx] + num_sentences
+ randint(0, len(self.weighted_doc_samples) - num_sentences - 1))
% len(self.weighted_doc_samples))
sampled_doc_index = self.weighted_doc_samples[sampled_val]
else:
# If we don't use sentence weighting, then every doc has an equal chance to be chosen
sampled_doc_index = current_idx + randint(1, len(self.document_list)-1)
assert sampled_doc_index != current_idx
return self.document_list[sampled_doc_index]
def __len__(self):
return len(self.document_list)
def __getitem__(self, item):
return self.document_list[item]
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
"""Truncates a pair of sequences to a maximum sequence length. Lifted from Google's BERT repo."""
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_num_tokens:
break
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
assert len(trunc_tokens) >= 1
# We want to sometimes truncate from the front and sometimes from the
# back to add more randomness and avoid biases.
if random() < 0.5:
del trunc_tokens[0]
else:
trunc_tokens.pop()
def create_masked_lm_predictions(tokens, masked_lm_prob, max_predictions_per_seq, vocab_list):
"""Creates the predictions for the masked LM objective. This is mostly copied from the Google BERT repo, but
with several refactors to clean it up and remove a lot of unnecessary variables."""
cand_indices = []
for (i, token) in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
continue
cand_indices.append(i)
num_to_mask = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob))))
shuffle(cand_indices)
mask_indices = sorted(sample(cand_indices, num_to_mask))
masked_token_labels = []
for index in mask_indices:
# 80% of the time, replace with [MASK]
if random() < 0.8:
masked_token = "[MASK]"
else:
# 10% of the time, keep original
if random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = choice(vocab_list)
masked_token_labels.append(tokens[index])
# Once we've saved the true label for that token, we can overwrite it with the masked version
tokens[index] = masked_token
return tokens, mask_indices, masked_token_labels
def create_instances_from_document(
doc_database, doc_idx, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_list):
"""This code is mostly a duplicate of the equivalent function from Google BERT's repo.
However, we make some changes and improvements. Sampling is improved and no longer requires a loop in this function.
Also, documents are sampled proportionally to the number of sentences they contain, which means each sentence
(rather than each document) has an equal chance of being sampled as a false example for the NextSentence task."""
document = doc_database[doc_idx]
# Account for [CLS], [SEP], [SEP]
max_num_tokens = max_seq_length - 3
# We *usually* want to fill up the entire sequence since we are padding
# to `max_seq_length` anyways, so short sequences are generally wasted
# computation. However, we *sometimes*
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
# sequences to minimize the mismatch between pre-training and fine-tuning.
# The `target_seq_length` is just a rough target however, whereas
# `max_seq_length` is a hard limit.
target_seq_length = max_num_tokens
if random() < short_seq_prob:
target_seq_length = randint(2, max_num_tokens)
# We DON'T just concatenate all of the tokens from a document into a long
# sequence and choose an arbitrary split point because this would make the
# next sentence prediction task too easy. Instead, we split the input into
# segments "A" and "B" based on the actual "sentences" provided by the user
# input.
instances = []
current_chunk = []
current_length = 0
i = 0
while i < len(document):
segment = document[i]
current_chunk.append(segment)
current_length += len(segment)
if i == len(document) - 1 or current_length >= target_seq_length:
if current_chunk:
# `a_end` is how many segments from `current_chunk` go into the `A`
# (first) sentence.
a_end = 1
if len(current_chunk) >= 2:
a_end = randint(1, len(current_chunk) - 1)
tokens_a = []
for j in range(a_end):
tokens_a.extend(current_chunk[j])
tokens_b = []
# Random next
if len(current_chunk) == 1 or random() < 0.5:
is_random_next = True
target_b_length = target_seq_length - len(tokens_a)
# This should rarely go for more than one iteration for large
# corpora. However, just to be careful, we try to make sure that
# the random document is not the same as the document
# we're processing.
# random_document = get_random_doc(all_documents, document, doc_weights)
random_document = doc_database.sample_doc(current_idx=doc_idx, sentence_weighted=True)
random_start = randint(0, len(random_document) - 1)
for j in range(random_start, len(random_document)):
tokens_b.extend(random_document[j])
if len(tokens_b) >= target_b_length:
break
# We didn't actually use these segments so we "put them back" so
# they don't go to waste.
num_unused_segments = len(current_chunk) - a_end
i -= num_unused_segments
# Actual next
else:
is_random_next = False
for j in range(a_end, len(current_chunk)):
tokens_b.extend(current_chunk[j])
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens)
assert len(tokens_a) >= 1
assert len(tokens_b) >= 1
tokens = ["[CLS]"] + tokens_a + ["[SEP]"] + tokens_b + ["[SEP]"]
# The segment IDs are 0 for the [CLS] token, the A tokens and the first [SEP]
# They are 1 for the B tokens and the final [SEP]
segment_ids = [0 for _ in range(len(tokens_a) + 2)] + [1 for _ in range(len(tokens_b) + 1)]
tokens, masked_lm_positions, masked_lm_labels = create_masked_lm_predictions(
tokens, masked_lm_prob, max_predictions_per_seq, vocab_list)
instance = {
"tokens": tokens,
"segment_ids": segment_ids,
"is_random_next": is_random_next,
"masked_lm_positions": masked_lm_positions,
"masked_lm_labels": masked_lm_labels}
instances.append(instance)
current_chunk = []
current_length = 0
i += 1
return instances
def main():
parser = ArgumentParser()
parser.add_argument('--corpus_path', type=Path, required=True)
parser.add_argument("--save_dir", type=Path, required=True)
parser.add_argument("--bert_model", type=str, required=True,
choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased",
"bert-base-multilingual", "bert-base-chinese"])
parser.add_argument("--do_lower_case", action="store_true")
parser.add_argument("--epochs_to_generate", type=int, default=3,
help="Number of epochs of data to pregenerate")
parser.add_argument("--max_seq_len", type=int, default=128)
parser.add_argument("--short_seq_prob", type=float, default=0.1,
help="Probability of making a short sentence as a training example")
parser.add_argument("--masked_lm_prob", type=float, default=0.15,
help="Probability of masking each token for the LM task")
parser.add_argument("--max_predictions_per_seq", type=int, default=20,
help="Maximum number of tokens to mask in each sequence")
args = parser.parse_args()
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
vocab_list = list(tokenizer.vocab.keys())
with args.corpus_path.open() as f:
docs = []
doc = []
for line in tqdm(f, desc="Loading Dataset"):
line = line.strip()
if line == "":
docs.append(doc)
doc = []
else:
tokens = tokenizer.tokenize(line)
# TODO If the sentence is longer than max_len, do we split it in the middle? That's probably a bad idea
doc.append(tokens)
args.save_dir.mkdir(exist_ok=True)
docs = DocumentDatabase(docs)
# When choosing a random sentence, we should sample docs proportionally to the number of sentences they contain
# Google BERT doesn't do this, and as a result oversamples shorter docs
for epoch in trange(args.epochs_to_generate, desc="Epoch"):
epoch_instances = []
for doc_idx in trange(len(docs), desc="Document"):
doc_instances = create_instances_from_document(
docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob,
masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq,
vocab_list=vocab_list)
doc_instances = [json.dumps(instance) for instance in doc_instances]
epoch_instances.extend(doc_instances)
shuffle(epoch_instances)
epoch_file = args.save_dir / f"epoch_{epoch}.json"
metrics_file = args.save_dir / f"epoch_{epoch}_metrics.json"
with epoch_file.open('w') as out_file:
for instance in epoch_instances:
out_file.write(instance + '\n')
with metrics_file.open('w') as metrics_file:
metrics = {
"num_training_examples": len(epoch_instances),
"max_seq_len": args.max_seq_len
}
metrics_file.write(json.dumps(metrics))
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment