"vscode:/vscode.git/clone" did not exist on "1039d61e4ff18e24812b2e4bdd2a3ae232739f6c"
Commit a54978bb authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'staging' into 'master'

Updating public repo with latest changes.

See merge request ADLR/megatron-lm!1
parents 93ab4bea 5d402eb4
__pycache__
......@@ -114,7 +114,8 @@ def add_training_args(parser):
help='report interval')
group.add_argument('--exit-interval', type=int, default=None,
help='Exit the program after this many new iterations.')
group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory')
group.add_argument('--seed', type=int, default=1234,
help='random seed')
# Batch prodecuer arguments
......@@ -123,6 +124,8 @@ def add_training_args(parser):
group.add_argument('--reset-attention-mask', action='store_true',
help='Reset self attention maske after '
'end-of-document token.')
group.add_argument('--eod-mask-loss', action='store_true',
help='Mask loss for the end of document tokens')
# Learning rate.
group.add_argument('--lr-decay-iters', type=int, default=None,
......@@ -133,9 +136,25 @@ def add_training_args(parser):
help='learning rate decay function')
group.add_argument('--lr', type=float, default=1.0e-4,
help='initial learning rate')
group.add_argument('--min-lr', type=float, default=0.0,
help='Minumum value for learning rate. The scheduler'
'clip values below this threshold.')
group.add_argument('--warmup', type=float, default=0.01,
help='percentage of data to warmup on (.01 = 1% of all '
'training iters). Default 0.01')
group.add_argument('--override-lr-scheduler', action='store_true',
help='Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum '
'number of iterations, and decay style from input '
'arguments and ignore values from checkpoints. Note'
'that all the above values will be reset.')
group.add_argument('--use-checkpoint-lr-scheduler', action='store_true',
help='Use checkpoint to set the values of the scheduler '
'(learning rate, warmup iterations, minimum learning '
'rate, maximum number of iterations, and decay style '
'from input arguments and ignore values from '
'checkpoints. Notethat all the above values will be '
'reset.')
# model checkpointing
group.add_argument('--save', type=str, default=None,
help='Output directory to save checkpoints to.')
......@@ -163,8 +182,17 @@ def add_training_args(parser):
group.add_argument('--distributed-backend', default='nccl',
help='which backend to use for distributed '
'training. One of [gloo, nccl]')
group.add_argument('--DDP-impl', default='local',
help='which DistributedDataParallel implementation '
'to use. One of [local, torch]')
group.add_argument('--local_rank', type=int, default=None,
help='local rank passed from distributed launcher')
# autoresume
group.add_argument('--adlr-autoresume', action='store_true',
help='enable autoresume on adlr cluster.')
group.add_argument('--adlr-autoresume-interval', type=int, default=1000,
help='intervals over which check for autoresume'
'termination signal')
return parser
......@@ -193,6 +221,8 @@ def add_evaluation_args(parser):
help='sliding window for overlapping eval ')
group.add_argument('--cloze-eval', action='store_true',
help='Evaluation dataset from `--valid-data` is a cloze task')
group.add_argument('--strict-lambada', action='store_true',
help='use more difficult formulation of lambada')
group.add_argument('--eval-hf', action='store_true',
help='perform evaluation with huggingface openai model.'
'use `--load` to specify weights path to be loaded')
......@@ -207,9 +237,23 @@ def add_text_generate_args(parser):
group = parser.add_argument_group('Text generation', 'configurations')
group.add_argument("--temperature", type=float, default=1.0)
group.add_argument("--greedy", action='store_true', default=False)
group.add_argument("--top_p", type=float, default=0.0)
group.add_argument("--top_k", type=int, default=0)
group.add_argument("--out-seq-length", type=int, default=256)
group.add_argument("--out-seq-length", type=int, default=1024)
group.add_argument("--sample-input-file", type=str, default="",
help='get input from file instead of interactive mode, '
'each line is an input' )
group.add_argument("--sample-output-file", type=str, default="",
help='output file got from --sample-input-file')
group.add_argument("--num-samples", type=int, default=0,
help='number of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling')
group.add_argument("--genfile", type=str,
help='output file when generating unconditionally')
group.add_argument("--recompute", action='store_true',
help='during generation recompute all attention '
'instead of using previously computed keys/values.')
return parser
......
......@@ -148,7 +148,8 @@ def make_loaders(args):
'model_type': args.tokenizer_model_type,
'cache_dir': args.cache_dir,
'max_preds_per_seq': args.max_preds_per_seq,
'presplit_sentences': args.presplit_sentences}
'presplit_sentences': args.presplit_sentences,
'parallel_group': mpu.get_data_parallel_group()}
eval_set_args = copy.copy(data_set_args)
eval_set_args['split'] = [1.]
......
......@@ -16,6 +16,8 @@
import os
import math
import torch
from .samplers import DistributedBatchSampler
from .datasets import json_dataset, csv_dataset, split_ds, ConcatDataset, SplitDataset, bert_sentencepair_dataset, GPT2Dataset
from .lazy_loader import exists_lazy, make_lazy, lazy_array_loader
......@@ -61,7 +63,8 @@ def supported_corpus(corpus_name):
def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=None, split=[1.],
delim=',', loose=False, binarize_sent=False, drop_unlabeled=False, tokenizer=None,
tokenizer_type='CharacterLevelTokenizer', tokenizer_model_path=None, vocab_size=None,
model_type='bpe', pad_token=0, character_converage=1.0, non_binary_cols=None, **kwargs):
model_type='bpe', pad_token=0, character_converage=1.0, non_binary_cols=None,
parallel_group=None, **kwargs):
"""function to create datasets+tokenizers for common options"""
if isinstance(process_fn, str):
process_fn = eval(process_fn)
......@@ -76,11 +79,19 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
named_corpora = True
name = path_
path_ = corpora.NAMED_CORPORA[path_].PATH
if not exists_lazy(path_, data_type='data'):
if torch.distributed.get_rank() == 0 and not exists_lazy(path_, data_type='data'):
# create cached version of dataset for lazy loading if it doesn't exist
text = get_dataset(name if named_corpora else path_, text_key=text_key, label_key=label_key, binarize_sent=binarize_sent,
delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose)
make_lazy(path_, text.X, data_type='data')
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=parallel_group)
assert counts[0].item() == torch.distributed.get_world_size(
group=parallel_group)
text = lazy_array_loader(path_, data_type='data', map_fn=process_fn)
else:
# get dataset
......@@ -107,15 +118,17 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
# Split dataset into train/val/test (and wrap bert dataset)
if should_split(split):
ds = split_ds(ds, split)
if ds_type.lower() == 'bert':
if 'bert' in ds_type.lower():
presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
ds = [bert_sentencepair_dataset(d, max_seq_len=seq_length, presplit_sentences=presplit_sentences) if d is not None else None for d in ds]
dstype = bert_sentencepair_dataset
ds = [dstype(d, max_seq_len=seq_length, presplit_sentences=presplit_sentences) if d is not None else None for d in ds]
elif ds_type.lower() == 'gpt2':
ds = [GPT2Dataset(d, max_seq_len=seq_length) if d is not None else None for d in ds]
else:
if ds_type.lower() == 'bert':
if 'bert' in ds_type.lower():
presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
ds = bert_sentencepair_dataset(ds, max_seq_len=seq_length, presplit_sentences=presplit_sentences)
dstype = bert_sentencepair_dataset
ds = dstype(ds, max_seq_len=seq_length, presplit_sentences=presplit_sentences)
elif ds_type.lower() == 'gpt2':
ds = GPT2Dataset(ds, max_seq_len=seq_length)
return ds, tokenizer
......@@ -461,6 +461,7 @@ class GPT2Dataset(data.Dataset):
weighted=True,
sample_across_doc=True,
random_across_doc_sampling=True,
bias_for_single_doc=False,
sentence_start=False, **kwargs):
self.ds = ds
self.ds_len = len(self.ds)
......@@ -473,6 +474,7 @@ class GPT2Dataset(data.Dataset):
self.weighted = weighted
self.sample_across_doc = sample_across_doc
self.random_across_doc_sampling = random_across_doc_sampling
self.bias_for_single_doc = bias_for_single_doc
self.sentence_start = sentence_start
self.init_weighting()
......@@ -510,7 +512,10 @@ class GPT2Dataset(data.Dataset):
# truncate or pad tokens
num_tokens = len(tokens)
tokens_to_strip = num_tokens - self.max_seq_len - 1
if self.bias_for_single_doc:
tokens_to_strip = num_tokens - self.max_seq_len - 1
else:
tokens_to_strip = num_tokens - 1
if tokens_to_strip > 0:
strip_left_tokens = rng.randint(tokens_to_strip + 1)
tokens = tokens[strip_left_tokens:]
......@@ -576,7 +581,7 @@ class bert_sentencepair_dataset(data.Dataset):
dataset_size (int): number of random sentencepairs in the dataset. Default: len(ds)*(len(ds)-1)
"""
def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None, short_seq_prob=.01, dataset_size=None, presplit_sentences=False, weighted=True,**kwargs):
def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None, short_seq_prob=.01, dataset_size=None, presplit_sentences=False, weighted=True, **kwargs):
self.ds = ds
self.ds_len = len(self.ds)
self.tokenizer = self.ds.GetTokenizer()
......@@ -758,7 +763,8 @@ class bert_sentencepair_dataset(data.Dataset):
"""
tokens_a, token_types_a = a
tokens_b, token_types_b = b
max_num_tokens = max_seq_len - 3
max_num_tokens = self.calc_seq_len(max_seq_len)
# max_num_tokens = max_seq_len - 3
while True:
len_a = len(tokens_a)
len_b = len(tokens_b)
......@@ -782,6 +788,9 @@ class bert_sentencepair_dataset(data.Dataset):
trunc_types.pop()
return (tokens_a, token_types_a), (tokens_b, token_types_b)
def calc_seq_len(self, max_seq_len):
return max_seq_len - 3
def mask_token(self, idx, tokens, types, vocab_words, rng):
"""
helper function to mask `idx` token from `tokens` according to
......@@ -807,6 +816,11 @@ class bert_sentencepair_dataset(data.Dataset):
seq += [self.tokenizer.get_command('pad').Id] * num_pad
return seq, pad_mask
def concat_tokens(self, tokens_a, token_types_a, tokens_b, token_types_b):
tokens = [self.tokenizer.get_command('ENC').Id] + tokens_a + [self.tokenizer.get_command('sep').Id] + tokens_b + [self.tokenizer.get_command('sep').Id]
token_types = [token_types_a[0]] + token_types_a + [token_types_a[0]] + token_types_b + [token_types_b[0]]
return tokens, token_types
def create_masked_lm_predictions(self, a, b, mask_lm_prob, max_preds_per_seq, vocab_words, rng):
"""
Mask sequence pair for BERT training according to:
......@@ -814,8 +828,7 @@ class bert_sentencepair_dataset(data.Dataset):
"""
tokens_a, token_types_a = a
tokens_b, token_types_b = b
tokens = [self.tokenizer.get_command('ENC').Id] + tokens_a + [self.tokenizer.get_command('sep').Id] + tokens_b + [self.tokenizer.get_command('sep').Id]
token_types = [token_types_a[0]] + token_types_a + [token_types_a[0]] + token_types_b + [token_types_b[0]]
tokens, token_types = self.concat_tokens(tokens_a, token_types_a, tokens_b, token_types_b)
len_a = len(tokens_a)
len_b = len(tokens_b)
......
......@@ -111,7 +111,7 @@ class lazy_array_loader(object):
lazypath = get_lazy_path(path)
datapath = os.path.join(lazypath, data_type)
#get file where array entries are concatenated into one big string
self._file = open(datapath, 'rb')
self._file = open(datapath, 'rb', buffering=0)
self.file = self._file
#memory map file if necessary
self.mem_map = mem_map
......
......@@ -795,7 +795,6 @@ class BertWordPieceTokenizer(Tokenizer):
Tokens = Tokens.tokenization
return ' '.join(Tokens)
class GPT2BPETokenizer(Tokenizer):
def __init__(self, cache_dir=None, **kwargs):
self.text_tokenizer = GPT2Tokenizer.from_pretrained('gpt2',
......@@ -887,4 +886,3 @@ class GPT2BPETokenizer(Tokenizer):
if isinstance(Tokens, Tokenization):
Tokens = Tokens.tokenization
return self.text_tokenizer.decode([self.TokenToId(tok) for tok in Tokens])
......@@ -210,10 +210,13 @@ def forward_step(data_iterator, model, args, timers):
lm_loss = torch.sum(
losses.view(-1) * loss_mask.float())
else:
outputs = torch.argmax(output, -1).contiguous().view(-1)
acc = (outputs == lm_labels.contiguous().view(-1)).float()
loss_mask = loss_mask.contiguous().view(-1).float()
lm_loss = torch.sum(acc * loss_mask)
outputs = torch.argmax(output, -1)
correct = (outputs == lm_labels).float()
correct[(1-loss_mask).bool()] = 1
correct = correct.prod(-1)
lm_loss = correct.sum()
# loss_mask = loss_mask.contiguous().view(-1).float()
# lm_loss = torch.sum(acc * loss_mask)
return lm_loss
......@@ -345,7 +348,7 @@ def set_random_seed(seed):
class LM_Eval_Dataset(torch.utils.data.Dataset):
def __init__(self, tokens, seq_len, pad_idx, overalapping_eval=None):
def __init__(self, tokens, seq_len, pad_idx, overalapping_eval=None, **kwargs):
self.tokens = tokens
self.seq_len = seq_len
self.pad_idx = pad_idx
......@@ -379,15 +382,30 @@ class LM_Eval_Dataset(torch.utils.data.Dataset):
return {'text': np.array(tokens), 'pad_mask': pad_mask}
class Lambada_Eval_Dataset(torch.utils.data.Dataset):
def __init__(self, path, tokenizer, seq_len):
def __init__(self, path, tokenizer, seq_len, strict=False, **kwargs):
self.seq_len = seq_len
self.pad_idx = tokenizer.get_command('pad').Id
self.tokenizer = tokenizer
self.strict = strict
self.tokens = []
self.labels = []
with open(path, 'r') as f:
for line in f.readlines():
text = json.loads(line)['text']
self.tokens.append(tokenizer.EncodeAsIds(text).tokenization)
tokens, labels = self.get_tokens(text)
self.tokens.append(tokens)
self.labels.append(labels)
def get_tokens(self, text):
if not self.strict:
tokens = self.tokenizer.EncodeAsIds(text).tokenization
return tokens[:-1], [tokens[-1]]
last_token = text.split()[-1]
start_idx = text.rfind(last_token)
beginning_tokens = self.tokenizer.EncodeAsIds(text[:start_idx].strip()).tokenization
last_token = self.tokenizer.EncodeAsIds(' '+last_token).tokenization
return beginning_tokens, last_token
def __len__(self):
return len(self.tokens)
......@@ -397,7 +415,10 @@ class Lambada_Eval_Dataset(torch.utils.data.Dataset):
tokens = self.tokens[idx]
num_tokens = len(tokens)
pad_mask = [0]*num_tokens
pad_mask[-1] = 1
labels = self.labels[idx]
pad_mask += [1]*len(labels)
tokens = tokens+labels
num_tokens = len(tokens)
if num_tokens < self.seq_len+1:
num_pad = (self.seq_len+1-num_tokens)
pad_mask += [0]*(num_pad)
......@@ -442,7 +463,7 @@ def get_eval_data(args):
val_dataset = LM_Eval_Dataset(tokenized_data, seq_len, eod_token,
args.overlapping_eval)
else:
val_dataset = Lambada_Eval_Dataset(valid_data, tokenizer, seq_len)
val_dataset = Lambada_Eval_Dataset(valid_data, tokenizer, seq_len, args.strict_lambada)
num_tokenized_tokens = 0
num_original_tokens = 0
val_dataloader = torch.utils.data.DataLoader(
......@@ -450,7 +471,9 @@ def get_eval_data(args):
before = tokenizer.num_tokens
after = before
while after % mpu.get_model_parallel_world_size() != 0:
multiple = args.make_vocab_size_divisible_by * \
mpu.get_model_parallel_world_size()
while (after % multiple) != 0:
after += 1
print_rank_0('> padded vocab (size: {}) with {} dummy tokens (new size: {})'.
format(before, after - before, after))
......
......@@ -17,6 +17,8 @@
import os
import random
import json
import copy
import numpy as np
import torch
import torch.nn.functional as F
......@@ -83,9 +85,10 @@ def setup_model(args):
return model
def get_batch(context_tokens, device, args):
def get_batch(context_tokens, args):
tokens = context_tokens
tokens = tokens.view(args.batch_size, -1).contiguous()
device = args.device
tokens = tokens.to(device)
# Get the masks and postition ids.
......@@ -108,8 +111,8 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
if top_p > 0.0:
#convert to 1D
logits=logits.view(logits.size()[1]).contiguous()
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
# logits=logits.view(logits.size()[1]).contiguous()
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
......@@ -117,16 +120,33 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
for i in range(sorted_indices.size(0)):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = filter_value
#going back to 2D
logits=logits.view(1, -1).contiguous()
# logits=logits.view(1, -1).contiguous()
return logits
def generate_samples_input_from_file(model, tokenizer, args):
def generate_samples(model, tokenizer, args, device):
if args.sample_input_file == "":
if mpu.get_model_parallel_rank() == 0:
print("args.sample_input_file CAN NOT BE empty!\n")
return
if mpu.get_model_parallel_rank() == 0:
fname = open(args.sample_input_file, "r")
all_raw_text = fname.readlines()
input_count = len(all_raw_text)
input_pos = 0
if args.sample_output_file == "":
print("Argument: sample-output-file can't be empty, setting it to\n")
print("\t args.sample_input_file.out")
args.sample_output_file = args.sample_input_file+".out"
fname_out = open(args.sample_output_file, "w+")
context_count=0
model.eval()
with torch.no_grad():
......@@ -135,6 +155,74 @@ def generate_samples(model, tokenizer, args, device):
terminate_runs=0
if mpu.get_model_parallel_rank() == 0:
raw_text = all_raw_text[input_pos]
input_pos += 1
if input_pos == input_count:
raw_text = "stop"
if "stop" in raw_text:
terminate_runs = 1
else:
context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization
context_length = len(context_tokens)
if context_length >=args.seq_length//2:
print("\nContext length", context_length, \
"\nPlease give smaller context (half of the sequence length)!")
continue
else:
context_tokens = tokenizer.EncodeAsIds("EMPTY TEXT").tokenization
context_length = len(context_tokens)
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item()
if terminate_runs == 1:
return
start_time = time.time()
token_stream = get_token_stream(model, [context_tokens], tokenizer, args)
for counter, decode_tokens in enumerate(token_stream):
# token_end = decode_tokens.find("<|endoftext|>")
# if token_end > 0:
# break
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
if mpu.get_model_parallel_rank() == 0:
os.system('clear')
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)[len(raw_text):]
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
fname_out.write("\nContext:")
fname_out.write(raw_text)
fname_out.write("\n\nMegatron-LM:")
fname_out.write(trim_decode_tokens)
#fname_out.write(trim_decode_tokens.replace("\n", "\n\n"))
fname_out.write("\n")
raw_text = None
torch.distributed.barrier(group=mpu.get_model_parallel_group())
context_count += 1
def generate_samples_interactive(model, tokenizer, args):
print_frequency = 24
context_count=0
model.eval()
with torch.no_grad():
while True:
torch.distributed.barrier(group=mpu.get_model_parallel_group())
terminate_runs=0
if mpu.get_model_parallel_rank() == 0:
os.system('clear')
raw_text = input("\nContext prompt (stop to exit) >>> ")
while not raw_text:
print('Prompt should not be empty!')
......@@ -161,60 +249,179 @@ def generate_samples(model, tokenizer, args, device):
if terminate_runs == 1:
return
pad_id = tokenizer.get_command('pad').Id
if context_length < args.seq_length:
context_tokens.extend([pad_id] * (args.seq_length - context_length))
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor([context_length])
torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group())
torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group())
context_length = context_length_tensor[0].item()
tokens, attention_mask, position_ids=get_batch(context_tokens_tensor, device, args)
start_time = time.time()
token_stream = get_token_stream(model, [context_tokens], tokenizer, args)
for counter, decode_tokens in enumerate(token_stream):
# token_end = decode_tokens.find("<|endoftext|>")
# if token_end > 0:
# break
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
if mpu.get_model_parallel_rank() == 0 and counter % print_frequency == 0:
os.system('clear')
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)[len(raw_text):]
#print("\nGPT2:", trim_decode_tokens, flush=True)
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
counter = 0
org_context_length = context_length
while counter < (org_context_length + args.out_seq_length):
logits = model(tokens, position_ids, attention_mask)
logits = logits[:, context_length - 1, :] / args.temperature
logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1)
tokens[0, context_length] = prev[0]
context_length += 1
counter += 1
output_tokens_list = tokens.view(-1).contiguous()
decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist())
token_end = decode_tokens.find("<|endoftext|>")
if mpu.get_model_parallel_rank() == 0 and (counter % 16 == 0 or token_end != -1):
os.system('clear')
print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True)
trim_decode_tokens = decode_tokens[len(raw_text):decode_tokens.find("<|endoftext|>")]
print("\nGPT2:", trim_decode_tokens, flush=True)
if token_end != -1:
break
if mpu.get_model_parallel_rank() == 0:
os.system('clear')
print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True)
output_tokens_list = tokens.view(-1).contiguous()
decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist())
trim_decode_tokens = decode_tokens[len(raw_text):decode_tokens.find("<|endoftext|>")]
print("\nGPT2:", trim_decode_tokens, flush=True)
trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)[len(raw_text):]
#print("\nGPT2:", trim_decode_tokens, flush=True)
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
raw_text = None
torch.distributed.barrier(group=mpu.get_model_parallel_group())
context_count += 1
if mpu.get_model_parallel_rank() == 0:
input("\nPress any key to continue >>>")
def generate_samples_unconditional(model, tokenizer, args):
num_samples = args.num_samples
context_tokens = [[tokenizer.get_command('pad').Id] for _ in range(args.batch_size)]
samples = []
# with open(args.genfile, 'w') as f:
ctr = 0
while True:
start_time = time.time()
for token_stream in get_token_stream(model, copy.deepcopy(context_tokens), tokenizer, args):
pass
# token_stream = list(get_token_stream(model, copy.deepcopy(context_tokens), tokenizer, args))
if ctr%args.log_interval == 0:
print('Avg s/batch:', (time.time()-start_time)/min(args.log_interval, ctr+1))
start_time = time.time()
length = len(token_stream)
token_batch = token_stream[0].cpu().numpy().tolist()
length_batch = token_stream[1].cpu().numpy().tolist()
for tokens, length in zip(token_batch, length_batch):
tokens = tokens[1:length-1]
text = tokenizer.DecodeIds(tokens)
is_finished = length < args.seq_length - 1
datum = {'text': text, 'length': length-1, 'finished': is_finished}
yield datum
ctr += 1
if ctr >= num_samples:
break
if ctr >= num_samples:
break
def write_and_generate_samples_unconditional(model, tokenizer, args):
assert args.genfile is not None
with open(args.genfile, 'w') as f:
for datum in generate_samples_unconditional(model, tokenizer, args):
f.write(json.dumps(datum)+'\n')
def pad_batch(batch, tokenizer, args):
pad_id = tokenizer.get_command('pad').Id
context_lengths = []
for tokens in batch:
context_length = len(tokens)
if context_length < args.seq_length:
tokens.extend([pad_id]*(args.seq_length-context_length))
context_lengths.append(context_length)
return batch, context_lengths
def get_token_stream(model, context_tokens, tokenizer, args):
pad_id = tokenizer.get_command('pad').Id
# context_length = len(context_tokens)
# if context_length < args.seq_length:
# context_tokens = context_tokens + [pad_id] * (args.seq_length - context_length)
context_tokens, context_lengths = pad_batch(context_tokens, tokenizer, args)
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor(context_lengths)
# context_length_tensor = torch.cuda.LongTensor([context_length])
torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group())
torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group())
context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids=get_batch(context_tokens_tensor, args)
counter = 0
org_context_length = context_length
layer_past = None
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids, tokenizer, args)
for tokens, lengths in batch_token_iterator:
context_length += 1
yield tokens[:, :context_length], lengths
def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
return (1-boolean)*val1 + boolean*val2
def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None):
model.eval()
with torch.no_grad():
context_length = context_lengths.min().item()
eos_id = tokenizer.get_command('eos').Id
counter = 0
org_context_length = context_length
layer_past = None
batch_size = context_tokens.size(0)
is_done = torch.zeros([batch_size]).byte().cuda()
tokens = context_tokens
if maxlen is None:
maxlen = args.seq_length - 1
if maxlen > (org_context_length + args.out_seq_length):
maxlen = org_context_length + args.out_seq_length
lengths = torch.ones([batch_size]).long().cuda()*maxlen
while context_length <= (maxlen):
if args.recompute:
logits = model(tokens, position_ids, attention_mask)
logits = logits[:, context_length - 1, :]
else:
if counter == 0:
tokens2use = tokens[:, :context_length]
positions2use = position_ids[:, :context_length]
else:
tokens2use = tokens[:, context_length - 1].view(batch_size, -1)
positions2use = position_ids[:, context_length - 1].view(batch_size, -1)
logits, layer_past = model(tokens2use, positions2use, attention_mask, layer_past=layer_past, get_present=True)
logits = logits[:, -1].view(batch_size,-1).contiguous()
if args.greedy:
prev = torch.argmax(logits, dim=-1).view(-1)
else:
logits /= args.temperature
logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
print_logits = []
for p in prev:
print_logits.append([logits[i, p].item() for i in range(batch_size)])
started = context_lengths <= context_length
tokens[:, context_length] = switch(tokens[:, context_length].view(-1), prev, started)
context_length += 1
counter += 1
done_token = (prev == eos_id).byte()
just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length
was_done = is_done
is_done = is_done | done_token
done = torch.all(is_done)
yield tokens, lengths
if done:
break
def prepare_tokenizer(args):
......@@ -232,8 +439,11 @@ def prepare_tokenizer(args):
args.eod_token = tokenizer.get_command('eos').Id
after = tokenizer.num_tokens
while after % mpu.get_model_parallel_world_size() != 0:
after += 1
multiple = args.make_vocab_size_divisible_by * \
mpu.get_model_parallel_world_size()
if multiple != 0:
while (after % multiple) != 0:
after += 1
args.vocab_size = after
print("prepare tokenizer done", flush=True)
......@@ -267,10 +477,19 @@ def main():
model = setup_model(args)
#setting default batch size to 1
args.batch_size = 1
# args.batch_size = 1
args.device = torch.cuda.current_device()
#generate samples
generate_samples(model, tokenizer, args, torch.cuda.current_device())
if args.num_samples == 0:
args.batch_size = 1
if args.sample_input_file != "":
generate_samples_input_from_file(model, tokenizer, args)
else:
generate_samples_interactive(model, tokenizer, args)
else:
write_and_generate_samples_unconditional(model, tokenizer, args)
if __name__ == "__main__":
......
......@@ -18,36 +18,48 @@ import torch
from torch.optim.lr_scheduler import _LRScheduler
import math
from utils import print_rank_0
class AnnealingLR(_LRScheduler):
"""Anneals the learning rate from start to zero along a cosine curve."""
"""Anneals the learning rate"""
DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None']
DECAY_STYLES = ['linear', 'cosine', 'constant', 'None']
def __init__(self, optimizer, start_lr, warmup_iter, num_iters, decay_style=None, last_iter=-1):
def __init__(self, optimizer, start_lr, warmup_iter, num_iters,
decay_style=None, last_iter=-1, min_lr=0.0,
use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False):
self.optimizer = optimizer
self.start_lr = start_lr
self.min_lr = min_lr
self.warmup_iter = warmup_iter
self.num_iters = last_iter + 1
self.end_iter = num_iters
self.decay_style = decay_style.lower() if isinstance(decay_style, str) else None
self.decay_style = decay_style.lower() if isinstance(decay_style, str) \
else None
self.override_lr_scheduler = override_lr_scheduler
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
if self.override_lr_scheduler:
assert not self.use_checkpoint_lr_scheduler, 'both override and '\
'use-checkpoint are set.'
self.step(self.num_iters)
if torch.distributed.get_rank() == 0:
print('learning rate decaying', decay_style)
def get_lr(self):
# https://openreview.net/pdf?id=BJYwwY9ll pg. 4
num_iters_ = min(self.num_iters, self.end_iter - self.warmup_iter)
if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter:
return float(self.start_lr) * self.num_iters / self.warmup_iter
return float(self.start_lr) * num_iters_ / self.warmup_iter
else:
if self.decay_style == self.DECAY_STYLES[0]:
return self.start_lr*((self.end_iter-(self.num_iters-self.warmup_iter))/self.end_iter)
lr = self.start_lr * ((self.end_iter - (num_iters_ - self.warmup_iter)) / self.end_iter)
elif self.decay_style == self.DECAY_STYLES[1]:
return self.start_lr / 2.0 * (math.cos(math.pi * (self.num_iters - self.warmup_iter) / self.end_iter) + 1)
elif self.decay_style == self.DECAY_STYLES[2]:
#TODO: implement exponential decay
return self.start_lr
lr = self.start_lr / 2.0 * (math.cos(math.pi * (num_iters_ - self.warmup_iter) / self.end_iter) + 1)
else:
return self.start_lr
lr = self.start_lr
return max(lr, self.min_lr)
def step(self, step_num=None):
if step_num is None:
......@@ -63,14 +75,38 @@ class AnnealingLR(_LRScheduler):
'warmup_iter': self.warmup_iter,
'num_iters': self.num_iters,
'decay_style': self.decay_style,
'end_iter': self.end_iter
'end_iter': self.end_iter,
'min_lr': self.min_lr
}
return sd
def check_and_set_(self, cls_value, sd_value, name):
if self.override_lr_scheduler:
print_rank_0(' > overriding {} value to {}'.format(name, cls_value))
return cls_value
else:
if not self.use_checkpoint_lr_scheduler:
assert cls_value == sd_value, 'AnnealingLR: class input value' \
'and checkpoint values for {} do not match'.format(name)
print_rank_0(' > using checkpoint value {} for {}'.format(sd_value,
name))
return sd_value
def load_state_dict(self, sd):
self.start_lr = sd['start_lr']
self.warmup_iter = sd['warmup_iter']
self.start_lr = self.check_and_set_(self.start_lr, sd['start_lr'],
'learning rate')
self.min_lr = self.check_and_set_(self.min_lr, sd['min_lr'],
'minimum learning rate')
self.warmup_iter = self.check_and_set_(self.warmup_iter,
sd['warmup_iter'],
'warmup iterations')
self.end_iter = self.check_and_set_(self.end_iter, sd['end_iter'],
'total number of iterations')
self.decay_style = self.check_and_set_(self.decay_style,
sd['decay_style'],
'decay style')
self.num_iters = sd['num_iters']
self.end_iter = sd['end_iter']
self.decay_style = sd['decay_style']
self.step(self.num_iters)
......@@ -65,6 +65,14 @@ class GPT2Model(torch.nn.Module):
# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(max_sequence_length,
hidden_size)
# Token type embedding.
# Add this as an optional field that can be added through
# method call so we can load a pretrain model without
# token types and add them as needed.
self.tokentype_embeddings = None
self.hidden_size = hidden_size
# Initialize the position embeddings.
init_method(self.position_embeddings.weight)
......@@ -80,18 +88,39 @@ class GPT2Model(torch.nn.Module):
checkpoint_activations,
checkpoint_num_layers)
def forward(self, input_ids, position_ids, attention_mask):
def add_tokentype_embeddings(self, num_tokentypes):
if self.tokentype_embeddings is not None:
raise Exception('tokentype embeddings is already initialized')
if torch.distributed.get_rank() == 0:
print('adding embedding for {} tokentypes'.format(num_tokentypes),
flush=True)
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
self.hidden_size)
def forward(self, input_ids, position_ids, attention_mask,
layer_past=None, get_present=False, tokentype_ids=None):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
else:
assert self.tokentype_embeddings is None
# Dropout.
embeddings = self.embedding_dropout(embeddings)
# Transformer.
transformer_output = self.transformer(embeddings, attention_mask)
transformer_output = self.transformer(embeddings, attention_mask,
layer_past=layer_past,
get_present=get_present)
if get_present:
transformer_output, presents = transformer_output
# Parallel logits.
transformer_output_parallel = mpu.copy_to_model_parallel_region(
......@@ -100,9 +129,12 @@ class GPT2Model(torch.nn.Module):
self.word_embeddings.weight)
if self.parallel_output:
return logits_parallel
return mpu.gather_from_model_parallel_region(logits_parallel)
output = logits_parallel
else:
output = mpu.gather_from_model_parallel_region(logits_parallel)
if get_present:
output = [output, presents]
return output
def gpt2_get_params_for_weight_decay_optimization(module):
......
......@@ -98,7 +98,7 @@ class GPT2ParallelSelfAttention(torch.nn.Module):
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def forward(self, hidden_states, ltor_mask):
def forward(self, hidden_states, ltor_mask, layer_past=None, get_present=False):
# hidden_states: [b, s, h]
# ltor_mask: [1, 1, s, s]
......@@ -112,13 +112,24 @@ class GPT2ParallelSelfAttention(torch.nn.Module):
query_layer = self._transpose_for_scores(mixed_query_layer)
key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer)
if layer_past is not None:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=-2)
value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=-2)
present = (key_layer, value_layer)
# Raw attention scores. [b, np, s, s]
attention_scores = torch.matmul(query_layer,
key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(
self.hidden_size_per_attention_head)
norm_factor = math.sqrt(math.sqrt(self.hidden_size_per_attention_head))
attention_scores = torch.matmul(query_layer/norm_factor,
key_layer.transpose(-1, -2)/norm_factor)
# Apply the left to right attention mask.
if get_present:
with torch.no_grad():
if layer_past is not None:
ltor_mask = ltor_mask[...,attention_scores.size(3)-1, :attention_scores.size(3)].unsqueeze(2)
else:
ltor_mask = ltor_mask[...,:attention_scores.size(3), :attention_scores.size(3)]
attention_scores = torch.mul(attention_scores, ltor_mask) - \
10000.0 * (1.0 - ltor_mask)
......@@ -143,6 +154,9 @@ class GPT2ParallelSelfAttention(torch.nn.Module):
output = self.dense(context_layer)
output = self.output_dropout(output)
if get_present:
output = [output, present]
return output
......@@ -268,14 +282,16 @@ class GPT2ParallelTransformerLayer(torch.nn.Module):
init_method,
output_layer_init_method=output_layer_init_method)
def forward(self, hidden_states, ltor_mask):
def forward(self, hidden_states, ltor_mask, layer_past=None, get_present=False):
# hidden_states: [b, s, h]
# ltor_mask: [1, 1, s, s]
# Layer norm at the begining of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output = self.attention(layernorm_output, ltor_mask)
attention_output = self.attention(layernorm_output, ltor_mask, layer_past=layer_past, get_present=get_present)
if get_present:
attention_output, presents = attention_output
# Residual connection.
layernorm_input = hidden_states + attention_output
# Layer norm post the self attention.
......@@ -285,6 +301,9 @@ class GPT2ParallelTransformerLayer(torch.nn.Module):
# Second residual connection.
output = layernorm_input + mlp_output
if get_present:
output = [output, presents]
return output
......@@ -376,7 +395,7 @@ class GPT2ParallelTransformer(torch.nn.Module):
# Final layer norm before output.
self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
def forward(self, hidden_states, attention_mask):
def forward(self, hidden_states, attention_mask, layer_past=None, get_present=False):
def custom(start, end):
def custom_forward(*inputs):
......@@ -387,7 +406,7 @@ class GPT2ParallelTransformer(torch.nn.Module):
return x_
return custom_forward
if self.checkpoint_activations:
if self.checkpoint_activations and not get_present:
l = 0
num_layers = len(self.layers)
chunk_length = self.checkpoint_num_layers
......@@ -396,11 +415,20 @@ class GPT2ParallelTransformer(torch.nn.Module):
hidden_states, attention_mask)
l += chunk_length
else:
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask)
presents = []
for i, layer in enumerate(self.layers):
past = None
if layer_past is not None:
past = layer_past[i]
hidden_states = layer(hidden_states, attention_mask, layer_past=past, get_present=get_present)
if get_present:
hidden_states, present = hidden_states
presents.append(present)
# Final layer norm.
output = self.final_layernorm(hidden_states)
if get_present:
output = [output, presents]
return output
......
......@@ -46,7 +46,8 @@ from utils import report_memory
from utils import print_args
from utils import print_params_min_max_norm
from utils import print_rank_0
from utils import enable_adlr_autoresume
from utils import check_adlr_autoresume_termination
def get_model(args):
"""Build the model."""
......@@ -114,7 +115,8 @@ def get_optimizer(model, args):
param.model_parallel = False
# Use Adam.
optimizer = Adam(param_groups,
betas = (0.9, 0.999)
optimizer = Adam(param_groups, betas=betas,
lr=args.lr, weight_decay=args.weight_decay)
# Wrap into fp16 optimizer.
......@@ -145,7 +147,10 @@ def get_learning_rate_scheduler(optimizer, args):
warmup_iter=warmup_iter,
num_iters=num_iters,
decay_style=args.lr_decay_style,
last_iter=init_step)
last_iter=init_step,
min_lr=args.min_lr,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
override_lr_scheduler=args.override_lr_scheduler)
return lr_scheduler
......@@ -299,7 +304,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
def train(model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator, timers, args):
train_data_iterator, val_data_iterator, timers, args, writer):
"""Train the model."""
# Turn on training mode which enables dropout.
......@@ -326,15 +331,37 @@ def train(model, optimizer, lr_scheduler,
iteration += 1
# Update losses.
total_lm_loss += lm_loss.data.detach().float()
total_nsp_loss += nsp_loss.data.detach().float()
current_lm_loss = lm_loss.data.detach().float()
current_nsp_loss = nsp_loss.data.detach().float()
total_lm_loss += current_lm_loss
total_nsp_loss += current_nsp_loss
# Logging.
timers_to_log = ['forward', 'backward', 'optimizer',
'batch generator', 'data loader']
learning_rate = optimizer.param_groups[0]['lr']
if writer and args.rank == 0:
writer.add_scalar('learning_rate', learning_rate, iteration)
writer.add_scalar('lm_loss', current_lm_loss, iteration)
writer.add_scalar('nsp_loss', current_nsp_loss, iteration)
if args.fp16:
writer.add_scalar('loss_scale', optimizer.loss_scale, iteration)
normalizer = iteration % args.log_interval
if normalizer == 0:
normalizer = args.log_interval
timers.write(timers_to_log, writer, iteration,
normalizer=normalizer)
if iteration % args.log_interval == 0:
learning_rate = optimizer.param_groups[0]['lr']
avg_nsp_loss = total_nsp_loss.item() / args.log_interval
avg_lm_loss = total_lm_loss.item() / args.log_interval
elapsed_time = timers('interval time').elapsed()
if writer and args.rank == 0:
writer.add_scalar('iteration_time',
elapsed_time / args.log_interval, iteration)
log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
args.train_iters)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
......@@ -351,9 +378,13 @@ def train(model, optimizer, lr_scheduler,
if report_memory_flag:
report_memory('after {} iterations'.format(iteration))
report_memory_flag = False
timers.log(['forward', 'backward', 'optimizer', 'batch generator',
'data loader'],
normalizer=args.log_interval)
timers.log(timers_to_log, normalizer=args.log_interval)
# Autoresume
if (iteration % args.adlr_autoresume_interval == 0) and args.adlr_autoresume:
check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler, args)
# Checkpointing
if args.save and args.save_interval and iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
......@@ -361,8 +392,8 @@ def train(model, optimizer, lr_scheduler,
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid:
prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(
prefix, val_data_iterator, model, args, timers, False)
evaluate_and_print_results(prefix, val_data_iterator, model, args,
writer, iteration, timers, False)
if args.exit_interval and iteration % args.exit_interval == 0:
torch.distributed.barrier()
......@@ -413,7 +444,8 @@ def evaluate(data_iterator, model, args, timers, verbose = False):
def evaluate_and_print_results(prefix, data_iterator, model,
args, timers, verbose=False):
args, writer, iteration,
timers, verbose=False):
"""Helper function to evaluate and dump results on screen."""
lm_loss, nsp_loss = evaluate(data_iterator, model,
args, timers, verbose)
......@@ -428,6 +460,11 @@ def evaluate_and_print_results(prefix, data_iterator, model,
print_rank_0(string)
print_rank_0('-' * length)
if writer and args.rank == 0:
writer.add_scalar('val_lm_loss', lm_loss, iteration)
writer.add_scalar('val_nsp_loss', nsp_loss, iteration)
writer.add_scalar('val_total_loss', val_loss, iteration)
return val_loss
......@@ -471,7 +508,8 @@ def get_train_val_test_data(args):
# Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0:
data_config = configure_data()
data_config.set_defaults(data_set_type='BERT', transpose=False)
ds_type = 'BERT'
data_config.set_defaults(data_set_type=ds_type, transpose=False)
(train_data, val_data, test_data), tokenizer = data_config.apply(args)
before = tokenizer.num_tokens
after = before
......@@ -514,11 +552,27 @@ def main():
# Arguments.
args = get_args()
writer = None
if args.tensorboard_dir and args.rank == 0:
try:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir = args.tensorboard_dir)
except ModuleNotFoundError:
print_rank_0('WARNING: TensorBoard writing requested but is not '
'available (are you using PyTorch 1.1.0 or later?), '
'no TensorBoard logs will be written.')
writer = None
# Pytorch distributed.
initialize_distributed(args)
if torch.distributed.get_rank() == 0:
print('Pretrain BERT model')
print_args(args)
print_args(args, writer)
# Autoresume.
torch.distributed.barrier()
if args.adlr_autoresume:
enable_adlr_autoresume(args)
# Random seeds for reproducability.
set_random_seed(args.seed)
......@@ -534,11 +588,15 @@ def main():
if train_data is not None:
train_data.batch_sampler.start_iter = args.iteration % \
len(train_data)
print_rank_0('setting training data start iteration to {}'.
format(train_data.batch_sampler.start_iter))
if val_data is not None:
start_iter_val = (args.train_iters // args.save_interval) * \
args.eval_interval
start_iter_val = (args.iteration // args.eval_interval) * \
args.eval_iters
val_data.batch_sampler.start_iter = start_iter_val % \
len(val_data)
print_rank_0('setting validation data start iteration to {}'.
format(val_data.batch_sampler.start_iter))
if train_data is not None:
train_data_iterator = iter(train_data)
......@@ -556,11 +614,12 @@ def main():
lr_scheduler,
train_data_iterator,
val_data_iterator,
timers, args)
timers, args, writer)
if args.do_valid:
prefix = 'the end of training for val data'
val_loss = evaluate_and_print_results(prefix, val_data_iterator,
model, args, timers, False)
model, args, writer, iteration,
timers, False)
if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
......@@ -574,7 +633,7 @@ def main():
# Run on test data.
prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, test_data_iterator,
model, args, timers, True)
model, args, None, 0, timers, True)
if __name__ == "__main__":
......
......@@ -15,10 +15,6 @@
"""Pretrain GPT2"""
# Flag to use Pytorch ddp which uses overlapping communication and computation.
USE_TORCH_DDP = False
from datetime import datetime
import os
import random
......@@ -33,10 +29,7 @@ from fp16 import FP16_Optimizer
from learning_rates import AnnealingLR
from model import GPT2Model
from model import gpt2_get_params_for_weight_decay_optimization
if USE_TORCH_DDP:
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
else:
from model import DistributedDataParallel as DDP
from model import DistributedDataParallel as LocalDDP
import mpu
from apex.optimizers import FusedAdam as Adam
from utils import Timers
......@@ -46,10 +39,11 @@ from utils import report_memory
from utils import print_args
from utils import print_params_min_max_norm
from utils import print_rank_0
from utils import enable_adlr_autoresume
from utils import check_adlr_autoresume_termination
from gpt2_data_loader import make_gpt2_dataloaders
def get_model(args):
"""Build the model."""
......@@ -79,12 +73,18 @@ def get_model(args):
model = FP16_Module(model)
# Wrap model for distributed training.
if USE_TORCH_DDP:
if args.DDP_impl == 'torch':
i = torch.cuda.current_device()
model = DDP(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
args.DDP_type = torch.nn.parallel.distributed.DistributedDataParallel
model = args.DDP_type(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
elif args.DDP_impl == 'local':
args.DDP_type = LocalDDP
model = args.DDP_type(model)
else:
model = DDP(model)
print_rank_0('Unknown DDP implementation specified: {}. '
'Exiting.'.format(args.DDP_impl))
exit()
return model
......@@ -93,7 +93,7 @@ def get_optimizer(model, args):
"""Set up the optimizer."""
# Build parameter groups (weight decay and non-decay).
while isinstance(model, (DDP, FP16_Module)):
while isinstance(model, (args.DDP_type, FP16_Module)):
model = model.module
param_groups = gpt2_get_params_for_weight_decay_optimization(model)
......@@ -136,7 +136,10 @@ def get_learning_rate_scheduler(optimizer, args):
warmup_iter=warmup_iter,
num_iters=num_iters,
decay_style=args.lr_decay_style,
last_iter=init_step)
last_iter=init_step,
min_lr=args.min_lr,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
override_lr_scheduler=args.override_lr_scheduler)
return lr_scheduler
......@@ -159,7 +162,8 @@ def setup_model_and_optimizer(args):
def get_masks_and_position_ids(data,
eod_token,
reset_position_ids,
reset_attention_mask):
reset_attention_mask,
eod_mask_loss):
# Extract batch size and sequence length.
batch_size, seq_length = data.size()
......@@ -175,7 +179,8 @@ def get_masks_and_position_ids(data,
# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
loss_mask[data == eod_token] = 0.0
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long,
......@@ -246,7 +251,8 @@ def get_batch(data_iterator, args, timers):
tokens,
args.eod_token,
args.reset_position_ids,
args.reset_attention_mask)
args.reset_attention_mask,
args.eod_mask_loss)
# Convert
if args.fp16:
attention_mask = attention_mask.half()
......@@ -292,7 +298,7 @@ def backward_step(optimizer, model, lm_loss, args, timers):
reduced_losses = lm_loss.view(1)
torch.distributed.all_reduce(reduced_losses.data)
reduced_losses.data = reduced_losses.data / args.world_size
if not USE_TORCH_DDP:
if args.DDP_impl == 'local':
timers('allreduce').start()
model.allreduce_params(reduce_after=False,
fp32_allreduce=args.fp32_allreduce)
......@@ -343,7 +349,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
def train(model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator, timers, args):
train_data_iterator, val_data_iterator, timers, args, writer):
"""Train the model."""
# Turn on training mode which enables dropout.
......@@ -369,13 +375,37 @@ def train(model, optimizer, lr_scheduler,
iteration += 1
# Update losses.
total_lm_loss += lm_loss.data.detach().float()
current_lm_loss = lm_loss.data.detach().float()
total_lm_loss += current_lm_loss
# Logging.
if args.DDP_impl == 'torch':
timers_to_log = ['forward', 'backward', 'optimizer',
'batch generator', 'data loader']
else:
timers_to_log = ['forward', 'backward', 'allreduce', 'optimizer',
'batch generator', 'data loader']
learning_rate = optimizer.param_groups[0]['lr']
if writer and args.rank == 0:
writer.add_scalar('learning_rate', learning_rate, iteration)
writer.add_scalar('train_loss', current_lm_loss, iteration)
if args.fp16:
writer.add_scalar('loss_scale', optimizer.loss_scale, iteration)
normalizer = iteration % args.log_interval
if normalizer == 0:
normalizer = args.log_interval
timers.write(timers_to_log, writer, iteration,
normalizer=normalizer)
if iteration % args.log_interval == 0:
learning_rate = optimizer.param_groups[0]['lr']
avg_lm_loss = total_lm_loss.item() / args.log_interval
elapsed_time = timers('interval time').elapsed()
if writer and args.rank == 0:
writer.add_scalar('iteration_time',
elapsed_time / args.log_interval, iteration)
log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
args.train_iters)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
......@@ -390,14 +420,13 @@ def train(model, optimizer, lr_scheduler,
if report_memory_flag:
report_memory('after {} iterations'.format(iteration))
report_memory_flag = False
if USE_TORCH_DDP:
timers.log(['forward', 'backward', 'optimizer',
'batch generator', 'data loader'],
normalizer=args.log_interval)
else:
timers.log(['forward', 'backward', 'allreduce', 'optimizer',
'batch generator', 'data loader'],
normalizer=args.log_interval)
timers.log(timers_to_log, normalizer=args.log_interval)
# Autoresume
if (iteration % args.adlr_autoresume_interval == 0) and args.adlr_autoresume:
check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler, args)
# Checkpointing
if args.save and args.save_interval and iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
......@@ -405,8 +434,8 @@ def train(model, optimizer, lr_scheduler,
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid:
prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(
prefix, val_data_iterator, model, args, timers, False)
evaluate_and_print_results(prefix, val_data_iterator, model, args,
writer, iteration, timers, False)
if args.exit_interval and iteration % args.exit_interval == 0:
torch.distributed.barrier()
......@@ -436,7 +465,7 @@ def evaluate(data_iterator, model, args, timers, verbose=False):
# Forward evaluation.
lm_loss = forward_step(data_iterator, model, args, timers)
# Reduce across processes.
if isinstance(model, DDP):
if isinstance(model, args.DDP_type):
torch.distributed.all_reduce(lm_loss.data)
lm_loss.data = lm_loss.data / args.world_size
......@@ -450,7 +479,8 @@ def evaluate(data_iterator, model, args, timers, verbose=False):
def evaluate_and_print_results(prefix, data_iterator, model,
args, timers, verbose=False):
args, writer, iteration,
timers, verbose=False):
"""Helper function to evaluate and dump results on screen."""
lm_loss = evaluate(data_iterator, model, args, timers, verbose)
lm_ppl = math.exp(min(20, lm_loss))
......@@ -463,6 +493,10 @@ def evaluate_and_print_results(prefix, data_iterator, model,
print_rank_0(string)
print_rank_0('-' * length)
if writer and args.rank == 0:
writer.add_scalar('val_loss', lm_loss, iteration)
writer.add_scalar('val_ppl', lm_ppl, iteration)
return lm_loss
......@@ -555,11 +589,27 @@ def main():
# Arguments.
args = get_args()
writer = None
if args.tensorboard_dir and args.rank == 0:
try:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir = args.tensorboard_dir)
except ModuleNotFoundError:
print_rank_0('WARNING: TensorBoard writing requested but is not '
'available (are you using PyTorch 1.1.0 or later?), '
'no TensorBoard logs will be written.')
writer = None
# Pytorch distributed.
initialize_distributed(args)
if torch.distributed.get_rank() == 0:
print('Pretrain GPT2 model')
print_args(args)
print_args(args, writer)
# Autoresume.
torch.distributed.barrier()
if args.adlr_autoresume:
enable_adlr_autoresume(args)
# Random seeds for reproducability.
set_random_seed(args.seed)
......@@ -576,11 +626,15 @@ def main():
if train_data is not None:
train_data.batch_sampler.start_iter = args.iteration % \
len(train_data)
print_rank_0('setting training data start iteration to {}'.
format(train_data.batch_sampler.start_iter))
if val_data is not None:
start_iter_val = (args.train_iters // args.save_interval) * \
args.eval_interval
start_iter_val = (args.iteration // args.eval_interval) * \
args.eval_iters
val_data.batch_sampler.start_iter = start_iter_val % \
len(val_data)
print_rank_0('setting validation data start iteration to {}'.
format(val_data.batch_sampler.start_iter))
if train_data is not None:
train_data_iterator = iter(train_data)
else:
......@@ -598,12 +652,13 @@ def main():
lr_scheduler,
train_data_iterator,
val_data_iterator,
timers, args)
timers, args, writer)
if args.do_valid:
prefix = 'the end of training for val data'
val_loss = evaluate_and_print_results(prefix, val_data_iterator,
model, args, timers, False)
model, args, writer, iteration,
timers, False)
if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer,
......@@ -618,7 +673,7 @@ def main():
# Run on test data.
prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, test_data_iterator,
model, args, timers, True)
model, args, None, 0, timers, True)
if __name__ == "__main__":
......
#!/bin/bash
CHECKPOINT_PATH=/path/to/checkpoint
CHECKPOINT_PATH=checkpoints/gpt2_345m/
MPSIZE=1
NLAYERS=24
NHIDDEN=1024
NATT=16
NLAYERS=12
NHIDDEN=768
NATT=12
MAXSEQLEN=1024
#SAMPLING ARGS
......@@ -26,4 +26,7 @@ python generate_samples.py \
--out-seq-length $MAXSEQLEN \
--temperature $TEMP \
--top_k $TOPK \
--top_p $TOPP
--genfile dbg_unconditional.json \
--num-samples 10 \
--top_p $TOPP \
--recompute
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
......@@ -28,6 +28,8 @@ parser.add_argument('--data-path', type=str, required=True,
help='Data path for evaluation data')
parser.add_argument('--cloze-eval', action='store_true',
help='Run lambada cloze eval instead of perplexity eval.')
parser.add_argument('--strict-lambada', action='store_true',
help='use more difficult formulation of lambada')
parser.add_argument('--webtext-eval', action='store_true',
help='Run webtext PPL eval instead of wikitext PPL eval.')
parser.add_argument('--eval-iters', default=5000, type=int,
......@@ -38,6 +40,9 @@ parser.add_argument('--load-openai', action='store_true',
help='Load weights from saved openai/hf checkpoints')
parser.add_argument('--cache-dir', type=str, default='cache',
help='directory to cache gpt2 tokenizers')
parser.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
help='Pad the vocab size to be divisible by this value.'
'This is added for computational efficieny reasons.')
args = parser.parse_args()
multinode_args = ''
......@@ -60,18 +65,23 @@ CMD = ' --model-parallel-size {model_par} \
--attention-dropout 0.1 \
--fp16 \
--overlapping-eval 32 \
--make-vocab-size-divisible-by {make_vocab_size_divisible_by} \
--cache-dir {cache} '.format(model_par=args.model_parallel_size,
nlayers=args.num_layers,
hidden=args.hidden_size,
model=args.model_path,
batch=args.batch_size,
natt=args.num_attention_heads,
make_vocab_size_divisible_by=args.make_vocab_size_divisible_by,
cache=args.cache_dir)
if args.load_openai:
CMD += ' --load-openai '
if args.cloze_eval:
CMD += ' --valid-data {} '.format(args.data_path)
CMD += ' --cloze-eval '
if args.strict_lambada:
CMD += ' --strict-lambada '
CMD = 'evaluate_gpt2.py' + CMD
print('Running Lambada Eval Command:', flush=True)
elif args.webtext_eval:
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Takes a corpora of files (specified by `--input_files`) with json data separated
by newlines (loose json). Splits data into train.json, val.json, test.json files
under `output_dir`.
Note: This code has the potential to override files with the names
train.json, val.json, test.json in `--output_dir`.
"""
import os
import argparse
import math
import random
parser = argparse.ArgumentParser('resplit loose json data into train/val/test')
parser.add_argument('--input_files', nargs='+', required=True,
help='whitespace separated list of input data files')
parser.add_argument('--output_dir', required=True,
help='output directory where to put files')
parser.add_argument('--test_percent', type=float, nargs='+', default=[0.05, 0],
help='percentage of available data to use for val/test dataset')
args = parser.parse_args()
def get_lines(filepath):
lines = []
with open(filepath, 'r') as f:
for i, l in enumerate(f.readlines()):
l = l.strip()
lines.append(l)
return lines
def get_splits(lines, line_counts):
all_lines = []
line_idx = []
file_mappings = []
for i, l in enumerate(lines):
all_lines.extend(l)
line_idx.extend(list(range(len(l))))
file_mappings.extend([i]*len(l))
indices = list(range(len(all_lines)))
random.shuffle(indices)
all_lines = [all_lines[idx] for idx in indices]
line_idx = [line_idx[idx] for idx in indices]
file_mappings = [file_mappings[idx] for idx in indices]
splits = []
mappings = []
start = 0
for end in line_counts:
end += start
splits.append(all_lines[start:end])
mappings.append(format_mappings(line_idx[start:end], file_mappings[start:end]))
start = end
return splits, mappings
def format_mappings(line_idx, file_mappings):
lines = []
for m, l in zip(file_mappings, line_idx):
lines.append(str(m).strip()+'\t'+str(l).strip())
return lines
def get_filepaths(filepaths, output_dir):
paths = []
train_path = 'train.json'
dev_path = 'dev.json'
test_path = 'test.json'
paths.append(os.path.join(output_dir, train_path))
paths.append(os.path.join(output_dir, dev_path))
paths.append(os.path.join(output_dir, test_path))
return paths
def write_files(lines, mappings, filepaths):
for l, m, path in zip(lines, mappings, filepaths):
write_file(l, path)
write_mapping_file(m, path)
def write_file(lines, path):
print('Writing:', path)
with open(path, 'w') as f:
for l in lines:
f.write(l+'\n')
def write_mapping_file(m, path):
path = path+'.map'
m = [get_mapping_header()]+m
write_file(m, path)
def get_mapping_header():
return 'file\tline #'
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
lines = []
for filepath in args.input_files:
_lines = get_lines(filepath)
lines.append(_lines)
#calculate number of lines to use for each
line_counts = [len(l) for l in lines]
total_lines = sum(line_counts)
dev_percent = args.test_percent[0]
dev_lines = math.ceil(dev_percent*total_lines)
test_percent = 0
if len(args.test_percent)==2:
test_percent=args.test_percent[1]
test_lines = math.ceil(test_percent*total_lines)
train_lines = total_lines-(test_lines+dev_lines)
normed_lines = [train_lines, dev_lines, test_lines]
normed_lines = [int(l) for l in normed_lines]
splits, mappings = get_splits(lines, normed_lines)
filepaths = get_filepaths(args.input_files, args.output_dir)
print('Writing output to:', filepaths)
write_files(splits, mappings, filepaths)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment