Commit 71553480 authored by LysandreJik's avatar LysandreJik
Browse files

BERT + RoBERTa masking tokens handling + GPU device update.

parent 339e556f
...@@ -65,11 +65,15 @@ def set_seed(args): ...@@ -65,11 +65,15 @@ def set_seed(args):
def mask_tokens(inputs, tokenizer, args): def mask_tokens(inputs, tokenizer, args):
labels = inputs.clone() labels = inputs.clone()
masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).byte() masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).byte()
labels[~masked_indices] = -1 # We only compute loss on masked tokens labels[~masked_indices.bool()] = -1 # We only compute loss on masked tokens
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).byte() & masked_indices indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).byte() & masked_indices
inputs[indices_replaced] = tokenizer.vocab["[MASK]"] # 80% of the time, replace masked input tokens with [MASK]
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).byte() & masked_indices & ~indices_replaced if args.model_name == "bert":
random_words = torch.randint(args.num_embeddings, labels.shape, dtype=torch.long, device=args.device) inputs[indices_replaced.bool()] = tokenizer.vocab["[MASK]"] # 80% of the time, replace masked input tokens with [MASK]
elif args.model_name == "roberta":
inputs[indices_replaced.bool()] = tokenizer.encoder["<mask>"] # 80% of the time, replace masked input tokens with <mask>
indices_random = (torch.bernoulli(torch.full(labels.shape, 0.5)).byte() & masked_indices & ~indices_replaced).bool()
random_words = torch.randint(args.num_embeddings, labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[ inputs[indices_random] = random_words[
indices_random] # 10% of the time, replace masked input tokens with random word indices_random] # 10% of the time, replace masked input tokens with random word
return inputs, labels return inputs, labels
...@@ -132,14 +136,15 @@ def train(args, train_dataset, model, tokenizer): ...@@ -132,14 +136,15 @@ def train(args, train_dataset, model, tokenizer):
for _ in train_iterator: for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator): for step, batch in enumerate(epoch_iterator):
batch.to(args.device)
model.train()
inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch) inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
inputs = inputs.to(args.device)
labels = labels.to(args.device)
model.train()
outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels) outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc) loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
if args.n_gpu > 1: if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training loss = loss.mean() # mean() to average on multi-gpu parallel training
if args.gradient_accumulation_steps > 1: if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps loss = loss / args.gradient_accumulation_steps
...@@ -214,7 +219,7 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -214,7 +219,7 @@ def evaluate(args, model, tokenizer, prefix=""):
nb_eval_steps = 0 nb_eval_steps = 0
for batch in tqdm(eval_dataloader, desc="Evaluating"): for batch in tqdm(eval_dataloader, desc="Evaluating"):
model.eval() model.eval()
batch.to(args.device) batch = batch.to(args.device)
with torch.no_grad(): with torch.no_grad():
outputs = model(batch) outputs = model(batch)
...@@ -285,9 +290,9 @@ def main(): ...@@ -285,9 +290,9 @@ def main():
parser.add_argument("--do_lower_case", action='store_true', parser.add_argument("--do_lower_case", action='store_true',
help="Set this flag if you are using an uncased model.") help="Set this flag if you are using an uncased model.")
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, parser.add_argument("--per_gpu_train_batch_size", default=4, type=int,
help="Batch size per GPU/CPU for training.") help="Batch size per GPU/CPU for training.")
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, parser.add_argument("--per_gpu_eval_batch_size", default=4, type=int,
help="Batch size per GPU/CPU for evaluation.") help="Batch size per GPU/CPU for evaluation.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1, parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.") help="Number of updates steps to accumulate before performing a backward/update pass.")
...@@ -299,7 +304,7 @@ def main(): ...@@ -299,7 +304,7 @@ def main():
help="Epsilon for Adam optimizer.") help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float, parser.add_argument("--max_grad_norm", default=1.0, type=float,
help="Max gradient norm.") help="Max gradient norm.")
parser.add_argument("--num_train_epochs", default=3.0, type=float, parser.add_argument("--num_train_epochs", default=1.0, type=float,
help="Total number of training epochs to perform.") help="Total number of training epochs to perform.")
parser.add_argument("--max_steps", default=-1, type=int, parser.add_argument("--max_steps", default=-1, type=int,
help="If > 0: set total number of training steps to perform. Override num_train_epochs.") help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
......
...@@ -6,8 +6,7 @@ import torch.nn.functional as F ...@@ -6,8 +6,7 @@ import torch.nn.functional as F
class WikiTextDataset(Dataset): class WikiTextDataset(Dataset):
def __init__(self, tokenizer, file='train', directory='wikitext', max_context_length=512, device='cpu'): def __init__(self, tokenizer, file='train', directory='wikitext', max_context_length=512):
self.device = device
self.max_context_length = max_context_length self.max_context_length = max_context_length
self.examples = [] self.examples = []
...@@ -32,7 +31,7 @@ class WikiTextDataset(Dataset): ...@@ -32,7 +31,7 @@ class WikiTextDataset(Dataset):
return len(self.examples) return len(self.examples)
def __getitem__(self, item): def __getitem__(self, item):
return torch.tensor(self.examples[item], device=self.device) return torch.tensor(self.examples[item])
@staticmethod @staticmethod
def collate(values): def collate(values):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment