Commit a1d04b79 authored by Jared Casper's avatar Jared Casper
Browse files

Updating public repo with latest changes.

parent 93ab4bea
__pycache__
...@@ -114,7 +114,8 @@ def add_training_args(parser): ...@@ -114,7 +114,8 @@ def add_training_args(parser):
help='report interval') help='report interval')
group.add_argument('--exit-interval', type=int, default=None, group.add_argument('--exit-interval', type=int, default=None,
help='Exit the program after this many new iterations.') help='Exit the program after this many new iterations.')
group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory')
group.add_argument('--seed', type=int, default=1234, group.add_argument('--seed', type=int, default=1234,
help='random seed') help='random seed')
# Batch prodecuer arguments # Batch prodecuer arguments
...@@ -123,6 +124,8 @@ def add_training_args(parser): ...@@ -123,6 +124,8 @@ def add_training_args(parser):
group.add_argument('--reset-attention-mask', action='store_true', group.add_argument('--reset-attention-mask', action='store_true',
help='Reset self attention maske after ' help='Reset self attention maske after '
'end-of-document token.') 'end-of-document token.')
group.add_argument('--eod-mask-loss', action='store_true',
help='Mask loss for the end of document tokens')
# Learning rate. # Learning rate.
group.add_argument('--lr-decay-iters', type=int, default=None, group.add_argument('--lr-decay-iters', type=int, default=None,
...@@ -133,9 +136,25 @@ def add_training_args(parser): ...@@ -133,9 +136,25 @@ def add_training_args(parser):
help='learning rate decay function') help='learning rate decay function')
group.add_argument('--lr', type=float, default=1.0e-4, group.add_argument('--lr', type=float, default=1.0e-4,
help='initial learning rate') help='initial learning rate')
group.add_argument('--min-lr', type=float, default=0.0,
help='Minumum value for learning rate. The scheduler'
'clip values below this threshold.')
group.add_argument('--warmup', type=float, default=0.01, group.add_argument('--warmup', type=float, default=0.01,
help='percentage of data to warmup on (.01 = 1% of all ' help='percentage of data to warmup on (.01 = 1% of all '
'training iters). Default 0.01') 'training iters). Default 0.01')
group.add_argument('--override-lr-scheduler', action='store_true',
help='Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum '
'number of iterations, and decay style from input '
'arguments and ignore values from checkpoints. Note'
'that all the above values will be reset.')
group.add_argument('--use-checkpoint-lr-scheduler', action='store_true',
help='Use checkpoint to set the values of the scheduler '
'(learning rate, warmup iterations, minimum learning '
'rate, maximum number of iterations, and decay style '
'from input arguments and ignore values from '
'checkpoints. Notethat all the above values will be '
'reset.')
# model checkpointing # model checkpointing
group.add_argument('--save', type=str, default=None, group.add_argument('--save', type=str, default=None,
help='Output directory to save checkpoints to.') help='Output directory to save checkpoints to.')
...@@ -163,8 +182,17 @@ def add_training_args(parser): ...@@ -163,8 +182,17 @@ def add_training_args(parser):
group.add_argument('--distributed-backend', default='nccl', group.add_argument('--distributed-backend', default='nccl',
help='which backend to use for distributed ' help='which backend to use for distributed '
'training. One of [gloo, nccl]') 'training. One of [gloo, nccl]')
group.add_argument('--DDP-impl', default='local',
help='which DistributedDataParallel implementation '
'to use. One of [local, torch]')
group.add_argument('--local_rank', type=int, default=None, group.add_argument('--local_rank', type=int, default=None,
help='local rank passed from distributed launcher') help='local rank passed from distributed launcher')
# autoresume
group.add_argument('--adlr-autoresume', action='store_true',
help='enable autoresume on adlr cluster.')
group.add_argument('--adlr-autoresume-interval', type=int, default=1000,
help='intervals over which check for autoresume'
'termination signal')
return parser return parser
...@@ -193,6 +221,8 @@ def add_evaluation_args(parser): ...@@ -193,6 +221,8 @@ def add_evaluation_args(parser):
help='sliding window for overlapping eval ') help='sliding window for overlapping eval ')
group.add_argument('--cloze-eval', action='store_true', group.add_argument('--cloze-eval', action='store_true',
help='Evaluation dataset from `--valid-data` is a cloze task') help='Evaluation dataset from `--valid-data` is a cloze task')
group.add_argument('--strict-lambada', action='store_true',
help='use more difficult formulation of lambada')
group.add_argument('--eval-hf', action='store_true', group.add_argument('--eval-hf', action='store_true',
help='perform evaluation with huggingface openai model.' help='perform evaluation with huggingface openai model.'
'use `--load` to specify weights path to be loaded') 'use `--load` to specify weights path to be loaded')
...@@ -207,9 +237,23 @@ def add_text_generate_args(parser): ...@@ -207,9 +237,23 @@ def add_text_generate_args(parser):
group = parser.add_argument_group('Text generation', 'configurations') group = parser.add_argument_group('Text generation', 'configurations')
group.add_argument("--temperature", type=float, default=1.0) group.add_argument("--temperature", type=float, default=1.0)
group.add_argument("--greedy", action='store_true', default=False)
group.add_argument("--top_p", type=float, default=0.0) group.add_argument("--top_p", type=float, default=0.0)
group.add_argument("--top_k", type=int, default=0) group.add_argument("--top_k", type=int, default=0)
group.add_argument("--out-seq-length", type=int, default=256) group.add_argument("--out-seq-length", type=int, default=1024)
group.add_argument("--sample-input-file", type=str, default="",
help='get input from file instead of interactive mode, '
'each line is an input' )
group.add_argument("--sample-output-file", type=str, default="",
help='output file got from --sample-input-file')
group.add_argument("--num-samples", type=int, default=0,
help='number of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling')
group.add_argument("--genfile", type=str,
help='output file when generating unconditionally')
group.add_argument("--recompute", action='store_true',
help='during generation recompute all attention '
'instead of using previously computed keys/values.')
return parser return parser
......
...@@ -148,7 +148,8 @@ def make_loaders(args): ...@@ -148,7 +148,8 @@ def make_loaders(args):
'model_type': args.tokenizer_model_type, 'model_type': args.tokenizer_model_type,
'cache_dir': args.cache_dir, 'cache_dir': args.cache_dir,
'max_preds_per_seq': args.max_preds_per_seq, 'max_preds_per_seq': args.max_preds_per_seq,
'presplit_sentences': args.presplit_sentences} 'presplit_sentences': args.presplit_sentences,
'parallel_group': mpu.get_data_parallel_group()}
eval_set_args = copy.copy(data_set_args) eval_set_args = copy.copy(data_set_args)
eval_set_args['split'] = [1.] eval_set_args['split'] = [1.]
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
import os import os
import math import math
import torch
from .samplers import DistributedBatchSampler from .samplers import DistributedBatchSampler
from .datasets import json_dataset, csv_dataset, split_ds, ConcatDataset, SplitDataset, bert_sentencepair_dataset, GPT2Dataset from .datasets import json_dataset, csv_dataset, split_ds, ConcatDataset, SplitDataset, bert_sentencepair_dataset, GPT2Dataset
from .lazy_loader import exists_lazy, make_lazy, lazy_array_loader from .lazy_loader import exists_lazy, make_lazy, lazy_array_loader
...@@ -61,7 +63,8 @@ def supported_corpus(corpus_name): ...@@ -61,7 +63,8 @@ def supported_corpus(corpus_name):
def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=None, split=[1.], def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=None, split=[1.],
delim=',', loose=False, binarize_sent=False, drop_unlabeled=False, tokenizer=None, delim=',', loose=False, binarize_sent=False, drop_unlabeled=False, tokenizer=None,
tokenizer_type='CharacterLevelTokenizer', tokenizer_model_path=None, vocab_size=None, tokenizer_type='CharacterLevelTokenizer', tokenizer_model_path=None, vocab_size=None,
model_type='bpe', pad_token=0, character_converage=1.0, non_binary_cols=None, **kwargs): model_type='bpe', pad_token=0, character_converage=1.0, non_binary_cols=None,
parallel_group=None, **kwargs):
"""function to create datasets+tokenizers for common options""" """function to create datasets+tokenizers for common options"""
if isinstance(process_fn, str): if isinstance(process_fn, str):
process_fn = eval(process_fn) process_fn = eval(process_fn)
...@@ -76,11 +79,19 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N ...@@ -76,11 +79,19 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
named_corpora = True named_corpora = True
name = path_ name = path_
path_ = corpora.NAMED_CORPORA[path_].PATH path_ = corpora.NAMED_CORPORA[path_].PATH
if not exists_lazy(path_, data_type='data'): if torch.distributed.get_rank() == 0 and not exists_lazy(path_, data_type='data'):
# create cached version of dataset for lazy loading if it doesn't exist # create cached version of dataset for lazy loading if it doesn't exist
text = get_dataset(name if named_corpora else path_, text_key=text_key, label_key=label_key, binarize_sent=binarize_sent, text = get_dataset(name if named_corpora else path_, text_key=text_key, label_key=label_key, binarize_sent=binarize_sent,
delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose) delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose)
make_lazy(path_, text.X, data_type='data') make_lazy(path_, text.X, data_type='data')
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=parallel_group)
assert counts[0].item() == torch.distributed.get_world_size(
group=parallel_group)
text = lazy_array_loader(path_, data_type='data', map_fn=process_fn) text = lazy_array_loader(path_, data_type='data', map_fn=process_fn)
else: else:
# get dataset # get dataset
...@@ -107,15 +118,17 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N ...@@ -107,15 +118,17 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
# Split dataset into train/val/test (and wrap bert dataset) # Split dataset into train/val/test (and wrap bert dataset)
if should_split(split): if should_split(split):
ds = split_ds(ds, split) ds = split_ds(ds, split)
if ds_type.lower() == 'bert': if 'bert' in ds_type.lower():
presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
ds = [bert_sentencepair_dataset(d, max_seq_len=seq_length, presplit_sentences=presplit_sentences) if d is not None else None for d in ds] dstype = bert_sentencepair_dataset
ds = [dstype(d, max_seq_len=seq_length, presplit_sentences=presplit_sentences) if d is not None else None for d in ds]
elif ds_type.lower() == 'gpt2': elif ds_type.lower() == 'gpt2':
ds = [GPT2Dataset(d, max_seq_len=seq_length) if d is not None else None for d in ds] ds = [GPT2Dataset(d, max_seq_len=seq_length) if d is not None else None for d in ds]
else: else:
if ds_type.lower() == 'bert': if 'bert' in ds_type.lower():
presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
ds = bert_sentencepair_dataset(ds, max_seq_len=seq_length, presplit_sentences=presplit_sentences) dstype = bert_sentencepair_dataset
ds = dstype(ds, max_seq_len=seq_length, presplit_sentences=presplit_sentences)
elif ds_type.lower() == 'gpt2': elif ds_type.lower() == 'gpt2':
ds = GPT2Dataset(ds, max_seq_len=seq_length) ds = GPT2Dataset(ds, max_seq_len=seq_length)
return ds, tokenizer return ds, tokenizer
...@@ -461,6 +461,7 @@ class GPT2Dataset(data.Dataset): ...@@ -461,6 +461,7 @@ class GPT2Dataset(data.Dataset):
weighted=True, weighted=True,
sample_across_doc=True, sample_across_doc=True,
random_across_doc_sampling=True, random_across_doc_sampling=True,
bias_for_single_doc=False,
sentence_start=False, **kwargs): sentence_start=False, **kwargs):
self.ds = ds self.ds = ds
self.ds_len = len(self.ds) self.ds_len = len(self.ds)
...@@ -473,6 +474,7 @@ class GPT2Dataset(data.Dataset): ...@@ -473,6 +474,7 @@ class GPT2Dataset(data.Dataset):
self.weighted = weighted self.weighted = weighted
self.sample_across_doc = sample_across_doc self.sample_across_doc = sample_across_doc
self.random_across_doc_sampling = random_across_doc_sampling self.random_across_doc_sampling = random_across_doc_sampling
self.bias_for_single_doc = bias_for_single_doc
self.sentence_start = sentence_start self.sentence_start = sentence_start
self.init_weighting() self.init_weighting()
...@@ -510,7 +512,10 @@ class GPT2Dataset(data.Dataset): ...@@ -510,7 +512,10 @@ class GPT2Dataset(data.Dataset):
# truncate or pad tokens # truncate or pad tokens
num_tokens = len(tokens) num_tokens = len(tokens)
tokens_to_strip = num_tokens - self.max_seq_len - 1 if self.bias_for_single_doc:
tokens_to_strip = num_tokens - self.max_seq_len - 1
else:
tokens_to_strip = num_tokens - 1
if tokens_to_strip > 0: if tokens_to_strip > 0:
strip_left_tokens = rng.randint(tokens_to_strip + 1) strip_left_tokens = rng.randint(tokens_to_strip + 1)
tokens = tokens[strip_left_tokens:] tokens = tokens[strip_left_tokens:]
...@@ -576,7 +581,7 @@ class bert_sentencepair_dataset(data.Dataset): ...@@ -576,7 +581,7 @@ class bert_sentencepair_dataset(data.Dataset):
dataset_size (int): number of random sentencepairs in the dataset. Default: len(ds)*(len(ds)-1) dataset_size (int): number of random sentencepairs in the dataset. Default: len(ds)*(len(ds)-1)
""" """
def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None, short_seq_prob=.01, dataset_size=None, presplit_sentences=False, weighted=True,**kwargs): def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None, short_seq_prob=.01, dataset_size=None, presplit_sentences=False, weighted=True, **kwargs):
self.ds = ds self.ds = ds
self.ds_len = len(self.ds) self.ds_len = len(self.ds)
self.tokenizer = self.ds.GetTokenizer() self.tokenizer = self.ds.GetTokenizer()
...@@ -758,7 +763,8 @@ class bert_sentencepair_dataset(data.Dataset): ...@@ -758,7 +763,8 @@ class bert_sentencepair_dataset(data.Dataset):
""" """
tokens_a, token_types_a = a tokens_a, token_types_a = a
tokens_b, token_types_b = b tokens_b, token_types_b = b
max_num_tokens = max_seq_len - 3 max_num_tokens = self.calc_seq_len(max_seq_len)
# max_num_tokens = max_seq_len - 3
while True: while True:
len_a = len(tokens_a) len_a = len(tokens_a)
len_b = len(tokens_b) len_b = len(tokens_b)
...@@ -782,6 +788,9 @@ class bert_sentencepair_dataset(data.Dataset): ...@@ -782,6 +788,9 @@ class bert_sentencepair_dataset(data.Dataset):
trunc_types.pop() trunc_types.pop()
return (tokens_a, token_types_a), (tokens_b, token_types_b) return (tokens_a, token_types_a), (tokens_b, token_types_b)
def calc_seq_len(self, max_seq_len):
return max_seq_len - 3
def mask_token(self, idx, tokens, types, vocab_words, rng): def mask_token(self, idx, tokens, types, vocab_words, rng):
""" """
helper function to mask `idx` token from `tokens` according to helper function to mask `idx` token from `tokens` according to
...@@ -807,6 +816,11 @@ class bert_sentencepair_dataset(data.Dataset): ...@@ -807,6 +816,11 @@ class bert_sentencepair_dataset(data.Dataset):
seq += [self.tokenizer.get_command('pad').Id] * num_pad seq += [self.tokenizer.get_command('pad').Id] * num_pad
return seq, pad_mask return seq, pad_mask
def concat_tokens(self, tokens_a, token_types_a, tokens_b, token_types_b):
tokens = [self.tokenizer.get_command('ENC').Id] + tokens_a + [self.tokenizer.get_command('sep').Id] + tokens_b + [self.tokenizer.get_command('sep').Id]
token_types = [token_types_a[0]] + token_types_a + [token_types_a[0]] + token_types_b + [token_types_b[0]]
return tokens, token_types
def create_masked_lm_predictions(self, a, b, mask_lm_prob, max_preds_per_seq, vocab_words, rng): def create_masked_lm_predictions(self, a, b, mask_lm_prob, max_preds_per_seq, vocab_words, rng):
""" """
Mask sequence pair for BERT training according to: Mask sequence pair for BERT training according to:
...@@ -814,8 +828,7 @@ class bert_sentencepair_dataset(data.Dataset): ...@@ -814,8 +828,7 @@ class bert_sentencepair_dataset(data.Dataset):
""" """
tokens_a, token_types_a = a tokens_a, token_types_a = a
tokens_b, token_types_b = b tokens_b, token_types_b = b
tokens = [self.tokenizer.get_command('ENC').Id] + tokens_a + [self.tokenizer.get_command('sep').Id] + tokens_b + [self.tokenizer.get_command('sep').Id] tokens, token_types = self.concat_tokens(tokens_a, token_types_a, tokens_b, token_types_b)
token_types = [token_types_a[0]] + token_types_a + [token_types_a[0]] + token_types_b + [token_types_b[0]]
len_a = len(tokens_a) len_a = len(tokens_a)
len_b = len(tokens_b) len_b = len(tokens_b)
......
...@@ -111,7 +111,7 @@ class lazy_array_loader(object): ...@@ -111,7 +111,7 @@ class lazy_array_loader(object):
lazypath = get_lazy_path(path) lazypath = get_lazy_path(path)
datapath = os.path.join(lazypath, data_type) datapath = os.path.join(lazypath, data_type)
#get file where array entries are concatenated into one big string #get file where array entries are concatenated into one big string
self._file = open(datapath, 'rb') self._file = open(datapath, 'rb', buffering=0)
self.file = self._file self.file = self._file
#memory map file if necessary #memory map file if necessary
self.mem_map = mem_map self.mem_map = mem_map
......
...@@ -795,7 +795,6 @@ class BertWordPieceTokenizer(Tokenizer): ...@@ -795,7 +795,6 @@ class BertWordPieceTokenizer(Tokenizer):
Tokens = Tokens.tokenization Tokens = Tokens.tokenization
return ' '.join(Tokens) return ' '.join(Tokens)
class GPT2BPETokenizer(Tokenizer): class GPT2BPETokenizer(Tokenizer):
def __init__(self, cache_dir=None, **kwargs): def __init__(self, cache_dir=None, **kwargs):
self.text_tokenizer = GPT2Tokenizer.from_pretrained('gpt2', self.text_tokenizer = GPT2Tokenizer.from_pretrained('gpt2',
...@@ -887,4 +886,3 @@ class GPT2BPETokenizer(Tokenizer): ...@@ -887,4 +886,3 @@ class GPT2BPETokenizer(Tokenizer):
if isinstance(Tokens, Tokenization): if isinstance(Tokens, Tokenization):
Tokens = Tokens.tokenization Tokens = Tokens.tokenization
return self.text_tokenizer.decode([self.TokenToId(tok) for tok in Tokens]) return self.text_tokenizer.decode([self.TokenToId(tok) for tok in Tokens])
...@@ -210,10 +210,13 @@ def forward_step(data_iterator, model, args, timers): ...@@ -210,10 +210,13 @@ def forward_step(data_iterator, model, args, timers):
lm_loss = torch.sum( lm_loss = torch.sum(
losses.view(-1) * loss_mask.float()) losses.view(-1) * loss_mask.float())
else: else:
outputs = torch.argmax(output, -1).contiguous().view(-1) outputs = torch.argmax(output, -1)
acc = (outputs == lm_labels.contiguous().view(-1)).float() correct = (outputs == lm_labels).float()
loss_mask = loss_mask.contiguous().view(-1).float() correct[(1-loss_mask).bool()] = 1
lm_loss = torch.sum(acc * loss_mask) correct = correct.prod(-1)
lm_loss = correct.sum()
# loss_mask = loss_mask.contiguous().view(-1).float()
# lm_loss = torch.sum(acc * loss_mask)
return lm_loss return lm_loss
...@@ -345,7 +348,7 @@ def set_random_seed(seed): ...@@ -345,7 +348,7 @@ def set_random_seed(seed):
class LM_Eval_Dataset(torch.utils.data.Dataset): class LM_Eval_Dataset(torch.utils.data.Dataset):
def __init__(self, tokens, seq_len, pad_idx, overalapping_eval=None): def __init__(self, tokens, seq_len, pad_idx, overalapping_eval=None, **kwargs):
self.tokens = tokens self.tokens = tokens
self.seq_len = seq_len self.seq_len = seq_len
self.pad_idx = pad_idx self.pad_idx = pad_idx
...@@ -379,15 +382,30 @@ class LM_Eval_Dataset(torch.utils.data.Dataset): ...@@ -379,15 +382,30 @@ class LM_Eval_Dataset(torch.utils.data.Dataset):
return {'text': np.array(tokens), 'pad_mask': pad_mask} return {'text': np.array(tokens), 'pad_mask': pad_mask}
class Lambada_Eval_Dataset(torch.utils.data.Dataset): class Lambada_Eval_Dataset(torch.utils.data.Dataset):
def __init__(self, path, tokenizer, seq_len): def __init__(self, path, tokenizer, seq_len, strict=False, **kwargs):
self.seq_len = seq_len self.seq_len = seq_len
self.pad_idx = tokenizer.get_command('pad').Id self.pad_idx = tokenizer.get_command('pad').Id
self.tokenizer = tokenizer
self.strict = strict
self.tokens = [] self.tokens = []
self.labels = []
with open(path, 'r') as f: with open(path, 'r') as f:
for line in f.readlines(): for line in f.readlines():
text = json.loads(line)['text'] text = json.loads(line)['text']
self.tokens.append(tokenizer.EncodeAsIds(text).tokenization) tokens, labels = self.get_tokens(text)
self.tokens.append(tokens)
self.labels.append(labels)
def get_tokens(self, text):
if not self.strict:
tokens = self.tokenizer.EncodeAsIds(text).tokenization
return tokens[:-1], [tokens[-1]]
last_token = text.split()[-1]
start_idx = text.rfind(last_token)
beginning_tokens = self.tokenizer.EncodeAsIds(text[:start_idx].strip()).tokenization
last_token = self.tokenizer.EncodeAsIds(' '+last_token).tokenization
return beginning_tokens, last_token
def __len__(self): def __len__(self):
return len(self.tokens) return len(self.tokens)
...@@ -397,7 +415,10 @@ class Lambada_Eval_Dataset(torch.utils.data.Dataset): ...@@ -397,7 +415,10 @@ class Lambada_Eval_Dataset(torch.utils.data.Dataset):
tokens = self.tokens[idx] tokens = self.tokens[idx]
num_tokens = len(tokens) num_tokens = len(tokens)
pad_mask = [0]*num_tokens pad_mask = [0]*num_tokens
pad_mask[-1] = 1 labels = self.labels[idx]
pad_mask += [1]*len(labels)
tokens = tokens+labels
num_tokens = len(tokens)
if num_tokens < self.seq_len+1: if num_tokens < self.seq_len+1:
num_pad = (self.seq_len+1-num_tokens) num_pad = (self.seq_len+1-num_tokens)
pad_mask += [0]*(num_pad) pad_mask += [0]*(num_pad)
...@@ -442,7 +463,7 @@ def get_eval_data(args): ...@@ -442,7 +463,7 @@ def get_eval_data(args):
val_dataset = LM_Eval_Dataset(tokenized_data, seq_len, eod_token, val_dataset = LM_Eval_Dataset(tokenized_data, seq_len, eod_token,
args.overlapping_eval) args.overlapping_eval)
else: else:
val_dataset = Lambada_Eval_Dataset(valid_data, tokenizer, seq_len) val_dataset = Lambada_Eval_Dataset(valid_data, tokenizer, seq_len, args.strict_lambada)
num_tokenized_tokens = 0 num_tokenized_tokens = 0
num_original_tokens = 0 num_original_tokens = 0
val_dataloader = torch.utils.data.DataLoader( val_dataloader = torch.utils.data.DataLoader(
...@@ -450,7 +471,9 @@ def get_eval_data(args): ...@@ -450,7 +471,9 @@ def get_eval_data(args):
before = tokenizer.num_tokens before = tokenizer.num_tokens
after = before after = before
while after % mpu.get_model_parallel_world_size() != 0: multiple = args.make_vocab_size_divisible_by * \
mpu.get_model_parallel_world_size()
while (after % multiple) != 0:
after += 1 after += 1
print_rank_0('> padded vocab (size: {}) with {} dummy tokens (new size: {})'. print_rank_0('> padded vocab (size: {}) with {} dummy tokens (new size: {})'.
format(before, after - before, after)) format(before, after - before, after))
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
import os import os
import random import random
import json
import copy
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -83,9 +85,10 @@ def setup_model(args): ...@@ -83,9 +85,10 @@ def setup_model(args):
return model return model
def get_batch(context_tokens, device, args): def get_batch(context_tokens, args):
tokens = context_tokens tokens = context_tokens
tokens = tokens.view(args.batch_size, -1).contiguous() tokens = tokens.view(args.batch_size, -1).contiguous()
device = args.device
tokens = tokens.to(device) tokens = tokens.to(device)
# Get the masks and postition ids. # Get the masks and postition ids.
...@@ -108,8 +111,8 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): ...@@ -108,8 +111,8 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
if top_p > 0.0: if top_p > 0.0:
#convert to 1D #convert to 1D
logits=logits.view(logits.size()[1]).contiguous() # logits=logits.view(logits.size()[1]).contiguous()
sorted_logits, sorted_indices = torch.sort(logits, descending=True) sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold # Remove tokens with cumulative probability above the threshold
...@@ -117,16 +120,33 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): ...@@ -117,16 +120,33 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
# Shift the indices to the right to keep also the first token above the threshold # Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0 sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove] for i in range(sorted_indices.size(0)):
logits[indices_to_remove] = filter_value indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = filter_value
#going back to 2D #going back to 2D
logits=logits.view(1, -1).contiguous() # logits=logits.view(1, -1).contiguous()
return logits return logits
def generate_samples_input_from_file(model, tokenizer, args):
def generate_samples(model, tokenizer, args, device): if args.sample_input_file == "":
if mpu.get_model_parallel_rank() == 0:
print("args.sample_input_file CAN NOT BE empty!\n")
return
if mpu.get_model_parallel_rank() == 0:
fname = open(args.sample_input_file, "r")
all_raw_text = fname.readlines()
input_count = len(all_raw_text)
input_pos = 0
if args.sample_output_file == "":
print("Argument: sample-output-file can't be empty, setting it to\n")
print("\t args.sample_input_file.out")
args.sample_output_file = args.sample_input_file+".out"
fname_out = open(args.sample_output_file, "w+")
context_count=0 context_count=0
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
...@@ -135,6 +155,74 @@ def generate_samples(model, tokenizer, args, device): ...@@ -135,6 +155,74 @@ def generate_samples(model, tokenizer, args, device):
terminate_runs=0 terminate_runs=0
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
raw_text = all_raw_text[input_pos]
input_pos += 1
if input_pos == input_count:
raw_text = "stop"
if "stop" in raw_text:
terminate_runs = 1
else:
context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization
context_length = len(context_tokens)
if context_length >=args.seq_length//2:
print("\nContext length", context_length, \
"\nPlease give smaller context (half of the sequence length)!")
continue
else:
context_tokens = tokenizer.EncodeAsIds("EMPTY TEXT").tokenization
context_length = len(context_tokens)
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item()
if terminate_runs == 1:
return
start_time = time.time()
token_stream = get_token_stream(model, [context_tokens], tokenizer, args)
for counter, decode_tokens in enumerate(token_stream):
# token_end = decode_tokens.find("<|endoftext|>")
# if token_end > 0:
# break
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
if mpu.get_model_parallel_rank() == 0:
os.system('clear')
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)[len(raw_text):]
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
fname_out.write("\nContext:")
fname_out.write(raw_text)
fname_out.write("\n\nMegatron-LM:")
fname_out.write(trim_decode_tokens)
#fname_out.write(trim_decode_tokens.replace("\n", "\n\n"))
fname_out.write("\n")
raw_text = None
torch.distributed.barrier(group=mpu.get_model_parallel_group())
context_count += 1
def generate_samples_interactive(model, tokenizer, args):
print_frequency = 24
context_count=0
model.eval()
with torch.no_grad():
while True:
torch.distributed.barrier(group=mpu.get_model_parallel_group())
terminate_runs=0
if mpu.get_model_parallel_rank() == 0:
os.system('clear')
raw_text = input("\nContext prompt (stop to exit) >>> ") raw_text = input("\nContext prompt (stop to exit) >>> ")
while not raw_text: while not raw_text:
print('Prompt should not be empty!') print('Prompt should not be empty!')
...@@ -161,60 +249,179 @@ def generate_samples(model, tokenizer, args, device): ...@@ -161,60 +249,179 @@ def generate_samples(model, tokenizer, args, device):
if terminate_runs == 1: if terminate_runs == 1:
return return
pad_id = tokenizer.get_command('pad').Id
if context_length < args.seq_length:
context_tokens.extend([pad_id] * (args.seq_length - context_length))
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor([context_length])
torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group())
torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group())
context_length = context_length_tensor[0].item()
tokens, attention_mask, position_ids=get_batch(context_tokens_tensor, device, args)
start_time = time.time() start_time = time.time()
token_stream = get_token_stream(model, [context_tokens], tokenizer, args)
for counter, decode_tokens in enumerate(token_stream):
# token_end = decode_tokens.find("<|endoftext|>")
# if token_end > 0:
# break
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
if mpu.get_model_parallel_rank() == 0 and counter % print_frequency == 0:
os.system('clear')
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)[len(raw_text):]
#print("\nGPT2:", trim_decode_tokens, flush=True)
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
counter = 0
org_context_length = context_length
while counter < (org_context_length + args.out_seq_length):
logits = model(tokens, position_ids, attention_mask)
logits = logits[:, context_length - 1, :] / args.temperature
logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1)
tokens[0, context_length] = prev[0]
context_length += 1
counter += 1
output_tokens_list = tokens.view(-1).contiguous()
decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist())
token_end = decode_tokens.find("<|endoftext|>")
if mpu.get_model_parallel_rank() == 0 and (counter % 16 == 0 or token_end != -1):
os.system('clear')
print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True)
trim_decode_tokens = decode_tokens[len(raw_text):decode_tokens.find("<|endoftext|>")]
print("\nGPT2:", trim_decode_tokens, flush=True)
if token_end != -1:
break
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
os.system('clear') os.system('clear')
print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) #print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True) print("\nContext:", raw_text, flush=True)
output_tokens_list = tokens.view(-1).contiguous() trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)[len(raw_text):]
decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist()) #print("\nGPT2:", trim_decode_tokens, flush=True)
trim_decode_tokens = decode_tokens[len(raw_text):decode_tokens.find("<|endoftext|>")] #print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
print("\nGPT2:", trim_decode_tokens, flush=True) print("\nMegatron-LM:", trim_decode_tokens, flush=True)
raw_text = None raw_text = None
torch.distributed.barrier(group=mpu.get_model_parallel_group()) torch.distributed.barrier(group=mpu.get_model_parallel_group())
context_count += 1 context_count += 1
if mpu.get_model_parallel_rank() == 0:
input("\nPress any key to continue >>>")
def generate_samples_unconditional(model, tokenizer, args):
num_samples = args.num_samples
context_tokens = [[tokenizer.get_command('pad').Id] for _ in range(args.batch_size)]
samples = []
# with open(args.genfile, 'w') as f:
ctr = 0
while True:
start_time = time.time()
for token_stream in get_token_stream(model, copy.deepcopy(context_tokens), tokenizer, args):
pass
# token_stream = list(get_token_stream(model, copy.deepcopy(context_tokens), tokenizer, args))
if ctr%args.log_interval == 0:
print('Avg s/batch:', (time.time()-start_time)/min(args.log_interval, ctr+1))
start_time = time.time()
length = len(token_stream)
token_batch = token_stream[0].cpu().numpy().tolist()
length_batch = token_stream[1].cpu().numpy().tolist()
for tokens, length in zip(token_batch, length_batch):
tokens = tokens[1:length-1]
text = tokenizer.DecodeIds(tokens)
is_finished = length < args.seq_length - 1
datum = {'text': text, 'length': length-1, 'finished': is_finished}
yield datum
ctr += 1
if ctr >= num_samples:
break
if ctr >= num_samples:
break
def write_and_generate_samples_unconditional(model, tokenizer, args):
assert args.genfile is not None
with open(args.genfile, 'w') as f:
for datum in generate_samples_unconditional(model, tokenizer, args):
f.write(json.dumps(datum)+'\n')
def pad_batch(batch, tokenizer, args):
pad_id = tokenizer.get_command('pad').Id
context_lengths = []
for tokens in batch:
context_length = len(tokens)
if context_length < args.seq_length:
tokens.extend([pad_id]*(args.seq_length-context_length))
context_lengths.append(context_length)
return batch, context_lengths
def get_token_stream(model, context_tokens, tokenizer, args):
pad_id = tokenizer.get_command('pad').Id
# context_length = len(context_tokens)
# if context_length < args.seq_length:
# context_tokens = context_tokens + [pad_id] * (args.seq_length - context_length)
context_tokens, context_lengths = pad_batch(context_tokens, tokenizer, args)
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor(context_lengths)
# context_length_tensor = torch.cuda.LongTensor([context_length])
torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group())
torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group())
context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids=get_batch(context_tokens_tensor, args)
counter = 0
org_context_length = context_length
layer_past = None
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids, tokenizer, args)
for tokens, lengths in batch_token_iterator:
context_length += 1
yield tokens[:, :context_length], lengths
def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
return (1-boolean)*val1 + boolean*val2
def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None):
model.eval()
with torch.no_grad():
context_length = context_lengths.min().item()
eos_id = tokenizer.get_command('eos').Id
counter = 0
org_context_length = context_length
layer_past = None
batch_size = context_tokens.size(0)
is_done = torch.zeros([batch_size]).byte().cuda()
tokens = context_tokens
if maxlen is None:
maxlen = args.seq_length - 1
if maxlen > (org_context_length + args.out_seq_length):
maxlen = org_context_length + args.out_seq_length
lengths = torch.ones([batch_size]).long().cuda()*maxlen
while context_length <= (maxlen):
if args.recompute:
logits = model(tokens, position_ids, attention_mask)
logits = logits[:, context_length - 1, :]
else:
if counter == 0:
tokens2use = tokens[:, :context_length]
positions2use = position_ids[:, :context_length]
else:
tokens2use = tokens[:, context_length - 1].view(batch_size, -1)
positions2use = position_ids[:, context_length - 1].view(batch_size, -1)
logits, layer_past = model(tokens2use, positions2use, attention_mask, layer_past=layer_past, get_present=True)
logits = logits[:, -1].view(batch_size,-1).contiguous()
if args.greedy:
prev = torch.argmax(logits, dim=-1).view(-1)
else:
logits /= args.temperature
logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
print_logits = []
for p in prev:
print_logits.append([logits[i, p].item() for i in range(batch_size)])
started = context_lengths <= context_length
tokens[:, context_length] = switch(tokens[:, context_length].view(-1), prev, started)
context_length += 1
counter += 1
done_token = (prev == eos_id).byte()
just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length
was_done = is_done
is_done = is_done | done_token
done = torch.all(is_done)
yield tokens, lengths
if done:
break
def prepare_tokenizer(args): def prepare_tokenizer(args):
...@@ -232,8 +439,11 @@ def prepare_tokenizer(args): ...@@ -232,8 +439,11 @@ def prepare_tokenizer(args):
args.eod_token = tokenizer.get_command('eos').Id args.eod_token = tokenizer.get_command('eos').Id
after = tokenizer.num_tokens after = tokenizer.num_tokens
while after % mpu.get_model_parallel_world_size() != 0: multiple = args.make_vocab_size_divisible_by * \
after += 1 mpu.get_model_parallel_world_size()
if multiple != 0:
while (after % multiple) != 0:
after += 1
args.vocab_size = after args.vocab_size = after
print("prepare tokenizer done", flush=True) print("prepare tokenizer done", flush=True)
...@@ -267,10 +477,19 @@ def main(): ...@@ -267,10 +477,19 @@ def main():
model = setup_model(args) model = setup_model(args)
#setting default batch size to 1 #setting default batch size to 1
args.batch_size = 1 # args.batch_size = 1
args.device = torch.cuda.current_device()
#generate samples #generate samples
generate_samples(model, tokenizer, args, torch.cuda.current_device()) if args.num_samples == 0:
args.batch_size = 1
if args.sample_input_file != "":
generate_samples_input_from_file(model, tokenizer, args)
else:
generate_samples_interactive(model, tokenizer, args)
else:
write_and_generate_samples_unconditional(model, tokenizer, args)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -18,36 +18,48 @@ import torch ...@@ -18,36 +18,48 @@ import torch
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
import math import math
from utils import print_rank_0
class AnnealingLR(_LRScheduler): class AnnealingLR(_LRScheduler):
"""Anneals the learning rate from start to zero along a cosine curve.""" """Anneals the learning rate"""
DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None'] DECAY_STYLES = ['linear', 'cosine', 'constant', 'None']
def __init__(self, optimizer, start_lr, warmup_iter, num_iters, decay_style=None, last_iter=-1): def __init__(self, optimizer, start_lr, warmup_iter, num_iters,
decay_style=None, last_iter=-1, min_lr=0.0,
use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False):
self.optimizer = optimizer self.optimizer = optimizer
self.start_lr = start_lr self.start_lr = start_lr
self.min_lr = min_lr
self.warmup_iter = warmup_iter self.warmup_iter = warmup_iter
self.num_iters = last_iter + 1 self.num_iters = last_iter + 1
self.end_iter = num_iters self.end_iter = num_iters
self.decay_style = decay_style.lower() if isinstance(decay_style, str) else None self.decay_style = decay_style.lower() if isinstance(decay_style, str) \
else None
self.override_lr_scheduler = override_lr_scheduler
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
if self.override_lr_scheduler:
assert not self.use_checkpoint_lr_scheduler, 'both override and '\
'use-checkpoint are set.'
self.step(self.num_iters) self.step(self.num_iters)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('learning rate decaying', decay_style) print('learning rate decaying', decay_style)
def get_lr(self): def get_lr(self):
# https://openreview.net/pdf?id=BJYwwY9ll pg. 4 # https://openreview.net/pdf?id=BJYwwY9ll pg. 4
num_iters_ = min(self.num_iters, self.end_iter - self.warmup_iter)
if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter: if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter:
return float(self.start_lr) * self.num_iters / self.warmup_iter return float(self.start_lr) * num_iters_ / self.warmup_iter
else: else:
if self.decay_style == self.DECAY_STYLES[0]: if self.decay_style == self.DECAY_STYLES[0]:
return self.start_lr*((self.end_iter-(self.num_iters-self.warmup_iter))/self.end_iter) lr = self.start_lr * ((self.end_iter - (num_iters_ - self.warmup_iter)) / self.end_iter)
elif self.decay_style == self.DECAY_STYLES[1]: elif self.decay_style == self.DECAY_STYLES[1]:
return self.start_lr / 2.0 * (math.cos(math.pi * (self.num_iters - self.warmup_iter) / self.end_iter) + 1) lr = self.start_lr / 2.0 * (math.cos(math.pi * (num_iters_ - self.warmup_iter) / self.end_iter) + 1)
elif self.decay_style == self.DECAY_STYLES[2]:
#TODO: implement exponential decay
return self.start_lr
else: else:
return self.start_lr lr = self.start_lr
return max(lr, self.min_lr)
def step(self, step_num=None): def step(self, step_num=None):
if step_num is None: if step_num is None:
...@@ -63,14 +75,38 @@ class AnnealingLR(_LRScheduler): ...@@ -63,14 +75,38 @@ class AnnealingLR(_LRScheduler):
'warmup_iter': self.warmup_iter, 'warmup_iter': self.warmup_iter,
'num_iters': self.num_iters, 'num_iters': self.num_iters,
'decay_style': self.decay_style, 'decay_style': self.decay_style,
'end_iter': self.end_iter 'end_iter': self.end_iter,
'min_lr': self.min_lr
} }
return sd return sd
def check_and_set_(self, cls_value, sd_value, name):
if self.override_lr_scheduler:
print_rank_0(' > overriding {} value to {}'.format(name, cls_value))
return cls_value
else:
if not self.use_checkpoint_lr_scheduler:
assert cls_value == sd_value, 'AnnealingLR: class input value' \
'and checkpoint values for {} do not match'.format(name)
print_rank_0(' > using checkpoint value {} for {}'.format(sd_value,
name))
return sd_value
def load_state_dict(self, sd): def load_state_dict(self, sd):
self.start_lr = sd['start_lr']
self.warmup_iter = sd['warmup_iter'] self.start_lr = self.check_and_set_(self.start_lr, sd['start_lr'],
'learning rate')
self.min_lr = self.check_and_set_(self.min_lr, sd['min_lr'],
'minimum learning rate')
self.warmup_iter = self.check_and_set_(self.warmup_iter,
sd['warmup_iter'],
'warmup iterations')
self.end_iter = self.check_and_set_(self.end_iter, sd['end_iter'],
'total number of iterations')
self.decay_style = self.check_and_set_(self.decay_style,
sd['decay_style'],
'decay style')
self.num_iters = sd['num_iters'] self.num_iters = sd['num_iters']
self.end_iter = sd['end_iter']
self.decay_style = sd['decay_style']
self.step(self.num_iters) self.step(self.num_iters)
...@@ -65,6 +65,14 @@ class GPT2Model(torch.nn.Module): ...@@ -65,6 +65,14 @@ class GPT2Model(torch.nn.Module):
# Position embedding (serial). # Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.position_embeddings = torch.nn.Embedding(max_sequence_length,
hidden_size) hidden_size)
# Token type embedding.
# Add this as an optional field that can be added through
# method call so we can load a pretrain model without
# token types and add them as needed.
self.tokentype_embeddings = None
self.hidden_size = hidden_size
# Initialize the position embeddings. # Initialize the position embeddings.
init_method(self.position_embeddings.weight) init_method(self.position_embeddings.weight)
...@@ -80,18 +88,39 @@ class GPT2Model(torch.nn.Module): ...@@ -80,18 +88,39 @@ class GPT2Model(torch.nn.Module):
checkpoint_activations, checkpoint_activations,
checkpoint_num_layers) checkpoint_num_layers)
def forward(self, input_ids, position_ids, attention_mask):
def add_tokentype_embeddings(self, num_tokentypes):
if self.tokentype_embeddings is not None:
raise Exception('tokentype embeddings is already initialized')
if torch.distributed.get_rank() == 0:
print('adding embedding for {} tokentypes'.format(num_tokentypes),
flush=True)
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
self.hidden_size)
def forward(self, input_ids, position_ids, attention_mask,
layer_past=None, get_present=False, tokentype_ids=None):
# Embeddings. # Embeddings.
words_embeddings = self.word_embeddings(input_ids) words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings embeddings = words_embeddings + position_embeddings
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
else:
assert self.tokentype_embeddings is None
# Dropout. # Dropout.
embeddings = self.embedding_dropout(embeddings) embeddings = self.embedding_dropout(embeddings)
# Transformer. # Transformer.
transformer_output = self.transformer(embeddings, attention_mask) transformer_output = self.transformer(embeddings, attention_mask,
layer_past=layer_past,
get_present=get_present)
if get_present:
transformer_output, presents = transformer_output
# Parallel logits. # Parallel logits.
transformer_output_parallel = mpu.copy_to_model_parallel_region( transformer_output_parallel = mpu.copy_to_model_parallel_region(
...@@ -100,9 +129,12 @@ class GPT2Model(torch.nn.Module): ...@@ -100,9 +129,12 @@ class GPT2Model(torch.nn.Module):
self.word_embeddings.weight) self.word_embeddings.weight)
if self.parallel_output: if self.parallel_output:
return logits_parallel output = logits_parallel
else:
return mpu.gather_from_model_parallel_region(logits_parallel) output = mpu.gather_from_model_parallel_region(logits_parallel)
if get_present:
output = [output, presents]
return output
def gpt2_get_params_for_weight_decay_optimization(module): def gpt2_get_params_for_weight_decay_optimization(module):
......
...@@ -98,7 +98,7 @@ class GPT2ParallelSelfAttention(torch.nn.Module): ...@@ -98,7 +98,7 @@ class GPT2ParallelSelfAttention(torch.nn.Module):
tensor = tensor.view(*new_tensor_shape) tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3) return tensor.permute(0, 2, 1, 3)
def forward(self, hidden_states, ltor_mask): def forward(self, hidden_states, ltor_mask, layer_past=None, get_present=False):
# hidden_states: [b, s, h] # hidden_states: [b, s, h]
# ltor_mask: [1, 1, s, s] # ltor_mask: [1, 1, s, s]
...@@ -112,13 +112,24 @@ class GPT2ParallelSelfAttention(torch.nn.Module): ...@@ -112,13 +112,24 @@ class GPT2ParallelSelfAttention(torch.nn.Module):
query_layer = self._transpose_for_scores(mixed_query_layer) query_layer = self._transpose_for_scores(mixed_query_layer)
key_layer = self._transpose_for_scores(mixed_key_layer) key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer) value_layer = self._transpose_for_scores(mixed_value_layer)
if layer_past is not None:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=-2)
value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=-2)
present = (key_layer, value_layer)
# Raw attention scores. [b, np, s, s] # Raw attention scores. [b, np, s, s]
attention_scores = torch.matmul(query_layer, norm_factor = math.sqrt(math.sqrt(self.hidden_size_per_attention_head))
key_layer.transpose(-1, -2)) attention_scores = torch.matmul(query_layer/norm_factor,
attention_scores = attention_scores / math.sqrt( key_layer.transpose(-1, -2)/norm_factor)
self.hidden_size_per_attention_head)
# Apply the left to right attention mask. # Apply the left to right attention mask.
if get_present:
with torch.no_grad():
if layer_past is not None:
ltor_mask = ltor_mask[...,attention_scores.size(3)-1, :attention_scores.size(3)].unsqueeze(2)
else:
ltor_mask = ltor_mask[...,:attention_scores.size(3), :attention_scores.size(3)]
attention_scores = torch.mul(attention_scores, ltor_mask) - \ attention_scores = torch.mul(attention_scores, ltor_mask) - \
10000.0 * (1.0 - ltor_mask) 10000.0 * (1.0 - ltor_mask)
...@@ -143,6 +154,9 @@ class GPT2ParallelSelfAttention(torch.nn.Module): ...@@ -143,6 +154,9 @@ class GPT2ParallelSelfAttention(torch.nn.Module):
output = self.dense(context_layer) output = self.dense(context_layer)
output = self.output_dropout(output) output = self.output_dropout(output)
if get_present:
output = [output, present]
return output return output
...@@ -268,14 +282,16 @@ class GPT2ParallelTransformerLayer(torch.nn.Module): ...@@ -268,14 +282,16 @@ class GPT2ParallelTransformerLayer(torch.nn.Module):
init_method, init_method,
output_layer_init_method=output_layer_init_method) output_layer_init_method=output_layer_init_method)
def forward(self, hidden_states, ltor_mask): def forward(self, hidden_states, ltor_mask, layer_past=None, get_present=False):
# hidden_states: [b, s, h] # hidden_states: [b, s, h]
# ltor_mask: [1, 1, s, s] # ltor_mask: [1, 1, s, s]
# Layer norm at the begining of the transformer layer. # Layer norm at the begining of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states) layernorm_output = self.input_layernorm(hidden_states)
# Self attention. # Self attention.
attention_output = self.attention(layernorm_output, ltor_mask) attention_output = self.attention(layernorm_output, ltor_mask, layer_past=layer_past, get_present=get_present)
if get_present:
attention_output, presents = attention_output
# Residual connection. # Residual connection.
layernorm_input = hidden_states + attention_output layernorm_input = hidden_states + attention_output
# Layer norm post the self attention. # Layer norm post the self attention.
...@@ -285,6 +301,9 @@ class GPT2ParallelTransformerLayer(torch.nn.Module): ...@@ -285,6 +301,9 @@ class GPT2ParallelTransformerLayer(torch.nn.Module):
# Second residual connection. # Second residual connection.
output = layernorm_input + mlp_output output = layernorm_input + mlp_output
if get_present:
output = [output, presents]
return output return output
...@@ -376,7 +395,7 @@ class GPT2ParallelTransformer(torch.nn.Module): ...@@ -376,7 +395,7 @@ class GPT2ParallelTransformer(torch.nn.Module):
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
def forward(self, hidden_states, attention_mask): def forward(self, hidden_states, attention_mask, layer_past=None, get_present=False):
def custom(start, end): def custom(start, end):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -387,7 +406,7 @@ class GPT2ParallelTransformer(torch.nn.Module): ...@@ -387,7 +406,7 @@ class GPT2ParallelTransformer(torch.nn.Module):
return x_ return x_
return custom_forward return custom_forward
if self.checkpoint_activations: if self.checkpoint_activations and not get_present:
l = 0 l = 0
num_layers = len(self.layers) num_layers = len(self.layers)
chunk_length = self.checkpoint_num_layers chunk_length = self.checkpoint_num_layers
...@@ -396,11 +415,20 @@ class GPT2ParallelTransformer(torch.nn.Module): ...@@ -396,11 +415,20 @@ class GPT2ParallelTransformer(torch.nn.Module):
hidden_states, attention_mask) hidden_states, attention_mask)
l += chunk_length l += chunk_length
else: else:
for layer in self.layers: presents = []
hidden_states = layer(hidden_states, attention_mask) for i, layer in enumerate(self.layers):
past = None
if layer_past is not None:
past = layer_past[i]
hidden_states = layer(hidden_states, attention_mask, layer_past=past, get_present=get_present)
if get_present:
hidden_states, present = hidden_states
presents.append(present)
# Final layer norm. # Final layer norm.
output = self.final_layernorm(hidden_states) output = self.final_layernorm(hidden_states)
if get_present:
output = [output, presents]
return output return output
......
...@@ -46,7 +46,8 @@ from utils import report_memory ...@@ -46,7 +46,8 @@ from utils import report_memory
from utils import print_args from utils import print_args
from utils import print_params_min_max_norm from utils import print_params_min_max_norm
from utils import print_rank_0 from utils import print_rank_0
from utils import enable_adlr_autoresume
from utils import check_adlr_autoresume_termination
def get_model(args): def get_model(args):
"""Build the model.""" """Build the model."""
...@@ -114,7 +115,8 @@ def get_optimizer(model, args): ...@@ -114,7 +115,8 @@ def get_optimizer(model, args):
param.model_parallel = False param.model_parallel = False
# Use Adam. # Use Adam.
optimizer = Adam(param_groups, betas = (0.9, 0.999)
optimizer = Adam(param_groups, betas=betas,
lr=args.lr, weight_decay=args.weight_decay) lr=args.lr, weight_decay=args.weight_decay)
# Wrap into fp16 optimizer. # Wrap into fp16 optimizer.
...@@ -145,7 +147,10 @@ def get_learning_rate_scheduler(optimizer, args): ...@@ -145,7 +147,10 @@ def get_learning_rate_scheduler(optimizer, args):
warmup_iter=warmup_iter, warmup_iter=warmup_iter,
num_iters=num_iters, num_iters=num_iters,
decay_style=args.lr_decay_style, decay_style=args.lr_decay_style,
last_iter=init_step) last_iter=init_step,
min_lr=args.min_lr,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
override_lr_scheduler=args.override_lr_scheduler)
return lr_scheduler return lr_scheduler
...@@ -299,7 +304,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler, ...@@ -299,7 +304,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
def train(model, optimizer, lr_scheduler, def train(model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator, timers, args): train_data_iterator, val_data_iterator, timers, args, writer):
"""Train the model.""" """Train the model."""
# Turn on training mode which enables dropout. # Turn on training mode which enables dropout.
...@@ -326,15 +331,37 @@ def train(model, optimizer, lr_scheduler, ...@@ -326,15 +331,37 @@ def train(model, optimizer, lr_scheduler,
iteration += 1 iteration += 1
# Update losses. # Update losses.
total_lm_loss += lm_loss.data.detach().float() current_lm_loss = lm_loss.data.detach().float()
total_nsp_loss += nsp_loss.data.detach().float() current_nsp_loss = nsp_loss.data.detach().float()
total_lm_loss += current_lm_loss
total_nsp_loss += current_nsp_loss
# Logging. # Logging.
timers_to_log = ['forward', 'backward', 'optimizer',
'batch generator', 'data loader']
learning_rate = optimizer.param_groups[0]['lr']
if writer and args.rank == 0:
writer.add_scalar('learning_rate', learning_rate, iteration)
writer.add_scalar('lm_loss', current_lm_loss, iteration)
writer.add_scalar('nsp_loss', current_nsp_loss, iteration)
if args.fp16:
writer.add_scalar('loss_scale', optimizer.loss_scale, iteration)
normalizer = iteration % args.log_interval
if normalizer == 0:
normalizer = args.log_interval
timers.write(timers_to_log, writer, iteration,
normalizer=normalizer)
if iteration % args.log_interval == 0: if iteration % args.log_interval == 0:
learning_rate = optimizer.param_groups[0]['lr']
avg_nsp_loss = total_nsp_loss.item() / args.log_interval avg_nsp_loss = total_nsp_loss.item() / args.log_interval
avg_lm_loss = total_lm_loss.item() / args.log_interval avg_lm_loss = total_lm_loss.item() / args.log_interval
elapsed_time = timers('interval time').elapsed() elapsed_time = timers('interval time').elapsed()
if writer and args.rank == 0:
writer.add_scalar('iteration_time',
elapsed_time / args.log_interval, iteration)
log_string = ' iteration {:8d}/{:8d} |'.format(iteration, log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
args.train_iters) args.train_iters)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
...@@ -351,9 +378,13 @@ def train(model, optimizer, lr_scheduler, ...@@ -351,9 +378,13 @@ def train(model, optimizer, lr_scheduler,
if report_memory_flag: if report_memory_flag:
report_memory('after {} iterations'.format(iteration)) report_memory('after {} iterations'.format(iteration))
report_memory_flag = False report_memory_flag = False
timers.log(['forward', 'backward', 'optimizer', 'batch generator', timers.log(timers_to_log, normalizer=args.log_interval)
'data loader'],
normalizer=args.log_interval) # Autoresume
if (iteration % args.adlr_autoresume_interval == 0) and args.adlr_autoresume:
check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler, args)
# Checkpointing # Checkpointing
if args.save and args.save_interval and iteration % args.save_interval == 0: if args.save and args.save_interval and iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args) save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
...@@ -361,8 +392,8 @@ def train(model, optimizer, lr_scheduler, ...@@ -361,8 +392,8 @@ def train(model, optimizer, lr_scheduler,
# Evaluation # Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid: if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid:
prefix = 'iteration {}'.format(iteration) prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results( evaluate_and_print_results(prefix, val_data_iterator, model, args,
prefix, val_data_iterator, model, args, timers, False) writer, iteration, timers, False)
if args.exit_interval and iteration % args.exit_interval == 0: if args.exit_interval and iteration % args.exit_interval == 0:
torch.distributed.barrier() torch.distributed.barrier()
...@@ -413,7 +444,8 @@ def evaluate(data_iterator, model, args, timers, verbose = False): ...@@ -413,7 +444,8 @@ def evaluate(data_iterator, model, args, timers, verbose = False):
def evaluate_and_print_results(prefix, data_iterator, model, def evaluate_and_print_results(prefix, data_iterator, model,
args, timers, verbose=False): args, writer, iteration,
timers, verbose=False):
"""Helper function to evaluate and dump results on screen.""" """Helper function to evaluate and dump results on screen."""
lm_loss, nsp_loss = evaluate(data_iterator, model, lm_loss, nsp_loss = evaluate(data_iterator, model,
args, timers, verbose) args, timers, verbose)
...@@ -428,6 +460,11 @@ def evaluate_and_print_results(prefix, data_iterator, model, ...@@ -428,6 +460,11 @@ def evaluate_and_print_results(prefix, data_iterator, model,
print_rank_0(string) print_rank_0(string)
print_rank_0('-' * length) print_rank_0('-' * length)
if writer and args.rank == 0:
writer.add_scalar('val_lm_loss', lm_loss, iteration)
writer.add_scalar('val_nsp_loss', nsp_loss, iteration)
writer.add_scalar('val_total_loss', val_loss, iteration)
return val_loss return val_loss
...@@ -471,7 +508,8 @@ def get_train_val_test_data(args): ...@@ -471,7 +508,8 @@ def get_train_val_test_data(args):
# Data loader only on rank 0 of each model parallel group. # Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
data_config = configure_data() data_config = configure_data()
data_config.set_defaults(data_set_type='BERT', transpose=False) ds_type = 'BERT'
data_config.set_defaults(data_set_type=ds_type, transpose=False)
(train_data, val_data, test_data), tokenizer = data_config.apply(args) (train_data, val_data, test_data), tokenizer = data_config.apply(args)
before = tokenizer.num_tokens before = tokenizer.num_tokens
after = before after = before
...@@ -514,11 +552,27 @@ def main(): ...@@ -514,11 +552,27 @@ def main():
# Arguments. # Arguments.
args = get_args() args = get_args()
writer = None
if args.tensorboard_dir and args.rank == 0:
try:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir = args.tensorboard_dir)
except ModuleNotFoundError:
print_rank_0('WARNING: TensorBoard writing requested but is not '
'available (are you using PyTorch 1.1.0 or later?), '
'no TensorBoard logs will be written.')
writer = None
# Pytorch distributed. # Pytorch distributed.
initialize_distributed(args) initialize_distributed(args)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('Pretrain BERT model') print('Pretrain BERT model')
print_args(args) print_args(args, writer)
# Autoresume.
torch.distributed.barrier()
if args.adlr_autoresume:
enable_adlr_autoresume(args)
# Random seeds for reproducability. # Random seeds for reproducability.
set_random_seed(args.seed) set_random_seed(args.seed)
...@@ -534,11 +588,15 @@ def main(): ...@@ -534,11 +588,15 @@ def main():
if train_data is not None: if train_data is not None:
train_data.batch_sampler.start_iter = args.iteration % \ train_data.batch_sampler.start_iter = args.iteration % \
len(train_data) len(train_data)
print_rank_0('setting training data start iteration to {}'.
format(train_data.batch_sampler.start_iter))
if val_data is not None: if val_data is not None:
start_iter_val = (args.train_iters // args.save_interval) * \ start_iter_val = (args.iteration // args.eval_interval) * \
args.eval_interval args.eval_iters
val_data.batch_sampler.start_iter = start_iter_val % \ val_data.batch_sampler.start_iter = start_iter_val % \
len(val_data) len(val_data)
print_rank_0('setting validation data start iteration to {}'.
format(val_data.batch_sampler.start_iter))
if train_data is not None: if train_data is not None:
train_data_iterator = iter(train_data) train_data_iterator = iter(train_data)
...@@ -556,11 +614,12 @@ def main(): ...@@ -556,11 +614,12 @@ def main():
lr_scheduler, lr_scheduler,
train_data_iterator, train_data_iterator,
val_data_iterator, val_data_iterator,
timers, args) timers, args, writer)
if args.do_valid: if args.do_valid:
prefix = 'the end of training for val data' prefix = 'the end of training for val data'
val_loss = evaluate_and_print_results(prefix, val_data_iterator, val_loss = evaluate_and_print_results(prefix, val_data_iterator,
model, args, timers, False) model, args, writer, iteration,
timers, False)
if args.save and iteration != 0: if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args) save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
...@@ -574,7 +633,7 @@ def main(): ...@@ -574,7 +633,7 @@ def main():
# Run on test data. # Run on test data.
prefix = 'the end of training for test data' prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, test_data_iterator, evaluate_and_print_results(prefix, test_data_iterator,
model, args, timers, True) model, args, None, 0, timers, True)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -15,10 +15,6 @@ ...@@ -15,10 +15,6 @@
"""Pretrain GPT2""" """Pretrain GPT2"""
# Flag to use Pytorch ddp which uses overlapping communication and computation.
USE_TORCH_DDP = False
from datetime import datetime from datetime import datetime
import os import os
import random import random
...@@ -33,10 +29,7 @@ from fp16 import FP16_Optimizer ...@@ -33,10 +29,7 @@ from fp16 import FP16_Optimizer
from learning_rates import AnnealingLR from learning_rates import AnnealingLR
from model import GPT2Model from model import GPT2Model
from model import gpt2_get_params_for_weight_decay_optimization from model import gpt2_get_params_for_weight_decay_optimization
if USE_TORCH_DDP: from model import DistributedDataParallel as LocalDDP
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
else:
from model import DistributedDataParallel as DDP
import mpu import mpu
from apex.optimizers import FusedAdam as Adam from apex.optimizers import FusedAdam as Adam
from utils import Timers from utils import Timers
...@@ -46,10 +39,11 @@ from utils import report_memory ...@@ -46,10 +39,11 @@ from utils import report_memory
from utils import print_args from utils import print_args
from utils import print_params_min_max_norm from utils import print_params_min_max_norm
from utils import print_rank_0 from utils import print_rank_0
from utils import enable_adlr_autoresume
from utils import check_adlr_autoresume_termination
from gpt2_data_loader import make_gpt2_dataloaders from gpt2_data_loader import make_gpt2_dataloaders
def get_model(args): def get_model(args):
"""Build the model.""" """Build the model."""
...@@ -79,12 +73,18 @@ def get_model(args): ...@@ -79,12 +73,18 @@ def get_model(args):
model = FP16_Module(model) model = FP16_Module(model)
# Wrap model for distributed training. # Wrap model for distributed training.
if USE_TORCH_DDP: if args.DDP_impl == 'torch':
i = torch.cuda.current_device() i = torch.cuda.current_device()
model = DDP(model, device_ids=[i], output_device=i, args.DDP_type = torch.nn.parallel.distributed.DistributedDataParallel
process_group=mpu.get_data_parallel_group()) model = args.DDP_type(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
elif args.DDP_impl == 'local':
args.DDP_type = LocalDDP
model = args.DDP_type(model)
else: else:
model = DDP(model) print_rank_0('Unknown DDP implementation specified: {}. '
'Exiting.'.format(args.DDP_impl))
exit()
return model return model
...@@ -93,7 +93,7 @@ def get_optimizer(model, args): ...@@ -93,7 +93,7 @@ def get_optimizer(model, args):
"""Set up the optimizer.""" """Set up the optimizer."""
# Build parameter groups (weight decay and non-decay). # Build parameter groups (weight decay and non-decay).
while isinstance(model, (DDP, FP16_Module)): while isinstance(model, (args.DDP_type, FP16_Module)):
model = model.module model = model.module
param_groups = gpt2_get_params_for_weight_decay_optimization(model) param_groups = gpt2_get_params_for_weight_decay_optimization(model)
...@@ -136,7 +136,10 @@ def get_learning_rate_scheduler(optimizer, args): ...@@ -136,7 +136,10 @@ def get_learning_rate_scheduler(optimizer, args):
warmup_iter=warmup_iter, warmup_iter=warmup_iter,
num_iters=num_iters, num_iters=num_iters,
decay_style=args.lr_decay_style, decay_style=args.lr_decay_style,
last_iter=init_step) last_iter=init_step,
min_lr=args.min_lr,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
override_lr_scheduler=args.override_lr_scheduler)
return lr_scheduler return lr_scheduler
...@@ -159,7 +162,8 @@ def setup_model_and_optimizer(args): ...@@ -159,7 +162,8 @@ def setup_model_and_optimizer(args):
def get_masks_and_position_ids(data, def get_masks_and_position_ids(data,
eod_token, eod_token,
reset_position_ids, reset_position_ids,
reset_attention_mask): reset_attention_mask,
eod_mask_loss):
# Extract batch size and sequence length. # Extract batch size and sequence length.
batch_size, seq_length = data.size() batch_size, seq_length = data.size()
...@@ -175,7 +179,8 @@ def get_masks_and_position_ids(data, ...@@ -175,7 +179,8 @@ def get_masks_and_position_ids(data,
# Loss mask. # Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
loss_mask[data == eod_token] = 0.0 if eod_mask_loss:
loss_mask[data == eod_token] = 0.0
# Position ids. # Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, position_ids = torch.arange(seq_length, dtype=torch.long,
...@@ -246,7 +251,8 @@ def get_batch(data_iterator, args, timers): ...@@ -246,7 +251,8 @@ def get_batch(data_iterator, args, timers):
tokens, tokens,
args.eod_token, args.eod_token,
args.reset_position_ids, args.reset_position_ids,
args.reset_attention_mask) args.reset_attention_mask,
args.eod_mask_loss)
# Convert # Convert
if args.fp16: if args.fp16:
attention_mask = attention_mask.half() attention_mask = attention_mask.half()
...@@ -292,7 +298,7 @@ def backward_step(optimizer, model, lm_loss, args, timers): ...@@ -292,7 +298,7 @@ def backward_step(optimizer, model, lm_loss, args, timers):
reduced_losses = lm_loss.view(1) reduced_losses = lm_loss.view(1)
torch.distributed.all_reduce(reduced_losses.data) torch.distributed.all_reduce(reduced_losses.data)
reduced_losses.data = reduced_losses.data / args.world_size reduced_losses.data = reduced_losses.data / args.world_size
if not USE_TORCH_DDP: if args.DDP_impl == 'local':
timers('allreduce').start() timers('allreduce').start()
model.allreduce_params(reduce_after=False, model.allreduce_params(reduce_after=False,
fp32_allreduce=args.fp32_allreduce) fp32_allreduce=args.fp32_allreduce)
...@@ -343,7 +349,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler, ...@@ -343,7 +349,7 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
def train(model, optimizer, lr_scheduler, def train(model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator, timers, args): train_data_iterator, val_data_iterator, timers, args, writer):
"""Train the model.""" """Train the model."""
# Turn on training mode which enables dropout. # Turn on training mode which enables dropout.
...@@ -369,13 +375,37 @@ def train(model, optimizer, lr_scheduler, ...@@ -369,13 +375,37 @@ def train(model, optimizer, lr_scheduler,
iteration += 1 iteration += 1
# Update losses. # Update losses.
total_lm_loss += lm_loss.data.detach().float() current_lm_loss = lm_loss.data.detach().float()
total_lm_loss += current_lm_loss
# Logging. # Logging.
if args.DDP_impl == 'torch':
timers_to_log = ['forward', 'backward', 'optimizer',
'batch generator', 'data loader']
else:
timers_to_log = ['forward', 'backward', 'allreduce', 'optimizer',
'batch generator', 'data loader']
learning_rate = optimizer.param_groups[0]['lr']
if writer and args.rank == 0:
writer.add_scalar('learning_rate', learning_rate, iteration)
writer.add_scalar('train_loss', current_lm_loss, iteration)
if args.fp16:
writer.add_scalar('loss_scale', optimizer.loss_scale, iteration)
normalizer = iteration % args.log_interval
if normalizer == 0:
normalizer = args.log_interval
timers.write(timers_to_log, writer, iteration,
normalizer=normalizer)
if iteration % args.log_interval == 0: if iteration % args.log_interval == 0:
learning_rate = optimizer.param_groups[0]['lr']
avg_lm_loss = total_lm_loss.item() / args.log_interval avg_lm_loss = total_lm_loss.item() / args.log_interval
elapsed_time = timers('interval time').elapsed() elapsed_time = timers('interval time').elapsed()
if writer and args.rank == 0:
writer.add_scalar('iteration_time',
elapsed_time / args.log_interval, iteration)
log_string = ' iteration {:8d}/{:8d} |'.format(iteration, log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
args.train_iters) args.train_iters)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
...@@ -390,14 +420,13 @@ def train(model, optimizer, lr_scheduler, ...@@ -390,14 +420,13 @@ def train(model, optimizer, lr_scheduler,
if report_memory_flag: if report_memory_flag:
report_memory('after {} iterations'.format(iteration)) report_memory('after {} iterations'.format(iteration))
report_memory_flag = False report_memory_flag = False
if USE_TORCH_DDP: timers.log(timers_to_log, normalizer=args.log_interval)
timers.log(['forward', 'backward', 'optimizer',
'batch generator', 'data loader'], # Autoresume
normalizer=args.log_interval) if (iteration % args.adlr_autoresume_interval == 0) and args.adlr_autoresume:
else: check_adlr_autoresume_termination(iteration, model, optimizer,
timers.log(['forward', 'backward', 'allreduce', 'optimizer', lr_scheduler, args)
'batch generator', 'data loader'],
normalizer=args.log_interval)
# Checkpointing # Checkpointing
if args.save and args.save_interval and iteration % args.save_interval == 0: if args.save and args.save_interval and iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args) save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
...@@ -405,8 +434,8 @@ def train(model, optimizer, lr_scheduler, ...@@ -405,8 +434,8 @@ def train(model, optimizer, lr_scheduler,
# Evaluation # Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid: if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid:
prefix = 'iteration {}'.format(iteration) prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results( evaluate_and_print_results(prefix, val_data_iterator, model, args,
prefix, val_data_iterator, model, args, timers, False) writer, iteration, timers, False)
if args.exit_interval and iteration % args.exit_interval == 0: if args.exit_interval and iteration % args.exit_interval == 0:
torch.distributed.barrier() torch.distributed.barrier()
...@@ -436,7 +465,7 @@ def evaluate(data_iterator, model, args, timers, verbose=False): ...@@ -436,7 +465,7 @@ def evaluate(data_iterator, model, args, timers, verbose=False):
# Forward evaluation. # Forward evaluation.
lm_loss = forward_step(data_iterator, model, args, timers) lm_loss = forward_step(data_iterator, model, args, timers)
# Reduce across processes. # Reduce across processes.
if isinstance(model, DDP): if isinstance(model, args.DDP_type):
torch.distributed.all_reduce(lm_loss.data) torch.distributed.all_reduce(lm_loss.data)
lm_loss.data = lm_loss.data / args.world_size lm_loss.data = lm_loss.data / args.world_size
...@@ -450,7 +479,8 @@ def evaluate(data_iterator, model, args, timers, verbose=False): ...@@ -450,7 +479,8 @@ def evaluate(data_iterator, model, args, timers, verbose=False):
def evaluate_and_print_results(prefix, data_iterator, model, def evaluate_and_print_results(prefix, data_iterator, model,
args, timers, verbose=False): args, writer, iteration,
timers, verbose=False):
"""Helper function to evaluate and dump results on screen.""" """Helper function to evaluate and dump results on screen."""
lm_loss = evaluate(data_iterator, model, args, timers, verbose) lm_loss = evaluate(data_iterator, model, args, timers, verbose)
lm_ppl = math.exp(min(20, lm_loss)) lm_ppl = math.exp(min(20, lm_loss))
...@@ -463,6 +493,10 @@ def evaluate_and_print_results(prefix, data_iterator, model, ...@@ -463,6 +493,10 @@ def evaluate_and_print_results(prefix, data_iterator, model,
print_rank_0(string) print_rank_0(string)
print_rank_0('-' * length) print_rank_0('-' * length)
if writer and args.rank == 0:
writer.add_scalar('val_loss', lm_loss, iteration)
writer.add_scalar('val_ppl', lm_ppl, iteration)
return lm_loss return lm_loss
...@@ -555,11 +589,27 @@ def main(): ...@@ -555,11 +589,27 @@ def main():
# Arguments. # Arguments.
args = get_args() args = get_args()
writer = None
if args.tensorboard_dir and args.rank == 0:
try:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir = args.tensorboard_dir)
except ModuleNotFoundError:
print_rank_0('WARNING: TensorBoard writing requested but is not '
'available (are you using PyTorch 1.1.0 or later?), '
'no TensorBoard logs will be written.')
writer = None
# Pytorch distributed. # Pytorch distributed.
initialize_distributed(args) initialize_distributed(args)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('Pretrain GPT2 model') print('Pretrain GPT2 model')
print_args(args) print_args(args, writer)
# Autoresume.
torch.distributed.barrier()
if args.adlr_autoresume:
enable_adlr_autoresume(args)
# Random seeds for reproducability. # Random seeds for reproducability.
set_random_seed(args.seed) set_random_seed(args.seed)
...@@ -576,11 +626,15 @@ def main(): ...@@ -576,11 +626,15 @@ def main():
if train_data is not None: if train_data is not None:
train_data.batch_sampler.start_iter = args.iteration % \ train_data.batch_sampler.start_iter = args.iteration % \
len(train_data) len(train_data)
print_rank_0('setting training data start iteration to {}'.
format(train_data.batch_sampler.start_iter))
if val_data is not None: if val_data is not None:
start_iter_val = (args.train_iters // args.save_interval) * \ start_iter_val = (args.iteration // args.eval_interval) * \
args.eval_interval args.eval_iters
val_data.batch_sampler.start_iter = start_iter_val % \ val_data.batch_sampler.start_iter = start_iter_val % \
len(val_data) len(val_data)
print_rank_0('setting validation data start iteration to {}'.
format(val_data.batch_sampler.start_iter))
if train_data is not None: if train_data is not None:
train_data_iterator = iter(train_data) train_data_iterator = iter(train_data)
else: else:
...@@ -598,12 +652,13 @@ def main(): ...@@ -598,12 +652,13 @@ def main():
lr_scheduler, lr_scheduler,
train_data_iterator, train_data_iterator,
val_data_iterator, val_data_iterator,
timers, args) timers, args, writer)
if args.do_valid: if args.do_valid:
prefix = 'the end of training for val data' prefix = 'the end of training for val data'
val_loss = evaluate_and_print_results(prefix, val_data_iterator, val_loss = evaluate_and_print_results(prefix, val_data_iterator,
model, args, timers, False) model, args, writer, iteration,
timers, False)
if args.save and iteration != 0: if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, save_checkpoint(iteration, model, optimizer,
...@@ -618,7 +673,7 @@ def main(): ...@@ -618,7 +673,7 @@ def main():
# Run on test data. # Run on test data.
prefix = 'the end of training for test data' prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, test_data_iterator, evaluate_and_print_results(prefix, test_data_iterator,
model, args, timers, True) model, args, None, 0, timers, True)
if __name__ == "__main__": if __name__ == "__main__":
......
#!/bin/bash #!/bin/bash
CHECKPOINT_PATH=/path/to/checkpoint CHECKPOINT_PATH=checkpoints/gpt2_345m/
MPSIZE=1 MPSIZE=1
NLAYERS=24 NLAYERS=12
NHIDDEN=1024 NHIDDEN=768
NATT=16 NATT=12
MAXSEQLEN=1024 MAXSEQLEN=1024
#SAMPLING ARGS #SAMPLING ARGS
...@@ -26,4 +26,7 @@ python generate_samples.py \ ...@@ -26,4 +26,7 @@ python generate_samples.py \
--out-seq-length $MAXSEQLEN \ --out-seq-length $MAXSEQLEN \
--temperature $TEMP \ --temperature $TEMP \
--top_k $TOPK \ --top_k $TOPK \
--top_p $TOPP --genfile dbg_unconditional.json \
--num-samples 10 \
--top_p $TOPP \
--recompute
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
...@@ -28,6 +28,8 @@ parser.add_argument('--data-path', type=str, required=True, ...@@ -28,6 +28,8 @@ parser.add_argument('--data-path', type=str, required=True,
help='Data path for evaluation data') help='Data path for evaluation data')
parser.add_argument('--cloze-eval', action='store_true', parser.add_argument('--cloze-eval', action='store_true',
help='Run lambada cloze eval instead of perplexity eval.') help='Run lambada cloze eval instead of perplexity eval.')
parser.add_argument('--strict-lambada', action='store_true',
help='use more difficult formulation of lambada')
parser.add_argument('--webtext-eval', action='store_true', parser.add_argument('--webtext-eval', action='store_true',
help='Run webtext PPL eval instead of wikitext PPL eval.') help='Run webtext PPL eval instead of wikitext PPL eval.')
parser.add_argument('--eval-iters', default=5000, type=int, parser.add_argument('--eval-iters', default=5000, type=int,
...@@ -38,6 +40,9 @@ parser.add_argument('--load-openai', action='store_true', ...@@ -38,6 +40,9 @@ parser.add_argument('--load-openai', action='store_true',
help='Load weights from saved openai/hf checkpoints') help='Load weights from saved openai/hf checkpoints')
parser.add_argument('--cache-dir', type=str, default='cache', parser.add_argument('--cache-dir', type=str, default='cache',
help='directory to cache gpt2 tokenizers') help='directory to cache gpt2 tokenizers')
parser.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
help='Pad the vocab size to be divisible by this value.'
'This is added for computational efficieny reasons.')
args = parser.parse_args() args = parser.parse_args()
multinode_args = '' multinode_args = ''
...@@ -60,18 +65,23 @@ CMD = ' --model-parallel-size {model_par} \ ...@@ -60,18 +65,23 @@ CMD = ' --model-parallel-size {model_par} \
--attention-dropout 0.1 \ --attention-dropout 0.1 \
--fp16 \ --fp16 \
--overlapping-eval 32 \ --overlapping-eval 32 \
--make-vocab-size-divisible-by {make_vocab_size_divisible_by} \
--cache-dir {cache} '.format(model_par=args.model_parallel_size, --cache-dir {cache} '.format(model_par=args.model_parallel_size,
nlayers=args.num_layers, nlayers=args.num_layers,
hidden=args.hidden_size, hidden=args.hidden_size,
model=args.model_path, model=args.model_path,
batch=args.batch_size, batch=args.batch_size,
natt=args.num_attention_heads, natt=args.num_attention_heads,
make_vocab_size_divisible_by=args.make_vocab_size_divisible_by,
cache=args.cache_dir) cache=args.cache_dir)
if args.load_openai: if args.load_openai:
CMD += ' --load-openai ' CMD += ' --load-openai '
if args.cloze_eval: if args.cloze_eval:
CMD += ' --valid-data {} '.format(args.data_path)
CMD += ' --cloze-eval ' CMD += ' --cloze-eval '
if args.strict_lambada:
CMD += ' --strict-lambada '
CMD = 'evaluate_gpt2.py' + CMD CMD = 'evaluate_gpt2.py' + CMD
print('Running Lambada Eval Command:', flush=True) print('Running Lambada Eval Command:', flush=True)
elif args.webtext_eval: elif args.webtext_eval:
......
"""
Takes a corpora of files (specified by `--input_files`) with json data separated
by newlines (loose json). Splits data into train.json, val.json, test.json files
under `output_dir`.
Note: This code has the potential to override files with the names
train.json, val.json, test.json in `--output_dir`.
"""
import os
import argparse
import math
import random
parser = argparse.ArgumentParser('resplit loose json data into train/val/test')
parser.add_argument('--input_files', nargs='+', required=True,
help='whitespace separated list of input data files')
parser.add_argument('--output_dir', required=True,
help='output directory where to put files')
parser.add_argument('--test_percent', type=float, nargs='+', default=[0.05, 0],
help='percentage of available data to use for val/test dataset')
args = parser.parse_args()
def get_lines(filepath):
lines = []
with open(filepath, 'r') as f:
for i, l in enumerate(f.readlines()):
l = l.strip()
lines.append(l)
return lines
def get_splits(lines, line_counts):
all_lines = []
line_idx = []
file_mappings = []
for i, l in enumerate(lines):
all_lines.extend(l)
line_idx.extend(list(range(len(l))))
file_mappings.extend([i]*len(l))
indices = list(range(len(all_lines)))
random.shuffle(indices)
all_lines = [all_lines[idx] for idx in indices]
line_idx = [line_idx[idx] for idx in indices]
file_mappings = [file_mappings[idx] for idx in indices]
splits = []
mappings = []
start = 0
for end in line_counts:
end += start
splits.append(all_lines[start:end])
mappings.append(format_mappings(line_idx[start:end], file_mappings[start:end]))
start = end
return splits, mappings
def format_mappings(line_idx, file_mappings):
lines = []
for m, l in zip(file_mappings, line_idx):
lines.append(str(m).strip()+'\t'+str(l).strip())
return lines
def get_filepaths(filepaths, output_dir):
paths = []
train_path = 'train.json'
dev_path = 'dev.json'
test_path = 'test.json'
paths.append(os.path.join(output_dir, train_path))
paths.append(os.path.join(output_dir, dev_path))
paths.append(os.path.join(output_dir, test_path))
return paths
def write_files(lines, mappings, filepaths):
for l, m, path in zip(lines, mappings, filepaths):
write_file(l, path)
write_mapping_file(m, path)
def write_file(lines, path):
print('Writing:', path)
with open(path, 'w') as f:
for l in lines:
f.write(l+'\n')
def write_mapping_file(m, path):
path = path+'.map'
m = [get_mapping_header()]+m
write_file(m, path)
def get_mapping_header():
return 'file\tline #'
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
lines = []
for filepath in args.input_files:
_lines = get_lines(filepath)
lines.append(_lines)
#calculate number of lines to use for each
line_counts = [len(l) for l in lines]
total_lines = sum(line_counts)
dev_percent = args.test_percent[0]
dev_lines = math.ceil(dev_percent*total_lines)
test_percent = 0
if len(args.test_percent)==2:
test_percent=args.test_percent[1]
test_lines = math.ceil(test_percent*total_lines)
train_lines = total_lines-(test_lines+dev_lines)
normed_lines = [train_lines, dev_lines, test_lines]
normed_lines = [int(l) for l in normed_lines]
splits, mappings = get_splits(lines, normed_lines)
filepaths = get_filepaths(args.input_files, args.output_dir)
print('Writing output to:', filepaths)
write_files(splits, mappings, filepaths)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment