Unverified Commit 2fb8ddee authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #392 from Rocketknight1/master

Add full language model fine-tuning
parents f3e54048 34561e61
......@@ -1051,18 +1051,7 @@ You can download an [exemplary training corpus](https://ext-bert-sample.obs.eu-d
Training one epoch on this corpus takes about 1:20h on 4 x NVIDIA Tesla P100 with `train_batch_size=200` and `max_seq_length=128`:
```shell
python run_lm_finetuning.py \
--bert_model bert-base-uncased \
--do_lower_case \
--do_train \
--train_file ../samples/sample_text.txt \
--output_dir models \
--num_train_epochs 5.0 \
--learning_rate 3e-5 \
--train_batch_size 32 \
--max_seq_length 128 \
```
Thank to the work of @Rocketknight1 and @tholor there are now **several scripts** that can be used to fine-tune BERT using the pretraining objective (combination of masked-language modeling and next sentence prediction loss). These scripts are detailed in the [`README`](./examples/lm_finetuning/README.md) of the [`examples/lm_finetuning/`](./examples/lm_finetuning/) folder.
### OpenAI GPT, Transformer-XL and GPT-2: running the examples
......
# BERT Model Finetuning using Masked Language Modeling objective
## Introduction
The three example scripts in this folder can be used to **fine-tune** a pre-trained BERT model using the pretraining objective (combination of masked language modeling and next sentence prediction loss). In general, pretrained models like BERT are first trained with a pretraining objective (masked language modeling and next sentence prediction for BERT) on a large and general natural language corpus. A classifier head is then added on top of the pre-trained architecture and the model is quickly fine-tuned on a target task, while still (hopefully) retaining its general language understanding. This greatly reduces overfitting and yields state-of-the-art results, especially when training data for the target task are limited.
The [ULMFiT paper](https://arxiv.org/abs/1801.06146) took a slightly different approach, however, and added an intermediate step in which the model is fine-tuned on text **from the same domain as the target task and using the pretraining objective** before the final stage in which the classifier head is added and the model is trained on the target task itself. This paper reported significantly improved results from this step, and found that they could get high-quality classifications even with only tiny numbers (<1000) of labelled training examples, as long as they had a lot of unlabelled data from the target domain.
The BERT model has more capacity than the LSTM models used in the ULMFiT work, but the [BERT paper](https://arxiv.org/abs/1810.04805) did not test finetuning using the pretraining objective and at the present stage there aren't many examples of this approach being used for Transformer-based language models. As such, it's hard to predict what effect this step will have on final model performance, but it's reasonable to conjecture that this approach can improve the final classification performance, especially when a large unlabelled corpus from the target domain is available, labelled data is limited, or the target domain is very unusual and different from 'normal' English text. If you are aware of any literature on this subject, please feel free to add it in here, or open an issue and tag me (@Rocketknight1) and I'll include it.
## Input format
The scripts in this folder expect a single file as input, consisting of untokenized text, with one **sentence** per line, and one blank line between documents. The reason for the sentence splitting is that part of BERT's training involves a _next sentence_ objective in which the model must predict whether two sequences of text are contiguous text from the same document or not, and to avoid making the task _too easy_, the split point between the sequences is always at the end of a sentence. The linebreaks in the file are therefore necessary to mark the points where the text can be split.
## Usage
There are two ways to fine-tune a language model using these scripts. The first _quick_ approach is to use [`simple_lm_finetuning.py`](./simple_lm_finetuning.py). This script does everything in a single script, but generates training instances that consist of just two sentences. This is quite different from the BERT paper, where (confusingly) the NextSentence task concatenated sentences together from each document to form two long multi-sentences, which the paper just referred to as _sentences_. The difference between this simple approach and the original paper approach can have a significant effect for long sequences since two sentences will be much shorter than the max sequence length. In this case, most of each training example will just consist of blank padding characters, which wastes a lot of computation and results in a model that isn't really training on long sequences.
As such, the preferred approach (assuming you have documents containing multiple contiguous sentences from your target domain) is to use [`pregenerate_training_data.py`](./pregenerate_training_data.py) to pre-process your data into training examples following the methodology used for LM training in the original BERT paper and repository. Since there is a significant random component to training data generation for BERT, this script includes an option to generate multiple _epochs_ of pre-processed data, to avoid training on the same random splits each epoch. Generating an epoch of data for each training epoch should result a better final model, and so we recommend doing so.
You can then train on the pregenerated data using [`finetune_on_pregenerated.py`](./finetune_on_pregenerated.py), and pointing it to the folder created by [`pregenerate_training_data.py`](./pregenerate_training_data.py). Note that you should use the same `bert_model` and case options for both! Also note that `max_seq_len` does not need to be specified for the [`finetune_on_pregenerated.py`](./finetune_on_pregenerated.py) script, as it is inferred from the training examples.
There are various options that can be tweaked, but they are mostly set to the values from the BERT paper/repository and default values should make sense. The most relevant ones are:
- `--max_seq_len`: Controls the length of training examples (in wordpiece tokens) seen by the model. Defaults to 128 but can be set as high as 512. Higher values may yield stronger language models at the cost of slower and more memory-intensive training.
- `--fp16`: Enables fast half-precision training on recent GPUs.
In addition, if memory usage is an issue, especially when training on a single GPU, reducing `--train_batch_size` from the default 32 to a lower number (4-16) can be helpful, or leaving `--train_batch_size` at the default and increasing `--gradient_accumulation_steps` to 2-8. Changing `--gradient_accumulation_steps` may be preferable as alterations to the batch size may require corresponding changes in the learning rate to compensate. There is also a `--reduce_memory` option for both the `pregenerate_training_data.py` and `finetune_on_pregenerated.py` scripts that spills data to disc in shelf objects or numpy memmaps rather than retaining it in memory, which significantly reduces memory usage with little performance impact.
## Examples
### Simple fine-tuning
```
python3 simple_lm_finetuning.py
--train_corpus my_corpus.txt
--bert_model bert-base-uncased
--do_lower_case
--output_dir finetuned_lm/
```
### Pregenerating training data
```
python3 pregenerate_training_data.py
--train_corpus my_corpus.txt
--bert_model bert-base-uncased
--do_lower_case
--output_dir training/
--epochs_to_generate 3
--max_seq_len 256
```
### Training on pregenerated data
```
python3 finetune_on_pregenerated.py
--pregenerated_data training/
--bert_model bert-base-uncased
--do_lower_case
--output_dir finetuned_lm/
--epochs 3
```
\ No newline at end of file
from argparse import ArgumentParser
from pathlib import Path
import torch
import logging
import json
import random
import numpy as np
from collections import namedtuple
from tempfile import TemporaryDirectory
from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from pytorch_pretrained_bert.modeling import BertForPreTraining
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
InputFeatures = namedtuple("InputFeatures", "input_ids input_mask segment_ids lm_label_ids is_next")
log_format = '%(asctime)-10s: %(message)s'
logging.basicConfig(level=logging.INFO, format=log_format)
def convert_example_to_features(example, tokenizer, max_seq_length):
tokens = example["tokens"]
segment_ids = example["segment_ids"]
is_random_next = example["is_random_next"]
masked_lm_positions = example["masked_lm_positions"]
masked_lm_labels = example["masked_lm_labels"]
assert len(tokens) == len(segment_ids) <= max_seq_length # The preprocessed data should be already truncated
input_ids = tokenizer.convert_tokens_to_ids(tokens)
masked_label_ids = tokenizer.convert_tokens_to_ids(masked_lm_labels)
input_array = np.zeros(max_seq_length, dtype=np.int)
input_array[:len(input_ids)] = input_ids
mask_array = np.zeros(max_seq_length, dtype=np.bool)
mask_array[:len(input_ids)] = 1
segment_array = np.zeros(max_seq_length, dtype=np.bool)
segment_array[:len(segment_ids)] = segment_ids
lm_label_array = np.full(max_seq_length, dtype=np.int, fill_value=-1)
lm_label_array[masked_lm_positions] = masked_label_ids
features = InputFeatures(input_ids=input_array,
input_mask=mask_array,
segment_ids=segment_array,
lm_label_ids=lm_label_array,
is_next=is_random_next)
return features
class PregeneratedDataset(Dataset):
def __init__(self, training_path, epoch, tokenizer, num_data_epochs, reduce_memory=False):
self.vocab = tokenizer.vocab
self.tokenizer = tokenizer
self.epoch = epoch
self.data_epoch = epoch % num_data_epochs
data_file = training_path / f"epoch_{self.data_epoch}.json"
metrics_file = training_path / f"epoch_{self.data_epoch}_metrics.json"
assert data_file.is_file() and metrics_file.is_file()
metrics = json.loads(metrics_file.read_text())
num_samples = metrics['num_training_examples']
seq_len = metrics['max_seq_len']
self.temp_dir = None
self.working_dir = None
if reduce_memory:
self.temp_dir = TemporaryDirectory()
self.working_dir = Path(self.temp_dir.name)
input_ids = np.memmap(filename=self.working_dir/'input_ids.memmap',
mode='w+', dtype=np.int32, shape=(num_samples, seq_len))
input_masks = np.memmap(filename=self.working_dir/'input_masks.memmap',
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
segment_ids = np.memmap(filename=self.working_dir/'input_masks.memmap',
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
lm_label_ids = np.memmap(filename=self.working_dir/'lm_label_ids.memmap',
shape=(num_samples, seq_len), mode='w+', dtype=np.int32)
lm_label_ids[:] = -1
is_nexts = np.memmap(filename=self.working_dir/'is_nexts.memmap',
shape=(num_samples,), mode='w+', dtype=np.bool)
else:
input_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.int32)
input_masks = np.zeros(shape=(num_samples, seq_len), dtype=np.bool)
segment_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.bool)
lm_label_ids = np.full(shape=(num_samples, seq_len), dtype=np.int32, fill_value=-1)
is_nexts = np.zeros(shape=(num_samples,), dtype=np.bool)
logging.info(f"Loading training examples for epoch {epoch}")
with data_file.open() as f:
for i, line in enumerate(tqdm(f, total=num_samples, desc="Training examples")):
line = line.strip()
example = json.loads(line)
features = convert_example_to_features(example, tokenizer, seq_len)
input_ids[i] = features.input_ids
segment_ids[i] = features.segment_ids
input_masks[i] = features.input_mask
lm_label_ids[i] = features.lm_label_ids
is_nexts[i] = features.is_next
assert i == num_samples - 1 # Assert that the sample count metric was true
logging.info("Loading complete!")
self.num_samples = num_samples
self.seq_len = seq_len
self.input_ids = input_ids
self.input_masks = input_masks
self.segment_ids = segment_ids
self.lm_label_ids = lm_label_ids
self.is_nexts = is_nexts
def __len__(self):
return self.num_samples
def __getitem__(self, item):
return (torch.tensor(self.input_ids[item].astype(np.int64)),
torch.tensor(self.input_masks[item].astype(np.int64)),
torch.tensor(self.segment_ids[item].astype(np.int64)),
torch.tensor(self.lm_label_ids[item].astype(np.int64)),
torch.tensor(self.is_nexts[item].astype(np.int64)))
def main():
parser = ArgumentParser()
parser.add_argument('--pregenerated_data', type=Path, required=True)
parser.add_argument('--output_dir', type=Path, required=True)
parser.add_argument("--bert_model", type=str, required=True,
choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased",
"bert-base-multilingual", "bert-base-chinese"])
parser.add_argument("--do_lower_case", action="store_true")
parser.add_argument("--reduce_memory", action="store_true",
help="Store training data as on-disc memmaps to massively reduce memory usage")
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs to train for")
parser.add_argument("--local_rank",
type=int,
default=-1,
help="local_rank for distributed training on gpus")
parser.add_argument("--no_cuda",
action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument('--gradient_accumulation_steps',
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--train_batch_size",
default=32,
type=int,
help="Total batch size for training.")
parser.add_argument('--fp16',
action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--loss_scale',
type=float, default=0,
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n")
parser.add_argument("--warmup_proportion",
default=0.1,
type=float,
help="Proportion of training to perform linear learning rate warmup for. "
"E.g., 0.1 = 10%% of training.")
parser.add_argument("--learning_rate",
default=3e-5,
type=float,
help="The initial learning rate for Adam.")
parser.add_argument('--seed',
type=int,
default=42,
help="random seed for initialization")
args = parser.parse_args()
assert args.pregenerated_data.is_dir(), \
"--pregenerated_data should point to the folder of files made by pregenerate_training_data.py!"
samples_per_epoch = []
for i in range(args.epochs):
epoch_file = args.pregenerated_data / f"epoch_{i}.json"
metrics_file = args.pregenerated_data / f"epoch_{i}_metrics.json"
if epoch_file.is_file() and metrics_file.is_file():
metrics = json.loads(metrics_file.read_text())
samples_per_epoch.append(metrics['num_training_examples'])
else:
if i == 0:
exit("No training data was found!")
print(f"Warning! There are fewer epochs of pregenerated data ({i}) than training epochs ({args.epochs}).")
print("This script will loop over the available data, but training diversity may be negatively impacted.")
num_data_epochs = i
break
else:
num_data_epochs = args.epochs
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
n_gpu = torch.cuda.device_count()
else:
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl')
logging.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
device, n_gpu, bool(args.local_rank != -1), args.fp16))
if args.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps))
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
if args.output_dir.is_dir() and list(args.output_dir.iterdir()):
logging.warning(f"Output directory ({args.output_dir}) already exists and is not empty!")
args.output_dir.mkdir(parents=True, exist_ok=True)
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
total_train_examples = 0
for i in range(args.epochs):
# The modulo takes into account the fact that we may loop over limited epochs of data
total_train_examples += samples_per_epoch[i % len(samples_per_epoch)]
num_train_optimization_steps = int(
total_train_examples / args.train_batch_size / args.gradient_accumulation_steps)
if args.local_rank != -1:
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
# Prepare model
model = BertForPreTraining.from_pretrained(args.bert_model)
if args.fp16:
model.half()
model.to(device)
if args.local_rank != -1:
try:
from apex.parallel import DistributedDataParallel as DDP
except ImportError:
raise ImportError(
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
model = DDP(model)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
# Prepare optimizer
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
if args.fp16:
try:
from apex.optimizers import FP16_Optimizer
from apex.optimizers import FusedAdam
except ImportError:
raise ImportError(
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
optimizer = FusedAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
bias_correction=False,
max_grad_norm=1.0)
if args.loss_scale == 0:
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
else:
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
else:
optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=num_train_optimization_steps)
global_step = 0
logging.info("***** Running training *****")
logging.info(f" Num examples = {total_train_examples}")
logging.info(" Batch size = %d", args.train_batch_size)
logging.info(" Num steps = %d", num_train_optimization_steps)
model.train()
for epoch in range(args.epochs):
epoch_dataset = PregeneratedDataset(epoch=epoch, training_path=args.pregenerated_data, tokenizer=tokenizer,
num_data_epochs=num_data_epochs)
if args.local_rank == -1:
train_sampler = RandomSampler(epoch_dataset)
else:
train_sampler = DistributedSampler(epoch_dataset)
train_dataloader = DataLoader(epoch_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0
with tqdm(total=len(train_dataloader), desc=f"Epoch {epoch}") as pbar:
for step, batch in enumerate(train_dataloader):
batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next)
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
if args.fp16:
optimizer.backward(loss)
else:
loss.backward()
tr_loss += loss.item()
nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1
pbar.update(1)
mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps
pbar.set_postfix_str(f"Loss: {mean_loss:.5f}")
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
# modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used that handles this automatically
lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps,
args.warmup_proportion)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
optimizer.step()
optimizer.zero_grad()
global_step += 1
# Save a trained model
logging.info("** ** * Saving fine-tuned model ** ** * ")
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = args.output_dir / "pytorch_model.bin"
torch.save(model_to_save.state_dict(), str(output_model_file))
if __name__ == '__main__':
main()
from argparse import ArgumentParser
from pathlib import Path
from tqdm import tqdm, trange
from tempfile import TemporaryDirectory
import shelve
from random import random, randint, shuffle, choice, sample
from pytorch_pretrained_bert.tokenization import BertTokenizer
import numpy as np
import json
class DocumentDatabase:
def __init__(self, reduce_memory=False):
if reduce_memory:
self.temp_dir = TemporaryDirectory()
self.working_dir = Path(self.temp_dir.name)
self.document_shelf_filepath = self.working_dir / 'shelf.db'
self.document_shelf = shelve.open(str(self.document_shelf_filepath),
flag='n', protocol=-1)
self.documents = None
else:
self.documents = []
self.document_shelf = None
self.document_shelf_filepath = None
self.temp_dir = None
self.doc_lengths = []
self.doc_cumsum = None
self.cumsum_max = None
self.reduce_memory = reduce_memory
def add_document(self, document):
if self.reduce_memory:
current_idx = len(self.doc_lengths)
self.document_shelf[str(current_idx)] = document
else:
self.documents.append(document)
self.doc_lengths.append(len(document))
def _precalculate_doc_weights(self):
self.doc_cumsum = np.cumsum(self.doc_lengths)
self.cumsum_max = self.doc_cumsum[-1]
def sample_doc(self, current_idx, sentence_weighted=True):
# Uses the current iteration counter to ensure we don't sample the same doc twice
if sentence_weighted:
# With sentence weighting, we sample docs proportionally to their sentence length
if self.doc_cumsum is None or len(self.doc_cumsum) != len(self.doc_lengths):
self._precalculate_doc_weights()
rand_start = self.doc_cumsum[current_idx]
rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx]
sentence_index = randint(rand_start, rand_end) % self.cumsum_max
sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right')
else:
# If we don't use sentence weighting, then every doc has an equal chance to be chosen
sampled_doc_index = current_idx + randint(1, len(self.doc_lengths)-1)
assert sampled_doc_index != current_idx
if self.reduce_memory:
return self.document_shelf[str(sampled_doc_index)]
else:
return self.documents[sampled_doc_index]
def __len__(self):
return len(self.doc_lengths)
def __getitem__(self, item):
if self.reduce_memory:
return self.document_shelf[str(item)]
else:
return self.documents[item]
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, traceback):
if self.document_shelf is not None:
self.document_shelf.close()
if self.temp_dir is not None:
self.temp_dir.cleanup()
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
"""Truncates a pair of sequences to a maximum sequence length. Lifted from Google's BERT repo."""
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_num_tokens:
break
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
assert len(trunc_tokens) >= 1
# We want to sometimes truncate from the front and sometimes from the
# back to add more randomness and avoid biases.
if random() < 0.5:
del trunc_tokens[0]
else:
trunc_tokens.pop()
def create_masked_lm_predictions(tokens, masked_lm_prob, max_predictions_per_seq, vocab_list):
"""Creates the predictions for the masked LM objective. This is mostly copied from the Google BERT repo, but
with several refactors to clean it up and remove a lot of unnecessary variables."""
cand_indices = []
for (i, token) in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
continue
cand_indices.append(i)
num_to_mask = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob))))
shuffle(cand_indices)
mask_indices = sorted(sample(cand_indices, num_to_mask))
masked_token_labels = []
for index in mask_indices:
# 80% of the time, replace with [MASK]
if random() < 0.8:
masked_token = "[MASK]"
else:
# 10% of the time, keep original
if random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = choice(vocab_list)
masked_token_labels.append(tokens[index])
# Once we've saved the true label for that token, we can overwrite it with the masked version
tokens[index] = masked_token
return tokens, mask_indices, masked_token_labels
def create_instances_from_document(
doc_database, doc_idx, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_list):
"""This code is mostly a duplicate of the equivalent function from Google BERT's repo.
However, we make some changes and improvements. Sampling is improved and no longer requires a loop in this function.
Also, documents are sampled proportionally to the number of sentences they contain, which means each sentence
(rather than each document) has an equal chance of being sampled as a false example for the NextSentence task."""
document = doc_database[doc_idx]
# Account for [CLS], [SEP], [SEP]
max_num_tokens = max_seq_length - 3
# We *usually* want to fill up the entire sequence since we are padding
# to `max_seq_length` anyways, so short sequences are generally wasted
# computation. However, we *sometimes*
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
# sequences to minimize the mismatch between pre-training and fine-tuning.
# The `target_seq_length` is just a rough target however, whereas
# `max_seq_length` is a hard limit.
target_seq_length = max_num_tokens
if random() < short_seq_prob:
target_seq_length = randint(2, max_num_tokens)
# We DON'T just concatenate all of the tokens from a document into a long
# sequence and choose an arbitrary split point because this would make the
# next sentence prediction task too easy. Instead, we split the input into
# segments "A" and "B" based on the actual "sentences" provided by the user
# input.
instances = []
current_chunk = []
current_length = 0
i = 0
while i < len(document):
segment = document[i]
current_chunk.append(segment)
current_length += len(segment)
if i == len(document) - 1 or current_length >= target_seq_length:
if current_chunk:
# `a_end` is how many segments from `current_chunk` go into the `A`
# (first) sentence.
a_end = 1
if len(current_chunk) >= 2:
a_end = randint(1, len(current_chunk) - 1)
tokens_a = []
for j in range(a_end):
tokens_a.extend(current_chunk[j])
tokens_b = []
# Random next
if len(current_chunk) == 1 or random() < 0.5:
is_random_next = True
target_b_length = target_seq_length - len(tokens_a)
# Sample a random document, with longer docs being sampled more frequently
random_document = doc_database.sample_doc(current_idx=doc_idx, sentence_weighted=True)
random_start = randint(0, len(random_document) - 1)
for j in range(random_start, len(random_document)):
tokens_b.extend(random_document[j])
if len(tokens_b) >= target_b_length:
break
# We didn't actually use these segments so we "put them back" so
# they don't go to waste.
num_unused_segments = len(current_chunk) - a_end
i -= num_unused_segments
# Actual next
else:
is_random_next = False
for j in range(a_end, len(current_chunk)):
tokens_b.extend(current_chunk[j])
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens)
assert len(tokens_a) >= 1
assert len(tokens_b) >= 1
tokens = ["[CLS]"] + tokens_a + ["[SEP]"] + tokens_b + ["[SEP]"]
# The segment IDs are 0 for the [CLS] token, the A tokens and the first [SEP]
# They are 1 for the B tokens and the final [SEP]
segment_ids = [0 for _ in range(len(tokens_a) + 2)] + [1 for _ in range(len(tokens_b) + 1)]
tokens, masked_lm_positions, masked_lm_labels = create_masked_lm_predictions(
tokens, masked_lm_prob, max_predictions_per_seq, vocab_list)
instance = {
"tokens": tokens,
"segment_ids": segment_ids,
"is_random_next": is_random_next,
"masked_lm_positions": masked_lm_positions,
"masked_lm_labels": masked_lm_labels}
instances.append(instance)
current_chunk = []
current_length = 0
i += 1
return instances
def main():
parser = ArgumentParser()
parser.add_argument('--train_corpus', type=Path, required=True)
parser.add_argument("--output_dir", type=Path, required=True)
parser.add_argument("--bert_model", type=str, required=True,
choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased",
"bert-base-multilingual", "bert-base-chinese"])
parser.add_argument("--do_lower_case", action="store_true")
parser.add_argument("--reduce_memory", action="store_true",
help="Reduce memory usage for large datasets by keeping data on disc rather than in memory")
parser.add_argument("--epochs_to_generate", type=int, default=3,
help="Number of epochs of data to pregenerate")
parser.add_argument("--max_seq_len", type=int, default=128)
parser.add_argument("--short_seq_prob", type=float, default=0.1,
help="Probability of making a short sentence as a training example")
parser.add_argument("--masked_lm_prob", type=float, default=0.15,
help="Probability of masking each token for the LM task")
parser.add_argument("--max_predictions_per_seq", type=int, default=20,
help="Maximum number of tokens to mask in each sequence")
args = parser.parse_args()
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
vocab_list = list(tokenizer.vocab.keys())
with DocumentDatabase(reduce_memory=args.reduce_memory) as docs:
with args.train_corpus.open() as f:
doc = []
for line in tqdm(f, desc="Loading Dataset", unit=" lines"):
line = line.strip()
if line == "":
docs.add_document(doc)
doc = []
else:
tokens = tokenizer.tokenize(line)
doc.append(tokens)
args.output_dir.mkdir(exist_ok=True)
for epoch in trange(args.epochs_to_generate, desc="Epoch"):
epoch_filename = args.output_dir / f"epoch_{epoch}.json"
num_instances = 0
with epoch_filename.open('w') as epoch_file:
for doc_idx in trange(len(docs), desc="Document"):
doc_instances = create_instances_from_document(
docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob,
masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq,
vocab_list=vocab_list)
doc_instances = [json.dumps(instance) for instance in doc_instances]
for instance in doc_instances:
epoch_file.write(instance + '\n')
num_instances += 1
metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json"
with metrics_file.open('w') as metrics_file:
metrics = {
"num_training_examples": num_instances,
"max_seq_len": args.max_seq_len
}
metrics_file.write(json.dumps(metrics))
if __name__ == '__main__':
main()
......@@ -33,9 +33,6 @@ from pytorch_pretrained_bert.modeling import BertForPreTraining
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
from torch.utils.data import Dataset
import random
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
......@@ -404,7 +401,7 @@ def main():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--train_file",
parser.add_argument("--train_corpus",
default=None,
type=str,
required=True,
......@@ -514,8 +511,8 @@ def main():
#train_examples = None
num_train_optimization_steps = None
if args.do_train:
print("Loading Train Dataset", args.train_file)
train_dataset = BERTDataset(args.train_file, tokenizer, seq_len=args.max_seq_length,
print("Loading Train Dataset", args.train_corpus)
train_dataset = BERTDataset(args.train_corpus, tokenizer, seq_len=args.max_seq_length,
corpus_lines=None, on_memory=args.on_memory)
num_train_optimization_steps = int(
len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment