Commit a1d04b79 authored by Jared Casper's avatar Jared Casper
Browse files

Updating public repo with latest changes.

parent 93ab4bea
__pycache__
......@@ -114,7 +114,8 @@ def add_training_args(parser):
help='report interval')
group.add_argument('--exit-interval', type=int, default=None,
help='Exit the program after this many new iterations.')
group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory')
group.add_argument('--seed', type=int, default=1234,
help='random seed')
# Batch prodecuer arguments
......@@ -123,6 +124,8 @@ def add_training_args(parser):
group.add_argument('--reset-attention-mask', action='store_true',
help='Reset self attention maske after '
'end-of-document token.')
group.add_argument('--eod-mask-loss', action='store_true',
help='Mask loss for the end of document tokens')
# Learning rate.
group.add_argument('--lr-decay-iters', type=int, default=None,
......@@ -133,9 +136,25 @@ def add_training_args(parser):
help='learning rate decay function')
group.add_argument('--lr', type=float, default=1.0e-4,
help='initial learning rate')
group.add_argument('--min-lr', type=float, default=0.0,
help='Minumum value for learning rate. The scheduler'
'clip values below this threshold.')
group.add_argument('--warmup', type=float, default=0.01,
help='percentage of data to warmup on (.01 = 1% of all '
'training iters). Default 0.01')
group.add_argument('--override-lr-scheduler', action='store_true',
help='Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum '
'number of iterations, and decay style from input '
'arguments and ignore values from checkpoints. Note'
'that all the above values will be reset.')
group.add_argument('--use-checkpoint-lr-scheduler', action='store_true',
help='Use checkpoint to set the values of the scheduler '
'(learning rate, warmup iterations, minimum learning '
'rate, maximum number of iterations, and decay style '
'from input arguments and ignore values from '
'checkpoints. Notethat all the above values will be '
'reset.')
# model checkpointing
group.add_argument('--save', type=str, default=None,
help='Output directory to save checkpoints to.')
......@@ -163,8 +182,17 @@ def add_training_args(parser):
group.add_argument('--distributed-backend', default='nccl',
help='which backend to use for distributed '
'training. One of [gloo, nccl]')
group.add_argument('--DDP-impl', default='local',
help='which DistributedDataParallel implementation '
'to use. One of [local, torch]')
group.add_argument('--local_rank', type=int, default=None,
help='local rank passed from distributed launcher')
# autoresume
group.add_argument('--adlr-autoresume', action='store_true',
help='enable autoresume on adlr cluster.')
group.add_argument('--adlr-autoresume-interval', type=int, default=1000,
help='intervals over which check for autoresume'
'termination signal')
return parser
......@@ -193,6 +221,8 @@ def add_evaluation_args(parser):
help='sliding window for overlapping eval ')
group.add_argument('--cloze-eval', action='store_true',
help='Evaluation dataset from `--valid-data` is a cloze task')
group.add_argument('--strict-lambada', action='store_true',
help='use more difficult formulation of lambada')
group.add_argument('--eval-hf', action='store_true',
help='perform evaluation with huggingface openai model.'
'use `--load` to specify weights path to be loaded')
......@@ -207,9 +237,23 @@ def add_text_generate_args(parser):
group = parser.add_argument_group('Text generation', 'configurations')
group.add_argument("--temperature", type=float, default=1.0)
group.add_argument("--greedy", action='store_true', default=False)
group.add_argument("--top_p", type=float, default=0.0)
group.add_argument("--top_k", type=int, default=0)
group.add_argument("--out-seq-length", type=int, default=256)
group.add_argument("--out-seq-length", type=int, default=1024)
group.add_argument("--sample-input-file", type=str, default="",
help='get input from file instead of interactive mode, '
'each line is an input' )
group.add_argument("--sample-output-file", type=str, default="",
help='output file got from --sample-input-file')
group.add_argument("--num-samples", type=int, default=0,
help='number of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling')
group.add_argument("--genfile", type=str,
help='output file when generating unconditionally')
group.add_argument("--recompute", action='store_true',
help='during generation recompute all attention '
'instead of using previously computed keys/values.')
return parser
......
......@@ -148,7 +148,8 @@ def make_loaders(args):
'model_type': args.tokenizer_model_type,
'cache_dir': args.cache_dir,
'max_preds_per_seq': args.max_preds_per_seq,
'presplit_sentences': args.presplit_sentences}
'presplit_sentences': args.presplit_sentences,
'parallel_group': mpu.get_data_parallel_group()}
eval_set_args = copy.copy(data_set_args)
eval_set_args['split'] = [1.]
......
......@@ -16,6 +16,8 @@
import os
import math
import torch
from .samplers import DistributedBatchSampler
from .datasets import json_dataset, csv_dataset, split_ds, ConcatDataset, SplitDataset, bert_sentencepair_dataset, GPT2Dataset
from .lazy_loader import exists_lazy, make_lazy, lazy_array_loader
......@@ -61,7 +63,8 @@ def supported_corpus(corpus_name):
def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=None, split=[1.],
delim=',', loose=False, binarize_sent=False, drop_unlabeled=False, tokenizer=None,
tokenizer_type='CharacterLevelTokenizer', tokenizer_model_path=None, vocab_size=None,
model_type='bpe', pad_token=0, character_converage=1.0, non_binary_cols=None, **kwargs):
model_type='bpe', pad_token=0, character_converage=1.0, non_binary_cols=None,
parallel_group=None, **kwargs):
"""function to create datasets+tokenizers for common options"""
if isinstance(process_fn, str):
process_fn = eval(process_fn)
......@@ -76,11 +79,19 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
named_corpora = True
name = path_
path_ = corpora.NAMED_CORPORA[path_].PATH
if not exists_lazy(path_, data_type='data'):
if torch.distributed.get_rank() == 0 and not exists_lazy(path_, data_type='data'):
# create cached version of dataset for lazy loading if it doesn't exist
text = get_dataset(name if named_corpora else path_, text_key=text_key, label_key=label_key, binarize_sent=binarize_sent,
delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose)
make_lazy(path_, text.X, data_type='data')
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=parallel_group)
assert counts[0].item() == torch.distributed.get_world_size(
group=parallel_group)
text = lazy_array_loader(path_, data_type='data', map_fn=process_fn)
else:
# get dataset
......@@ -107,15 +118,17 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
# Split dataset into train/val/test (and wrap bert dataset)
if should_split(split):
ds = split_ds(ds, split)
if ds_type.lower() == 'bert':
if 'bert' in ds_type.lower():
presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
ds = [bert_sentencepair_dataset(d, max_seq_len=seq_length, presplit_sentences=presplit_sentences) if d is not None else None for d in ds]
dstype = bert_sentencepair_dataset
ds = [dstype(d, max_seq_len=seq_length, presplit_sentences=presplit_sentences) if d is not None else None for d in ds]
elif ds_type.lower() == 'gpt2':
ds = [GPT2Dataset(d, max_seq_len=seq_length) if d is not None else None for d in ds]
else:
if ds_type.lower() == 'bert':
if 'bert' in ds_type.lower():
presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
ds = bert_sentencepair_dataset(ds, max_seq_len=seq_length, presplit_sentences=presplit_sentences)
dstype = bert_sentencepair_dataset
ds = dstype(ds, max_seq_len=seq_length, presplit_sentences=presplit_sentences)
elif ds_type.lower() == 'gpt2':
ds = GPT2Dataset(ds, max_seq_len=seq_length)
return ds, tokenizer
......@@ -461,6 +461,7 @@ class GPT2Dataset(data.Dataset):
weighted=True,
sample_across_doc=True,
random_across_doc_sampling=True,
bias_for_single_doc=False,
sentence_start=False, **kwargs):
self.ds = ds
self.ds_len = len(self.ds)
......@@ -473,6 +474,7 @@ class GPT2Dataset(data.Dataset):
self.weighted = weighted
self.sample_across_doc = sample_across_doc
self.random_across_doc_sampling = random_across_doc_sampling
self.bias_for_single_doc = bias_for_single_doc
self.sentence_start = sentence_start
self.init_weighting()
......@@ -510,7 +512,10 @@ class GPT2Dataset(data.Dataset):
# truncate or pad tokens
num_tokens = len(tokens)
if self.bias_for_single_doc:
tokens_to_strip = num_tokens - self.max_seq_len - 1
else:
tokens_to_strip = num_tokens - 1
if tokens_to_strip > 0:
strip_left_tokens = rng.randint(tokens_to_strip + 1)
tokens = tokens[strip_left_tokens:]
......@@ -576,7 +581,7 @@ class bert_sentencepair_dataset(data.Dataset):
dataset_size (int): number of random sentencepairs in the dataset. Default: len(ds)*(len(ds)-1)
"""
def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None, short_seq_prob=.01, dataset_size=None, presplit_sentences=False, weighted=True,**kwargs):
def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None, short_seq_prob=.01, dataset_size=None, presplit_sentences=False, weighted=True, **kwargs):
self.ds = ds
self.ds_len = len(self.ds)
self.tokenizer = self.ds.GetTokenizer()
......@@ -758,7 +763,8 @@ class bert_sentencepair_dataset(data.Dataset):
"""
tokens_a, token_types_a = a
tokens_b, token_types_b = b
max_num_tokens = max_seq_len - 3
max_num_tokens = self.calc_seq_len(max_seq_len)
# max_num_tokens = max_seq_len - 3
while True:
len_a = len(tokens_a)
len_b = len(tokens_b)
......@@ -782,6 +788,9 @@ class bert_sentencepair_dataset(data.Dataset):
trunc_types.pop()
return (tokens_a, token_types_a), (tokens_b, token_types_b)
def calc_seq_len(self, max_seq_len):
return max_seq_len - 3
def mask_token(self, idx, tokens, types, vocab_words, rng):
"""
helper function to mask `idx` token from `tokens` according to
......@@ -807,6 +816,11 @@ class bert_sentencepair_dataset(data.Dataset):
seq += [self.tokenizer.get_command('pad').Id] * num_pad
return seq, pad_mask
def concat_tokens(self, tokens_a, token_types_a, tokens_b, token_types_b):
tokens = [self.tokenizer.get_command('ENC').Id] + tokens_a + [self.tokenizer.get_command('sep').Id] + tokens_b + [self.tokenizer.get_command('sep').Id]
token_types = [token_types_a[0]] + token_types_a + [token_types_a[0]] + token_types_b + [token_types_b[0]]
return tokens, token_types
def create_masked_lm_predictions(self, a, b, mask_lm_prob, max_preds_per_seq, vocab_words, rng):
"""
Mask sequence pair for BERT training according to:
......@@ -814,8 +828,7 @@ class bert_sentencepair_dataset(data.Dataset):
"""
tokens_a, token_types_a = a
tokens_b, token_types_b = b
tokens = [self.tokenizer.get_command('ENC').Id] + tokens_a + [self.tokenizer.get_command('sep').Id] + tokens_b + [self.tokenizer.get_command('sep').Id]
token_types = [token_types_a[0]] + token_types_a + [token_types_a[0]] + token_types_b + [token_types_b[0]]
tokens, token_types = self.concat_tokens(tokens_a, token_types_a, tokens_b, token_types_b)
len_a = len(tokens_a)
len_b = len(tokens_b)
......
......@@ -111,7 +111,7 @@ class lazy_array_loader(object):
lazypath = get_lazy_path(path)
datapath = os.path.join(lazypath, data_type)
#get file where array entries are concatenated into one big string
self._file = open(datapath, 'rb')
self._file = open(datapath, 'rb', buffering=0)
self.file = self._file
#memory map file if necessary
self.mem_map = mem_map
......
......@@ -795,7 +795,6 @@ class BertWordPieceTokenizer(Tokenizer):
Tokens = Tokens.tokenization
return ' '.join(Tokens)
class GPT2BPETokenizer(Tokenizer):
def __init__(self, cache_dir=None, **kwargs):
self.text_tokenizer = GPT2Tokenizer.from_pretrained('gpt2',
......@@ -887,4 +886,3 @@ class GPT2BPETokenizer(Tokenizer):
if isinstance(Tokens, Tokenization):
Tokens = Tokens.tokenization
return self.text_tokenizer.decode([self.TokenToId(tok) for tok in Tokens])
......@@ -210,10 +210,13 @@ def forward_step(data_iterator, model, args, timers):
lm_loss = torch.sum(
losses.view(-1) * loss_mask.float())
else:
outputs = torch.argmax(output, -1).contiguous().view(-1)
acc = (outputs == lm_labels.contiguous().view(-1)).float()
loss_mask = loss_mask.contiguous().view(-1).float()
lm_loss = torch.sum(acc * loss_mask)
outputs = torch.argmax(output, -1)
correct = (outputs == lm_labels).float()
correct[(1-loss_mask).bool()] = 1
correct = correct.prod(-1)
lm_loss = correct.sum()
# loss_mask = loss_mask.contiguous().view(-1).float()
# lm_loss = torch.sum(acc * loss_mask)
return lm_loss
......@@ -345,7 +348,7 @@ def set_random_seed(seed):
class LM_Eval_Dataset(torch.utils.data.Dataset):
def __init__(self, tokens, seq_len, pad_idx, overalapping_eval=None):
def __init__(self, tokens, seq_len, pad_idx, overalapping_eval=None, **kwargs):
self.tokens = tokens
self.seq_len = seq_len
self.pad_idx = pad_idx
......@@ -379,15 +382,30 @@ class LM_Eval_Dataset(torch.utils.data.Dataset):
return {'text': np.array(tokens), 'pad_mask': pad_mask}
class Lambada_Eval_Dataset(torch.utils.data.Dataset):
def __init__(self, path, tokenizer, seq_len):
def __init__(self, path, tokenizer, seq_len, strict=False, **kwargs):
self.seq_len = seq_len
self.pad_idx = tokenizer.get_command('pad').Id
self.tokenizer = tokenizer
self.strict = strict
self.tokens = []
self.labels = []
with open(path, 'r') as f:
for line in f.readlines():
text = json.loads(line)['text']
self.tokens.append(tokenizer.EncodeAsIds(text).tokenization)
tokens, labels = self.get_tokens(text)
self.tokens.append(tokens)
self.labels.append(labels)
def get_tokens(self, text):
if not self.strict:
tokens = self.tokenizer.EncodeAsIds(text).tokenization
return tokens[:-1], [tokens[-1]]
last_token = text.split()[-1]
start_idx = text.rfind(last_token)
beginning_tokens = self.tokenizer.EncodeAsIds(text[:start_idx].strip()).tokenization
last_token = self.tokenizer.EncodeAsIds(' '+last_token).tokenization
return beginning_tokens, last_token
def __len__(self):
return len(self.tokens)
......@@ -397,7 +415,10 @@ class Lambada_Eval_Dataset(torch.utils.data.Dataset):
tokens = self.tokens[idx]
num_tokens = len(tokens)
pad_mask = [0]*num_tokens
pad_mask[-1] = 1
labels = self.labels[idx]
pad_mask += [1]*len(labels)
tokens = tokens+labels
num_tokens = len(tokens)
if num_tokens < self.seq_len+1:
num_pad = (self.seq_len+1-num_tokens)
pad_mask += [0]*(num_pad)
......@@ -442,7 +463,7 @@ def get_eval_data(args):
val_dataset = LM_Eval_Dataset(tokenized_data, seq_len, eod_token,
args.overlapping_eval)
else:
val_dataset = Lambada_Eval_Dataset(valid_data, tokenizer, seq_len)
val_dataset = Lambada_Eval_Dataset(valid_data, tokenizer, seq_len, args.strict_lambada)
num_tokenized_tokens = 0
num_original_tokens = 0
val_dataloader = torch.utils.data.DataLoader(
......@@ -450,7 +471,9 @@ def get_eval_data(args):
before = tokenizer.num_tokens
after = before
while after % mpu.get_model_parallel_world_size() != 0:
multiple = args.make_vocab_size_divisible_by * \
mpu.get_model_parallel_world_size()
while (after % multiple) != 0:
after += 1
print_rank_0('> padded vocab (size: {}) with {} dummy tokens (new size: {})'.
format(before, after - before, after))
......
......@@ -17,6 +17,8 @@
import os
import random
import json
import copy
import numpy as np
import torch
import torch.nn.functional as F
......@@ -83,9 +85,10 @@ def setup_model(args):
return model
def get_batch(context_tokens, device, args):
def get_batch(context_tokens, args):
tokens = context_tokens
tokens = tokens.view(args.batch_size, -1).contiguous()
device = args.device
tokens = tokens.to(device)
# Get the masks and postition ids.
......@@ -108,8 +111,8 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
if top_p > 0.0:
#convert to 1D
logits=logits.view(logits.size()[1]).contiguous()
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
# logits=logits.view(logits.size()[1]).contiguous()
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
......@@ -117,15 +120,32 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
for i in range(sorted_indices.size(0)):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = filter_value
#going back to 2D
logits=logits.view(1, -1).contiguous()
# logits=logits.view(1, -1).contiguous()
return logits
def generate_samples_input_from_file(model, tokenizer, args):
if args.sample_input_file == "":
if mpu.get_model_parallel_rank() == 0:
print("args.sample_input_file CAN NOT BE empty!\n")
return
def generate_samples(model, tokenizer, args, device):
if mpu.get_model_parallel_rank() == 0:
fname = open(args.sample_input_file, "r")
all_raw_text = fname.readlines()
input_count = len(all_raw_text)
input_pos = 0
if args.sample_output_file == "":
print("Argument: sample-output-file can't be empty, setting it to\n")
print("\t args.sample_input_file.out")
args.sample_output_file = args.sample_input_file+".out"
fname_out = open(args.sample_output_file, "w+")
context_count=0
model.eval()
......@@ -135,6 +155,74 @@ def generate_samples(model, tokenizer, args, device):
terminate_runs=0
if mpu.get_model_parallel_rank() == 0:
raw_text = all_raw_text[input_pos]
input_pos += 1
if input_pos == input_count:
raw_text = "stop"
if "stop" in raw_text:
terminate_runs = 1
else:
context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization
context_length = len(context_tokens)
if context_length >=args.seq_length//2:
print("\nContext length", context_length, \
"\nPlease give smaller context (half of the sequence length)!")
continue
else:
context_tokens = tokenizer.EncodeAsIds("EMPTY TEXT").tokenization
context_length = len(context_tokens)
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item()
if terminate_runs == 1:
return
start_time = time.time()
token_stream = get_token_stream(model, [context_tokens], tokenizer, args)
for counter, decode_tokens in enumerate(token_stream):
# token_end = decode_tokens.find("<|endoftext|>")
# if token_end > 0:
# break
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
if mpu.get_model_parallel_rank() == 0:
os.system('clear')
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)[len(raw_text):]
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
fname_out.write("\nContext:")
fname_out.write(raw_text)
fname_out.write("\n\nMegatron-LM:")
fname_out.write(trim_decode_tokens)
#fname_out.write(trim_decode_tokens.replace("\n", "\n\n"))
fname_out.write("\n")
raw_text = None
torch.distributed.barrier(group=mpu.get_model_parallel_group())
context_count += 1
def generate_samples_interactive(model, tokenizer, args):
print_frequency = 24
context_count=0
model.eval()
with torch.no_grad():
while True:
torch.distributed.barrier(group=mpu.get_model_parallel_group())
terminate_runs=0
if mpu.get_model_parallel_rank() == 0:
os.system('clear')
raw_text = input("\nContext prompt (stop to exit) >>> ")
while not raw_text:
print('Prompt should not be empty!')
......@@ -161,61 +249,180 @@ def generate_samples(model, tokenizer, args, device):
if terminate_runs == 1:
return
start_time = time.time()
token_stream = get_token_stream(model, [context_tokens], tokenizer, args)
for counter, decode_tokens in enumerate(token_stream):
# token_end = decode_tokens.find("<|endoftext|>")
# if token_end > 0:
# break
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
if mpu.get_model_parallel_rank() == 0 and counter % print_frequency == 0:
os.system('clear')
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)[len(raw_text):]
#print("\nGPT2:", trim_decode_tokens, flush=True)
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
if mpu.get_model_parallel_rank() == 0:
os.system('clear')
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)[len(raw_text):]
#print("\nGPT2:", trim_decode_tokens, flush=True)
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
raw_text = None
torch.distributed.barrier(group=mpu.get_model_parallel_group())
context_count += 1
if mpu.get_model_parallel_rank() == 0:
input("\nPress any key to continue >>>")
def generate_samples_unconditional(model, tokenizer, args):
num_samples = args.num_samples
context_tokens = [[tokenizer.get_command('pad').Id] for _ in range(args.batch_size)]
samples = []
# with open(args.genfile, 'w') as f:
ctr = 0
while True:
start_time = time.time()
for token_stream in get_token_stream(model, copy.deepcopy(context_tokens), tokenizer, args):
pass
# token_stream = list(get_token_stream(model, copy.deepcopy(context_tokens), tokenizer, args))
if ctr%args.log_interval == 0:
print('Avg s/batch:', (time.time()-start_time)/min(args.log_interval, ctr+1))
start_time = time.time()
length = len(token_stream)
token_batch = token_stream[0].cpu().numpy().tolist()
length_batch = token_stream[1].cpu().numpy().tolist()
for tokens, length in zip(token_batch, length_batch):
tokens = tokens[1:length-1]
text = tokenizer.DecodeIds(tokens)
is_finished = length < args.seq_length - 1
datum = {'text': text, 'length': length-1, 'finished': is_finished}
yield datum
ctr += 1
if ctr >= num_samples:
break
if ctr >= num_samples:
break
def write_and_generate_samples_unconditional(model, tokenizer, args):
assert args.genfile is not None
with open(args.genfile, 'w') as f:
for datum in generate_samples_unconditional(model, tokenizer, args):
f.write(json.dumps(datum)+'\n')
def pad_batch(batch, tokenizer, args):
pad_id = tokenizer.get_command('pad').Id
context_lengths = []
for tokens in batch:
context_length = len(tokens)
if context_length < args.seq_length:
context_tokens.extend([pad_id] * (args.seq_length - context_length))
tokens.extend([pad_id]*(args.seq_length-context_length))
context_lengths.append(context_length)
return batch, context_lengths
def get_token_stream(model, context_tokens, tokenizer, args):
pad_id = tokenizer.get_command('pad').Id
# context_length = len(context_tokens)
# if context_length < args.seq_length:
# context_tokens = context_tokens + [pad_id] * (args.seq_length - context_length)
context_tokens, context_lengths = pad_batch(context_tokens, tokenizer, args)
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor([context_length])
context_length_tensor = torch.cuda.LongTensor(context_lengths)
# context_length_tensor = torch.cuda.LongTensor([context_length])
torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group())
torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group())
context_length = context_length_tensor[0].item()
tokens, attention_mask, position_ids=get_batch(context_tokens_tensor, device, args)
context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids=get_batch(context_tokens_tensor, args)
start_time = time.time()
counter = 0
org_context_length = context_length
layer_past = None
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids, tokenizer, args)
for tokens, lengths in batch_token_iterator:
context_length += 1
yield tokens[:, :context_length], lengths
def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
return (1-boolean)*val1 + boolean*val2
def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None):
model.eval()
with torch.no_grad():
context_length = context_lengths.min().item()
eos_id = tokenizer.get_command('eos').Id
counter = 0
org_context_length = context_length
while counter < (org_context_length + args.out_seq_length):
layer_past = None
batch_size = context_tokens.size(0)
is_done = torch.zeros([batch_size]).byte().cuda()
tokens = context_tokens
if maxlen is None:
maxlen = args.seq_length - 1
if maxlen > (org_context_length + args.out_seq_length):
maxlen = org_context_length + args.out_seq_length
lengths = torch.ones([batch_size]).long().cuda()*maxlen
while context_length <= (maxlen):
if args.recompute:
logits = model(tokens, position_ids, attention_mask)
logits = logits[:, context_length - 1, :] / args.temperature
logits = logits[:, context_length - 1, :]
else:
if counter == 0:
tokens2use = tokens[:, :context_length]
positions2use = position_ids[:, :context_length]
else:
tokens2use = tokens[:, context_length - 1].view(batch_size, -1)
positions2use = position_ids[:, context_length - 1].view(batch_size, -1)
logits, layer_past = model(tokens2use, positions2use, attention_mask, layer_past=layer_past, get_present=True)
logits = logits[:, -1].view(batch_size,-1).contiguous()
if args.greedy:
prev = torch.argmax(logits, dim=-1).view(-1)
else:
logits /= args.temperature
logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1)
tokens[0, context_length] = prev[0]
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
print_logits = []
for p in prev:
print_logits.append([logits[i, p].item() for i in range(batch_size)])
started = context_lengths <= context_length
tokens[:, context_length] = switch(tokens[:, context_length].view(-1), prev, started)
context_length += 1
counter += 1
output_tokens_list = tokens.view(-1).contiguous()
decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist())
token_end = decode_tokens.find("<|endoftext|>")
done_token = (prev == eos_id).byte()
just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length
was_done = is_done
is_done = is_done | done_token
done = torch.all(is_done)
if mpu.get_model_parallel_rank() == 0 and (counter % 16 == 0 or token_end != -1):
os.system('clear')
print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True)
trim_decode_tokens = decode_tokens[len(raw_text):decode_tokens.find("<|endoftext|>")]
print("\nGPT2:", trim_decode_tokens, flush=True)
if token_end != -1:
yield tokens, lengths
if done:
break
if mpu.get_model_parallel_rank() == 0:
os.system('clear')
print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True)
output_tokens_list = tokens.view(-1).contiguous()
decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist())
trim_decode_tokens = decode_tokens[len(raw_text):decode_tokens.find("<|endoftext|>")]
print("\nGPT2:", trim_decode_tokens, flush=True)
raw_text = None
torch.distributed.barrier(group=mpu.get_model_parallel_group())
context_count += 1
def prepare_tokenizer(args):
tokenizer_args = {
......@@ -232,7 +439,10 @@ def prepare_tokenizer(args):
args.eod_token = tokenizer.get_command('eos').Id
after = tokenizer.num_tokens
while after % mpu.get_model_parallel_world_size() != 0:
multiple = args.make_vocab_size_divisible_by * \
mpu.get_model_parallel_world_size()
if multiple != 0:
while (after % multiple) != 0:
after += 1
args.vocab_size = after
......@@ -267,10 +477,19 @@ def main():
model = setup_model(args)
#setting default batch size to 1
args.batch_size = 1
# args.batch_size = 1
args.device = torch.cuda.current_device()
#generate samples
generate_samples(model, tokenizer, args, torch.cuda.current_device())
if args.num_samples == 0:
args.batch_size = 1
if args.sample_input_file != "":
generate_samples_input_from_file(model, tokenizer, args)
else:
generate_samples_interactive(model, tokenizer, args)
else:
write_and_generate_samples_unconditional(model, tokenizer, args)
if __name__ == "__main__":
......
......@@ -18,36 +18,48 @@ import torch
from torch.optim.lr_scheduler import _LRScheduler
import math
from utils import print_rank_0
class AnnealingLR(_LRScheduler):
"""Anneals the learning rate from start to zero along a cosine curve."""
"""Anneals the learning rate"""
DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None']
DECAY_STYLES = ['linear', 'cosine', 'constant', 'None']
def __init__(self, optimizer, start_lr, warmup_iter, num_iters, decay_style=None, last_iter=-1):
def __init__(self, optimizer, start_lr, warmup_iter, num_iters,
decay_style=None, last_iter=-1, min_lr=0.0,
use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False):
self.optimizer = optimizer
self.start_lr = start_lr
self.min_lr = min_lr
self.warmup_iter = warmup_iter
self.num_iters = last_iter + 1
self.end_iter = num_iters
self.decay_style = decay_style.lower() if isinstance(decay_style, str) else None
self.decay_style = decay_style.lower() if isinstance(decay_style, str) \
else None
self.override_lr_scheduler = override_lr_scheduler
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
if self.override_lr_scheduler:
assert not self.use_checkpoint_lr_scheduler, 'both override and '\
'use-checkpoint are set.'
self.step(self.num_iters)
if torch.distributed.get_rank() == 0:
print('learning rate decaying', decay_style)
def get_lr(self):
# https://openreview.net/pdf?id=BJYwwY9ll pg. 4
num_iters_ = min(self.num_iters, self.end_iter - self.warmup_iter)
if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter:
return float(self.start_lr) * self.num_iters / self.warmup_iter
return float(self.start_lr) * num_iters_ / self.warmup_iter
else:
if self.decay_style == self.DECAY_STYLES[0]:
return self.start_lr*((self.end_iter-(self.num_iters-self.warmup_iter))/self.end_iter)
lr = self.start_lr * ((self.end_iter - (num_iters_ - self.warmup_iter)) / self.end_iter)
elif self.decay_style == self.DECAY_STYLES[1]:
return self.start_lr / 2.0 * (math.cos(math.pi * (self.num_iters - self.warmup_iter) / self.end_iter) + 1)
elif self.decay_style == self.DECAY_STYLES[2]:
#TODO: implement exponential decay
return self.start_lr
lr = self.start_lr / 2.0 * (math.cos(math.pi * (num_iters_ - self.warmup_iter) / self.end_iter) + 1)
else:
return self.start_lr
lr = self.start_lr
return max(lr, self.min_lr)
def step(self, step_num=None):
if step_num is None:
......@@ -63,14 +75,38 @@ class AnnealingLR(_LRScheduler):
'warmup_iter': self.warmup_iter,
'num_iters': self.num_iters,
'decay_style': self.decay_style,
'end_iter': self.end_iter
'end_iter': self.end_iter,
'min_lr': self.min_lr
}
return sd
def check_and_set_(self, cls_value, sd_value, name):
if self.override_lr_scheduler:
print_rank_0(' > overriding {} value to {}'.format(name, cls_value))
return cls_value
else:
if not self.use_checkpoint_lr_scheduler:
assert cls_value == sd_value, 'AnnealingLR: class input value' \
'and checkpoint values for {} do not match'.format(name)
print_rank_0(' > using checkpoint value {} for {}'.format(sd_value,
name))
return sd_value
def load_state_dict(self, sd):
self.start_lr = sd['start_lr']
self.warmup_iter = sd['warmup_iter']
self.start_lr = self.check_and_set_(self.start_lr, sd['start_lr'],
'learning rate')
self.min_lr = self.check_and_set_(self.min_lr, sd['min_lr'],
'minimum learning rate')
self.warmup_iter = self.check_and_set_(self.warmup_iter,
sd['warmup_iter'],
'warmup iterations')
self.end_iter = self.check_and_set_(self.end_iter, sd['end_iter'],
'total number of iterations')
self.decay_style = self.check_and_set_(self.decay_style,
sd['decay_style'],
'decay style')
self.num_iters = sd['num_iters']
self.end_iter = sd['end_iter']
self.decay_style = sd['decay_style']
self.step(self.num_iters)
......@@ -65,6 +65,14 @@ class GPT2Model(torch.nn.Module):
# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(max_sequence_length,
hidden_size)
# Token type embedding.
# Add this as an optional field that can be added through
# method call so we can load a pretrain model without
# token types and add them as needed.
self.tokentype_embeddings = None
self.hidden_size = hidden_size
# Initialize the position embeddings.
init_method(self.position_embeddings.weight)
......@@ -80,18 +88,39 @@ class GPT2Model(torch.nn.Module):
checkpoint_activations,
checkpoint_num_layers)
def forward(self, input_ids, position_ids, attention_mask):
def add_tokentype_embeddings(self, num_tokentypes):
if self.tokentype_embeddings is not None:
raise Exception('tokentype embeddings is already initialized')
if torch.distributed.get_rank() == 0:
print('adding embedding for {} tokentypes'.format(num_tokentypes),
flush=True)
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
self.hidden_size)
def forward(self, input_ids, position_ids, attention_mask,
layer_past=None, get_present=False, tokentype_ids=None):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
else:
assert self.tokentype_embeddings is None
# Dropout.
embeddings = self.embedding_dropout(embeddings)
# Transformer.
transformer_output = self.transformer(embeddings, attention_mask)
transformer_output = self.transformer(embeddings, attention_mask,
layer_past=layer_past,
get_present=get_present)
if get_present:
transformer_output, presents = transformer_output
# Parallel logits.
transformer_output_parallel = mpu.copy_to_model_parallel_region(
......@@ -100,9 +129,12 @@ class GPT2Model(torch.nn.Module):
self.word_embeddings.weight)
if self.parallel_output:
return logits_parallel
return mpu.gather_from_model_parallel_region(logits_parallel)
output = logits_parallel
else:
output = mpu.gather_from_model_parallel_region(logits_parallel)
if get_present:
output = [output, presents]
return output
def gpt2_get_params_for_weight_decay_optimization(module):
......
......@@ -98,7 +98,7 @@ class GPT2ParallelSelfAttention(torch.nn.Module):
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def forward(self, hidden_states, ltor_mask):
def forward(self, hidden_states, ltor_mask, layer_past=None, get_present=False):
# hidden_states: [b, s, h]
# ltor_mask: [1, 1, s, s]
......@@ -112,13 +112,24 @@ class GPT2ParallelSelfAttention(torch.nn.Module):
query_layer = self._transpose_for_scores(mixed_query_layer)
key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer)
if layer_past is not None:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=-2)
value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=-2)
present = (key_layer, value_layer)
# Raw attention scores. [b, np, s, s]
attention_scores = torch.matmul(query_layer,
key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(
self.hidden_size_per_attention_head)
norm_factor = math.sqrt(math.sqrt(self.hidden_size_per_attention_head))
attention_scores = torch.matmul(query_layer/norm_factor,
key_layer.transpose(-1, -2)/norm_factor)
# Apply the left to right attention mask.
if get_present:
with torch.no_grad():
if layer_past is not None:
ltor_mask = ltor_mask[...,attention_scores.size(3)-1, :attention_scores.size(3)].unsqueeze(2)
else:
ltor_mask = ltor_mask[...,:attention_scores.size(3), :attention_scores.size(3)]
attention_scores = torch.mul(attention_scores, ltor_mask) - \
10000.0 * (1.0 - ltor_mask)
......@@ -143,6 +154,9 @@ class GPT2ParallelSelfAttention(torch.nn.Module):
output = self.dense(context_layer)
output = self.output_dropout(output)
if get_present:
output = [output, present]
return output
......@@ -268,14 +282,16 @@ class GPT2ParallelTransformerLayer(torch.nn.Module):
init_method,
output_layer_init_method=output_layer_init_method)
def forward(self, hidden_states, ltor_mask):
def forward(self, hidden_states, ltor_mask, layer_past=None, get_present=False):
# hidden_states: [b, s, h]
# ltor_mask: [1, 1, s, s]
# Layer norm at the begining of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output = self.attention(layernorm_output, ltor_mask)
attention_output = self.attention(layernorm_output, ltor_mask, layer_past=layer_past, get_present=get_present)
if get_present:
attention_output, presents = attention_output
# Residual connection.
layernorm_input = hidden_states + attention_output
# Layer norm post the self attention.
......@@ -285,6 +301,9 @@ class GPT2ParallelTransformerLayer(torch.nn.Module):
# Second residual connection.
output = layernorm_input + mlp_output
if get_present:
output = [output, presents]
return output
......@@ -376,7 +395,7 @@ class GPT2ParallelTransformer(torch.nn.Module):
# Final layer norm before output.
self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
def forward(self, hidden_states, attention_mask):
def forward(self, hidden_states, attention_mask, layer_past=None, get_present=False):
def custom(start, end):
def custom_forward(*inputs):
......@@ -387,7 +406,7 @@ class GPT2ParallelTransformer(torch.nn.Module):
return x_
return custom_forward
if self.checkpoint_activations:
if self.checkpoint_activations and not get_present:
l = 0
num_layers = len(self.layers)
chunk_length = self.checkpoint_num_layers
......@@ -396,11 +415,20 @@ class GPT2ParallelTransformer(torch.nn.Module):
hidden_states, attention_mask)
l += chunk_length
else:
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask)
presents = []
for i, layer in enumerate(self.layers):
past = None
if layer_past is not None:
past = layer_past[i]
hidden_states = layer(hidden_states, attention_mask, layer_past=past, get_present=get_present)
if get_present:
hidden_states, present = hidden_states
presents.append(present)
# Final layer norm.
output = self.final_layernorm(hidden_states)
if get_present:
output = [output, presents]
return output
......
......@@ -46,7 +46,8 @@ from utils import report_memory
from utils import print_args
from utils import print_params_min_max_norm
from utils import print_rank_0
from utils import enable_adlr_autoresume
from utils import check_adlr_autoresume_termination
def get_model(args):
"""Build the model."""
......@@ -114,7 +115,8 @@ def get_optimizer(model, args):
param.model_parallel = False
# Use Adam.
optimizer = Adam(param_groups,
betas = (0.9, 0.999)
optimizer = Adam(param_groups, betas=betas,
lr=args.lr, weight_decay=args.weight_decay)
# Wrap into fp16 optimizer.
......@@ -145,7 +147,10 @@ def get_learning_rate_scheduler(optimizer, args):
warmup_iter=warmup_iter,
num_iters=num_iters,
decay_style=args.lr_decay_style,
last_iter=init_step)
last_iter=init_step,
min_lr=args.min_lr,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
override_lr_scheduler=args.override_lr_scheduler)
return lr_scheduler
......@@ -299,7 +304,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
def train(model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator, timers, args):
train_data_iterator, val_data_iterator, timers, args, writer):
"""Train the model."""
# Turn on training mode which enables dropout.
......@@ -326,15 +331,37 @@ def train(model, optimizer, lr_scheduler,
iteration += 1
# Update losses.
total_lm_loss += lm_loss.data.detach().float()
total_nsp_loss += nsp_loss.data.detach().float()
current_lm_loss = lm_loss.data.detach().float()
current_nsp_loss = nsp_loss.data.detach().float()
total_lm_loss += current_lm_loss
total_nsp_loss += current_nsp_loss
# Logging.
if iteration % args.log_interval == 0:
timers_to_log = ['forward', 'backward', 'optimizer',
'batch generator', 'data loader']
learning_rate = optimizer.param_groups[0]['lr']
if writer and args.rank == 0:
writer.add_scalar('learning_rate', learning_rate, iteration)
writer.add_scalar('lm_loss', current_lm_loss, iteration)
writer.add_scalar('nsp_loss', current_nsp_loss, iteration)
if args.fp16:
writer.add_scalar('loss_scale', optimizer.loss_scale, iteration)
normalizer = iteration % args.log_interval
if normalizer == 0:
normalizer = args.log_interval
timers.write(timers_to_log, writer, iteration,
normalizer=normalizer)
if iteration % args.log_interval == 0:
avg_nsp_loss = total_nsp_loss.item() / args.log_interval
avg_lm_loss = total_lm_loss.item() / args.log_interval
elapsed_time = timers('interval time').elapsed()
if writer and args.rank == 0:
writer.add_scalar('iteration_time',
elapsed_time / args.log_interval, iteration)
log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
args.train_iters)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
......@@ -351,9 +378,13 @@ def train(model, optimizer, lr_scheduler,
if report_memory_flag:
report_memory('after {} iterations'.format(iteration))
report_memory_flag = False
timers.log(['forward', 'backward', 'optimizer', 'batch generator',
'data loader'],
normalizer=args.log_interval)
timers.log(timers_to_log, normalizer=args.log_interval)
# Autoresume
if (iteration % args.adlr_autoresume_interval == 0) and args.adlr_autoresume:
check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler, args)
# Checkpointing
if args.save and args.save_interval and iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
......@@ -361,8 +392,8 @@ def train(model, optimizer, lr_scheduler,
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid:
prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(
prefix, val_data_iterator, model, args, timers, False)
evaluate_and_print_results(prefix, val_data_iterator, model, args,
writer, iteration, timers, False)
if args.exit_interval and iteration % args.exit_interval == 0:
torch.distributed.barrier()
......@@ -413,7 +444,8 @@ def evaluate(data_iterator, model, args, timers, verbose = False):
def evaluate_and_print_results(prefix, data_iterator, model,
args, timers, verbose=False):
args, writer, iteration,
timers, verbose=False):
"""Helper function to evaluate and dump results on screen."""
lm_loss, nsp_loss = evaluate(data_iterator, model,
args, timers, verbose)
......@@ -428,6 +460,11 @@ def evaluate_and_print_results(prefix, data_iterator, model,
print_rank_0(string)
print_rank_0('-' * length)
if writer and args.rank == 0:
writer.add_scalar('val_lm_loss', lm_loss, iteration)
writer.add_scalar('val_nsp_loss', nsp_loss, iteration)
writer.add_scalar('val_total_loss', val_loss, iteration)
return val_loss
......@@ -471,7 +508,8 @@ def get_train_val_test_data(args):
# Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0:
data_config = configure_data()
data_config.set_defaults(data_set_type='BERT', transpose=False)
ds_type = 'BERT'
data_config.set_defaults(data_set_type=ds_type, transpose=False)
(train_data, val_data, test_data), tokenizer = data_config.apply(args)
before = tokenizer.num_tokens
after = before
......@@ -514,11 +552,27 @@ def main():
# Arguments.
args = get_args()
writer = None
if args.tensorboard_dir and args.rank == 0:
try:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir = args.tensorboard_dir)
except ModuleNotFoundError:
print_rank_0('WARNING: TensorBoard writing requested but is not '
'available (are you using PyTorch 1.1.0 or later?), '
'no TensorBoard logs will be written.')
writer = None
# Pytorch distributed.
initialize_distributed(args)
if torch.distributed.get_rank() == 0:
print('Pretrain BERT model')
print_args(args)
print_args(args, writer)
# Autoresume.
torch.distributed.barrier()
if args.adlr_autoresume:
enable_adlr_autoresume(args)
# Random seeds for reproducability.
set_random_seed(args.seed)
......@@ -534,11 +588,15 @@ def main():
if train_data is not None:
train_data.batch_sampler.start_iter = args.iteration % \
len(train_data)
print_rank_0('setting training data start iteration to {}'.
format(train_data.batch_sampler.start_iter))
if val_data is not None:
start_iter_val = (args.train_iters // args.save_interval) * \
args.eval_interval
start_iter_val = (args.iteration // args.eval_interval) * \
args.eval_iters
val_data.batch_sampler.start_iter = start_iter_val % \
len(val_data)
print_rank_0('setting validation data start iteration to {}'.
format(val_data.batch_sampler.start_iter))
if train_data is not None:
train_data_iterator = iter(train_data)
......@@ -556,11 +614,12 @@ def main():
lr_scheduler,
train_data_iterator,
val_data_iterator,
timers, args)
timers, args, writer)
if args.do_valid:
prefix = 'the end of training for val data'
val_loss = evaluate_and_print_results(prefix, val_data_iterator,
model, args, timers, False)
model, args, writer, iteration,
timers, False)
if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
......@@ -574,7 +633,7 @@ def main():
# Run on test data.
prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, test_data_iterator,
model, args, timers, True)
model, args, None, 0, timers, True)
if __name__ == "__main__":
......
......@@ -15,10 +15,6 @@
"""Pretrain GPT2"""
# Flag to use Pytorch ddp which uses overlapping communication and computation.
USE_TORCH_DDP = False
from datetime import datetime
import os
import random
......@@ -33,10 +29,7 @@ from fp16 import FP16_Optimizer
from learning_rates import AnnealingLR
from model import GPT2Model
from model import gpt2_get_params_for_weight_decay_optimization
if USE_TORCH_DDP:
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
else:
from model import DistributedDataParallel as DDP
from model import DistributedDataParallel as LocalDDP
import mpu
from apex.optimizers import FusedAdam as Adam
from utils import Timers
......@@ -46,10 +39,11 @@ from utils import report_memory
from utils import print_args
from utils import print_params_min_max_norm
from utils import print_rank_0
from utils import enable_adlr_autoresume
from utils import check_adlr_autoresume_termination
from gpt2_data_loader import make_gpt2_dataloaders
def get_model(args):
"""Build the model."""
......@@ -79,12 +73,18 @@ def get_model(args):
model = FP16_Module(model)
# Wrap model for distributed training.
if USE_TORCH_DDP:
if args.DDP_impl == 'torch':
i = torch.cuda.current_device()
model = DDP(model, device_ids=[i], output_device=i,
args.DDP_type = torch.nn.parallel.distributed.DistributedDataParallel
model = args.DDP_type(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
elif args.DDP_impl == 'local':
args.DDP_type = LocalDDP
model = args.DDP_type(model)
else:
model = DDP(model)
print_rank_0('Unknown DDP implementation specified: {}. '
'Exiting.'.format(args.DDP_impl))
exit()
return model
......@@ -93,7 +93,7 @@ def get_optimizer(model, args):
"""Set up the optimizer."""
# Build parameter groups (weight decay and non-decay).
while isinstance(model, (DDP, FP16_Module)):
while isinstance(model, (args.DDP_type, FP16_Module)):
model = model.module
param_groups = gpt2_get_params_for_weight_decay_optimization(model)
......@@ -136,7 +136,10 @@ def get_learning_rate_scheduler(optimizer, args):
warmup_iter=warmup_iter,
num_iters=num_iters,
decay_style=args.lr_decay_style,
last_iter=init_step)
last_iter=init_step,
min_lr=args.min_lr,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
override_lr_scheduler=args.override_lr_scheduler)
return lr_scheduler
......@@ -159,7 +162,8 @@ def setup_model_and_optimizer(args):
def get_masks_and_position_ids(data,
eod_token,
reset_position_ids,
reset_attention_mask):
reset_attention_mask,
eod_mask_loss):
# Extract batch size and sequence length.
batch_size, seq_length = data.size()
......@@ -175,6 +179,7 @@ def get_masks_and_position_ids(data,
# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0
# Position ids.
......@@ -246,7 +251,8 @@ def get_batch(data_iterator, args, timers):
tokens,
args.eod_token,
args.reset_position_ids,
args.reset_attention_mask)
args.reset_attention_mask,
args.eod_mask_loss)
# Convert
if args.fp16:
attention_mask = attention_mask.half()
......@@ -292,7 +298,7 @@ def backward_step(optimizer, model, lm_loss, args, timers):
reduced_losses = lm_loss.view(1)
torch.distributed.all_reduce(reduced_losses.data)
reduced_losses.data = reduced_losses.data / args.world_size
if not USE_TORCH_DDP:
if args.DDP_impl == 'local':
timers('allreduce').start()
model.allreduce_params(reduce_after=False,
fp32_allreduce=args.fp32_allreduce)
......@@ -343,7 +349,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
def train(model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator, timers, args):
train_data_iterator, val_data_iterator, timers, args, writer):
"""Train the model."""
# Turn on training mode which enables dropout.
......@@ -369,13 +375,37 @@ def train(model, optimizer, lr_scheduler,
iteration += 1
# Update losses.
total_lm_loss += lm_loss.data.detach().float()
current_lm_loss = lm_loss.data.detach().float()
total_lm_loss += current_lm_loss
# Logging.
if iteration % args.log_interval == 0:
if args.DDP_impl == 'torch':
timers_to_log = ['forward', 'backward', 'optimizer',
'batch generator', 'data loader']
else:
timers_to_log = ['forward', 'backward', 'allreduce', 'optimizer',
'batch generator', 'data loader']
learning_rate = optimizer.param_groups[0]['lr']
if writer and args.rank == 0:
writer.add_scalar('learning_rate', learning_rate, iteration)
writer.add_scalar('train_loss', current_lm_loss, iteration)
if args.fp16:
writer.add_scalar('loss_scale', optimizer.loss_scale, iteration)
normalizer = iteration % args.log_interval
if normalizer == 0:
normalizer = args.log_interval
timers.write(timers_to_log, writer, iteration,
normalizer=normalizer)
if iteration % args.log_interval == 0:
avg_lm_loss = total_lm_loss.item() / args.log_interval
elapsed_time = timers('interval time').elapsed()
if writer and args.rank == 0:
writer.add_scalar('iteration_time',
elapsed_time / args.log_interval, iteration)
log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
args.train_iters)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
......@@ -390,14 +420,13 @@ def train(model, optimizer, lr_scheduler,
if report_memory_flag:
report_memory('after {} iterations'.format(iteration))
report_memory_flag = False
if USE_TORCH_DDP:
timers.log(['forward', 'backward', 'optimizer',
'batch generator', 'data loader'],
normalizer=args.log_interval)
else:
timers.log(['forward', 'backward', 'allreduce', 'optimizer',
'batch generator', 'data loader'],
normalizer=args.log_interval)
timers.log(timers_to_log, normalizer=args.log_interval)
# Autoresume
if (iteration % args.adlr_autoresume_interval == 0) and args.adlr_autoresume:
check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler, args)
# Checkpointing
if args.save and args.save_interval and iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
......@@ -405,8 +434,8 @@ def train(model, optimizer, lr_scheduler,
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid:
prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(
prefix, val_data_iterator, model, args, timers, False)
evaluate_and_print_results(prefix, val_data_iterator, model, args,
writer, iteration, timers, False)
if args.exit_interval and iteration % args.exit_interval == 0:
torch.distributed.barrier()
......@@ -436,7 +465,7 @@ def evaluate(data_iterator, model, args, timers, verbose=False):
# Forward evaluation.
lm_loss = forward_step(data_iterator, model, args, timers)
# Reduce across processes.
if isinstance(model, DDP):
if isinstance(model, args.DDP_type):
torch.distributed.all_reduce(lm_loss.data)
lm_loss.data = lm_loss.data / args.world_size
......@@ -450,7 +479,8 @@ def evaluate(data_iterator, model, args, timers, verbose=False):
def evaluate_and_print_results(prefix, data_iterator, model,
args, timers, verbose=False):
args, writer, iteration,
timers, verbose=False):
"""Helper function to evaluate and dump results on screen."""
lm_loss = evaluate(data_iterator, model, args, timers, verbose)
lm_ppl = math.exp(min(20, lm_loss))
......@@ -463,6 +493,10 @@ def evaluate_and_print_results(prefix, data_iterator, model,
print_rank_0(string)
print_rank_0('-' * length)
if writer and args.rank == 0:
writer.add_scalar('val_loss', lm_loss, iteration)
writer.add_scalar('val_ppl', lm_ppl, iteration)
return lm_loss
......@@ -555,11 +589,27 @@ def main():
# Arguments.
args = get_args()
writer = None
if args.tensorboard_dir and args.rank == 0:
try:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir = args.tensorboard_dir)
except ModuleNotFoundError:
print_rank_0('WARNING: TensorBoard writing requested but is not '
'available (are you using PyTorch 1.1.0 or later?), '
'no TensorBoard logs will be written.')
writer = None
# Pytorch distributed.
initialize_distributed(args)
if torch.distributed.get_rank() == 0:
print('Pretrain GPT2 model')
print_args(args)
print_args(args, writer)
# Autoresume.
torch.distributed.barrier()
if args.adlr_autoresume:
enable_adlr_autoresume(args)
# Random seeds for reproducability.
set_random_seed(args.seed)
......@@ -576,11 +626,15 @@ def main():
if train_data is not None:
train_data.batch_sampler.start_iter = args.iteration % \
len(train_data)
print_rank_0('setting training data start iteration to {}'.
format(train_data.batch_sampler.start_iter))
if val_data is not None:
start_iter_val = (args.train_iters // args.save_interval) * \
args.eval_interval
start_iter_val = (args.iteration // args.eval_interval) * \
args.eval_iters
val_data.batch_sampler.start_iter = start_iter_val % \
len(val_data)
print_rank_0('setting validation data start iteration to {}'.
format(val_data.batch_sampler.start_iter))
if train_data is not None:
train_data_iterator = iter(train_data)
else:
......@@ -598,12 +652,13 @@ def main():
lr_scheduler,
train_data_iterator,
val_data_iterator,
timers, args)
timers, args, writer)
if args.do_valid:
prefix = 'the end of training for val data'
val_loss = evaluate_and_print_results(prefix, val_data_iterator,
model, args, timers, False)
model, args, writer, iteration,
timers, False)
if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer,
......@@ -618,7 +673,7 @@ def main():
# Run on test data.
prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, test_data_iterator,
model, args, timers, True)
model, args, None, 0, timers, True)
if __name__ == "__main__":
......
#!/bin/bash
CHECKPOINT_PATH=/path/to/checkpoint
CHECKPOINT_PATH=checkpoints/gpt2_345m/
MPSIZE=1
NLAYERS=24
NHIDDEN=1024
NATT=16
NLAYERS=12
NHIDDEN=768
NATT=12
MAXSEQLEN=1024
#SAMPLING ARGS
......@@ -26,4 +26,7 @@ python generate_samples.py \
--out-seq-length $MAXSEQLEN \
--temperature $TEMP \
--top_k $TOPK \
--top_p $TOPP
--genfile dbg_unconditional.json \
--num-samples 10 \
--top_p $TOPP \
--recompute
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
......@@ -28,6 +28,8 @@ parser.add_argument('--data-path', type=str, required=True,
help='Data path for evaluation data')
parser.add_argument('--cloze-eval', action='store_true',
help='Run lambada cloze eval instead of perplexity eval.')
parser.add_argument('--strict-lambada', action='store_true',
help='use more difficult formulation of lambada')
parser.add_argument('--webtext-eval', action='store_true',
help='Run webtext PPL eval instead of wikitext PPL eval.')
parser.add_argument('--eval-iters', default=5000, type=int,
......@@ -38,6 +40,9 @@ parser.add_argument('--load-openai', action='store_true',
help='Load weights from saved openai/hf checkpoints')
parser.add_argument('--cache-dir', type=str, default='cache',
help='directory to cache gpt2 tokenizers')
parser.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
help='Pad the vocab size to be divisible by this value.'
'This is added for computational efficieny reasons.')
args = parser.parse_args()
multinode_args = ''
......@@ -60,18 +65,23 @@ CMD = ' --model-parallel-size {model_par} \
--attention-dropout 0.1 \
--fp16 \
--overlapping-eval 32 \
--make-vocab-size-divisible-by {make_vocab_size_divisible_by} \
--cache-dir {cache} '.format(model_par=args.model_parallel_size,
nlayers=args.num_layers,
hidden=args.hidden_size,
model=args.model_path,
batch=args.batch_size,
natt=args.num_attention_heads,
make_vocab_size_divisible_by=args.make_vocab_size_divisible_by,
cache=args.cache_dir)
if args.load_openai:
CMD += ' --load-openai '
if args.cloze_eval:
CMD += ' --valid-data {} '.format(args.data_path)
CMD += ' --cloze-eval '
if args.strict_lambada:
CMD += ' --strict-lambada '
CMD = 'evaluate_gpt2.py' + CMD
print('Running Lambada Eval Command:', flush=True)
elif args.webtext_eval:
......
"""
Takes a corpora of files (specified by `--input_files`) with json data separated
by newlines (loose json). Splits data into train.json, val.json, test.json files
under `output_dir`.
Note: This code has the potential to override files with the names
train.json, val.json, test.json in `--output_dir`.
"""
import os
import argparse
import math
import random
parser = argparse.ArgumentParser('resplit loose json data into train/val/test')
parser.add_argument('--input_files', nargs='+', required=True,
help='whitespace separated list of input data files')
parser.add_argument('--output_dir', required=True,
help='output directory where to put files')
parser.add_argument('--test_percent', type=float, nargs='+', default=[0.05, 0],
help='percentage of available data to use for val/test dataset')
args = parser.parse_args()
def get_lines(filepath):
lines = []
with open(filepath, 'r') as f:
for i, l in enumerate(f.readlines()):
l = l.strip()
lines.append(l)
return lines
def get_splits(lines, line_counts):
all_lines = []
line_idx = []
file_mappings = []
for i, l in enumerate(lines):
all_lines.extend(l)
line_idx.extend(list(range(len(l))))
file_mappings.extend([i]*len(l))
indices = list(range(len(all_lines)))
random.shuffle(indices)
all_lines = [all_lines[idx] for idx in indices]
line_idx = [line_idx[idx] for idx in indices]
file_mappings = [file_mappings[idx] for idx in indices]
splits = []
mappings = []
start = 0
for end in line_counts:
end += start
splits.append(all_lines[start:end])
mappings.append(format_mappings(line_idx[start:end], file_mappings[start:end]))
start = end
return splits, mappings
def format_mappings(line_idx, file_mappings):
lines = []
for m, l in zip(file_mappings, line_idx):
lines.append(str(m).strip()+'\t'+str(l).strip())
return lines
def get_filepaths(filepaths, output_dir):
paths = []
train_path = 'train.json'
dev_path = 'dev.json'
test_path = 'test.json'
paths.append(os.path.join(output_dir, train_path))
paths.append(os.path.join(output_dir, dev_path))
paths.append(os.path.join(output_dir, test_path))
return paths
def write_files(lines, mappings, filepaths):
for l, m, path in zip(lines, mappings, filepaths):
write_file(l, path)
write_mapping_file(m, path)
def write_file(lines, path):
print('Writing:', path)
with open(path, 'w') as f:
for l in lines:
f.write(l+'\n')
def write_mapping_file(m, path):
path = path+'.map'
m = [get_mapping_header()]+m
write_file(m, path)
def get_mapping_header():
return 'file\tline #'
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
lines = []
for filepath in args.input_files:
_lines = get_lines(filepath)
lines.append(_lines)
#calculate number of lines to use for each
line_counts = [len(l) for l in lines]
total_lines = sum(line_counts)
dev_percent = args.test_percent[0]
dev_lines = math.ceil(dev_percent*total_lines)
test_percent = 0
if len(args.test_percent)==2:
test_percent=args.test_percent[1]
test_lines = math.ceil(test_percent*total_lines)
train_lines = total_lines-(test_lines+dev_lines)
normed_lines = [train_lines, dev_lines, test_lines]
normed_lines = [int(l) for l in normed_lines]
splits, mappings = get_splits(lines, normed_lines)
filepaths = get_filepaths(args.input_files, args.output_dir)
print('Writing output to:', filepaths)
write_files(splits, mappings, filepaths)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment