Commit 48a5e0da authored by Jared Casper's avatar Jared Casper
Browse files

Integrate code from t5_main into existing code.

parent f32a638d
...@@ -149,11 +149,11 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -149,11 +149,11 @@ def parse_args(extra_args_provider=None, defaults={},
flush=True) flush=True)
# If we do accumulation and all-reduces in fp32, we need to have # If we do accumulation and all-reduces in fp32, we need to have
# local DDP and we should set the use-contiguous-buffers-in-ddp. # local DDP and we should set the use-contiguous-buffers-in-ddp.
if args.accumulate_allreduce_grads_in_fp32: if args.accumulate_allreduce_grads_in_fp32:
assert args.DDP_impl == 'local' assert args.DDP_impl == 'local'
args.use_contiguous_buffers_in_ddp = True args.use_contiguous_buffers_in_ddp = True
if args.dataloader_type is None: if args.dataloader_type is None:
args.dataloader_type = 'single' args.dataloader_type = 'single'
...@@ -212,7 +212,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -212,7 +212,7 @@ def parse_args(extra_args_provider=None, defaults={},
else: else:
assert args.encoder_seq_length is not None assert args.encoder_seq_length is not None
args.seq_length = args.encoder_seq_length args.seq_length = args.encoder_seq_length
assert args.hidden_size % args.num_attention_heads == 0 assert args.hidden_size % args.num_attention_heads == 0
if args.seq_length is not None: if args.seq_length is not None:
assert args.max_position_embeddings >= args.seq_length assert args.max_position_embeddings >= args.seq_length
...@@ -625,6 +625,9 @@ def _add_data_args(parser): ...@@ -625,6 +625,9 @@ def _add_data_args(parser):
help='Path to the vocab file.') help='Path to the vocab file.')
group.add_argument('--merge-file', type=str, default=None, group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file.') help='Path to the BPE merge file.')
group.add_argument('--vocab-extra-ids', type=int, default=0,
help='Number of additional vocabulary tokens. '
'They are used for span masking in the T5 model')
group.add_argument('--seq-length', type=int, default=None, group.add_argument('--seq-length', type=int, default=None,
help='Maximum sequence length to process.') help='Maximum sequence length to process.')
group.add_argument('--encoder-seq-length', type=int, default=None, group.add_argument('--encoder-seq-length', type=int, default=None,
......
...@@ -15,24 +15,25 @@ ...@@ -15,24 +15,25 @@
"""BERT Style dataset.""" """BERT Style dataset."""
import os
import time
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import Dataset
from megatron import get_tokenizer, get_args from megatron import (
from megatron import print_rank_0 get_args,
from megatron import mpu get_tokenizer,
from megatron.data.dataset_utils import get_a_and_b_segments mpu,
from megatron.data.dataset_utils import truncate_segments print_rank_0
from megatron.data.dataset_utils import create_tokens_and_tokentypes )
from megatron.data.dataset_utils import pad_and_convert_to_numpy from megatron.data.dataset_utils import (
from megatron.data.dataset_utils import create_masked_lm_predictions get_samples_mapping,
get_a_and_b_segments,
truncate_segments,
create_tokens_and_tokentypes,
create_masked_lm_predictions
)
class BertDataset(Dataset): class BertDataset(torch.utils.data.Dataset):
def __init__(self, name, indexed_dataset, data_prefix, def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob, num_epochs, max_num_samples, masked_lm_prob,
...@@ -49,15 +50,15 @@ class BertDataset(Dataset): ...@@ -49,15 +50,15 @@ class BertDataset(Dataset):
self.indexed_dataset = indexed_dataset self.indexed_dataset = indexed_dataset
# Build the samples mapping. # Build the samples mapping.
self.samples_mapping = get_samples_mapping_(self.indexed_dataset, self.samples_mapping = get_samples_mapping(self.indexed_dataset,
data_prefix, data_prefix,
num_epochs, num_epochs,
max_num_samples, max_num_samples,
self.max_seq_length, self.max_seq_length - 3, # account for added tokens
short_seq_prob, short_seq_prob,
self.seed, self.seed,
self.name, self.name,
self.binary_head) self.binary_head)
# Vocab stuff. # Vocab stuff.
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
...@@ -87,91 +88,6 @@ class BertDataset(Dataset): ...@@ -87,91 +88,6 @@ class BertDataset(Dataset):
self.binary_head) self.binary_head)
def get_samples_mapping_(indexed_dataset,
data_prefix,
num_epochs,
max_num_samples,
max_seq_length,
short_seq_prob,
seed,
name,
binary_head):
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
"or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples:
max_num_samples = np.iinfo(np.int64).max - 1
# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(max_seq_length)
indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob)
indexmap_filename += '_{}s'.format(seed)
indexmap_filename += '.npy'
# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types.
assert indexed_dataset.doc_idx.dtype == np.int64
assert indexed_dataset.sizes.dtype == np.int32
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
print_rank_0(' > building sapmles index mapping for {} ...'.format(
name))
# First compile and then import.
from megatron.data import helpers
samples_mapping = helpers.build_mapping(
indexed_dataset.doc_idx,
indexed_dataset.sizes,
num_epochs,
max_num_samples,
max_seq_length - 3, # account for added tokens
short_seq_prob,
seed,
verbose,
2 if binary_head else 1)
print_rank_0(' > done building sapmles index maping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
indexmap_filename))
# Make sure all the ranks have built the mapping
print_rank_0(' > elasped time to build and save samples mapping '
'(seconds): {:4f}'.format(
time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename))
start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
samples_mapping.shape[0]))
return samples_mapping
def build_training_sample(sample, def build_training_sample(sample,
...@@ -225,7 +141,7 @@ def build_training_sample(sample, ...@@ -225,7 +141,7 @@ def build_training_sample(sample,
# Masking. # Masking.
max_predictions_per_seq = masked_lm_prob * max_num_tokens max_predictions_per_seq = masked_lm_prob * max_num_tokens
(tokens, masked_positions, masked_labels, _) = create_masked_lm_predictions( (tokens, masked_positions, masked_labels, _, _) = create_masked_lm_predictions(
tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng) cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)
...@@ -244,3 +160,35 @@ def build_training_sample(sample, ...@@ -244,3 +160,35 @@ def build_training_sample(sample,
'truncated': int(truncated)} 'truncated': int(truncated)}
return train_sample return train_sample
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
masked_labels, pad_id, max_seq_length):
"""Pad sequences and convert them to numpy."""
# Some checks.
num_tokens = len(tokens)
padding_length = max_seq_length - num_tokens
assert padding_length >= 0
assert len(tokentypes) == num_tokens
assert len(masked_positions) == len(masked_labels)
# Tokens and token types.
filler = [pad_id] * padding_length
tokens_np = np.array(tokens + filler, dtype=np.int64)
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
# Padding mask.
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
dtype=np.int64)
# Lables and loss mask.
labels = [-1] * max_seq_length
loss_mask = [0] * max_seq_length
for i in range(len(masked_positions)):
assert masked_positions[i] < num_tokens
labels[masked_positions[i]] = masked_labels[i]
loss_mask[masked_positions[i]] = 1
labels_np = np.array(labels, dtype=np.int64)
loss_mask_np = np.array(loss_mask, dtype=np.int64)
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
...@@ -19,18 +19,26 @@ ...@@ -19,18 +19,26 @@
# with some modifications. # with some modifications.
import math import math
import os
import time import time
import collections import collections
import numpy as np import numpy as np
from megatron import get_args, print_rank_0 import torch
from megatron import (
get_args,
mpu,
print_rank_0
)
from megatron.data.blendable_dataset import BlendableDataset from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
DSET_TYPE_STD = 'standard_bert' DSET_TYPE_BERT = 'standard_bert'
DSET_TYPE_ICT = 'ict' DSET_TYPE_ICT = 'ict'
DSET_TYPE_T5 = 't5'
DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD] DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5]
def get_datasets_weights_and_num_samples(data_prefix, def get_datasets_weights_and_num_samples(data_prefix,
...@@ -153,7 +161,7 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): ...@@ -153,7 +161,7 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
# [SEP]. # [SEP].
tokens.append(sep_id) tokens.append(sep_id)
tokentypes.append(1) tokentypes.append(1)
return tokens, tokentypes return tokens, tokentypes
...@@ -179,7 +187,9 @@ def create_masked_lm_predictions(tokens, ...@@ -179,7 +187,9 @@ def create_masked_lm_predictions(tokens,
max_ngrams=3, max_ngrams=3,
do_whole_word_mask=True, do_whole_word_mask=True,
favor_longer_ngram=False, favor_longer_ngram=False,
do_permutation=False): do_permutation=False,
geometric_dist=False,
masking_style="bert"):
"""Creates the predictions for the masked LM objective. """Creates the predictions for the masked LM objective.
Note: Tokens here are vocab ids and not text tokens.""" Note: Tokens here are vocab ids and not text tokens."""
...@@ -219,14 +229,14 @@ def create_masked_lm_predictions(tokens, ...@@ -219,14 +229,14 @@ def create_masked_lm_predictions(tokens,
num_to_predict = min(max_predictions_per_seq, num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob)))) max(1, int(round(len(tokens) * masked_lm_prob))))
# Note(mingdachen):
# By default, we set the probilities to favor shorter ngram sequences.
ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64) ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
pvals = 1. / np.arange(1, max_ngrams + 1) if not geometric_dist:
pvals /= pvals.sum(keepdims=True) # Note(mingdachen):
# By default, we set the probilities to favor shorter ngram sequences.
if favor_longer_ngram: pvals = 1. / np.arange(1, max_ngrams + 1)
pvals = pvals[::-1] pvals /= pvals.sum(keepdims=True)
if favor_longer_ngram:
pvals = pvals[::-1]
ngram_indexes = [] ngram_indexes = []
for idx in range(len(cand_indexes)): for idx in range(len(cand_indexes)):
...@@ -237,7 +247,7 @@ def create_masked_lm_predictions(tokens, ...@@ -237,7 +247,7 @@ def create_masked_lm_predictions(tokens,
np_rng.shuffle(ngram_indexes) np_rng.shuffle(ngram_indexes)
masked_lms = [] (masked_lms, masked_spans) = ([], [])
covered_indexes = set() covered_indexes = set()
for cand_index_set in ngram_indexes: for cand_index_set in ngram_indexes:
if len(masked_lms) >= num_to_predict: if len(masked_lms) >= num_to_predict:
...@@ -251,9 +261,16 @@ def create_masked_lm_predictions(tokens, ...@@ -251,9 +261,16 @@ def create_masked_lm_predictions(tokens,
if index in covered_indexes: if index in covered_indexes:
continue continue
n = np_rng.choice(ngrams[:len(cand_index_set)], if not geometric_dist:
p=pvals[:len(cand_index_set)] / n = np_rng.choice(ngrams[:len(cand_index_set)],
pvals[:len(cand_index_set)].sum(keepdims=True)) p=pvals[:len(cand_index_set)] /
pvals[:len(cand_index_set)].sum(keepdims=True))
else:
# Sampling "n" from the geometric distribution and clipping it to
# the max_ngrams. Using p=0.2 default from the SpanBERT paper
# https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1)
n = min(np_rng.geometric(0.2), max_ngrams)
index_set = sum(cand_index_set[n - 1], []) index_set = sum(cand_index_set[n - 1], [])
n -= 1 n -= 1
# Note(mingdachen): # Note(mingdachen):
...@@ -277,24 +294,31 @@ def create_masked_lm_predictions(tokens, ...@@ -277,24 +294,31 @@ def create_masked_lm_predictions(tokens,
continue continue
for index in index_set: for index in index_set:
covered_indexes.add(index) covered_indexes.add(index)
masked_token = None masked_token = None
# 80% of the time, replace with [MASK] if masking_style == "bert":
if np_rng.random() < 0.8: # 80% of the time, replace with [MASK]
if np_rng.random() < 0.8:
masked_token = mask_id
else:
# 10% of the time, keep original
if np_rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))]
elif masking_style == "t5":
masked_token = mask_id masked_token = mask_id
else: else:
# 10% of the time, keep original raise ValueError("invalid value of masking style")
if np_rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))]
output_tokens[index] = masked_token output_tokens[index] = masked_token
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
assert len(masked_lms) <= num_to_predict
masked_spans.append(MaskedLmInstance(
index=index_set,
label=[tokens[index] for index in index_set]))
assert len(masked_lms) <= num_to_predict
np_rng.shuffle(ngram_indexes) np_rng.shuffle(ngram_indexes)
select_indexes = set() select_indexes = set()
...@@ -347,12 +371,13 @@ def create_masked_lm_predictions(tokens, ...@@ -347,12 +371,13 @@ def create_masked_lm_predictions(tokens,
masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i])) masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i]))
masked_lms = sorted(masked_lms, key=lambda x: x.index) masked_lms = sorted(masked_lms, key=lambda x: x.index)
# Sort the spans by the index of the first span
masked_spans = sorted(masked_spans, key=lambda x: x.index[0])
for p in masked_lms: for p in masked_lms:
masked_lm_positions.append(p.index) masked_lm_positions.append(p.index)
masked_lm_labels.append(p.label) masked_lm_labels.append(p.label)
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary, masked_spans)
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
...@@ -390,9 +415,10 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, ...@@ -390,9 +415,10 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples, train_valid_test_num_samples,
max_seq_length, masked_lm_prob, max_seq_length,
short_seq_prob, seed, skip_warmup, masked_lm_prob, short_seq_prob, seed,
binary_head, skip_warmup, binary_head=False,
max_seq_length_dec=None,
dataset_type='standard_bert'): dataset_type='standard_bert'):
if len(data_prefix) == 1: if len(data_prefix) == 1:
...@@ -403,6 +429,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -403,6 +429,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
short_seq_prob, seed, short_seq_prob, seed,
skip_warmup, skip_warmup,
binary_head, binary_head,
max_seq_length_dec,
dataset_type=dataset_type) dataset_type=dataset_type)
# Blending dataset. # Blending dataset.
# Parse the values. # Parse the values.
...@@ -444,11 +471,12 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -444,11 +471,12 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples, train_valid_test_num_samples,
max_seq_length, masked_lm_prob, max_seq_length,
short_seq_prob, seed, skip_warmup, masked_lm_prob, short_seq_prob, seed,
binary_head, skip_warmup, binary_head,
max_seq_length_dec,
dataset_type='standard_bert'): dataset_type='standard_bert'):
if dataset_type not in DSET_TYPES: if dataset_type not in DSET_TYPES:
raise ValueError("Invalid dataset_type: ", dataset_type) raise ValueError("Invalid dataset_type: ", dataset_type)
...@@ -489,6 +517,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -489,6 +517,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
def build_dataset(index, name): def build_dataset(index, name):
from megatron.data.bert_dataset import BertDataset from megatron.data.bert_dataset import BertDataset
from megatron.data.ict_dataset import ICTDataset from megatron.data.ict_dataset import ICTDataset
from megatron.data.t5_dataset import T5Dataset
dataset = None dataset = None
if splits[index + 1] > splits[index]: if splits[index + 1] > splits[index]:
# Get the pointer to the original doc-idx so we can set it later. # Get the pointer to the original doc-idx so we can set it later.
...@@ -507,7 +536,6 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -507,7 +536,6 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
max_num_samples=train_valid_test_num_samples[index], max_num_samples=train_valid_test_num_samples[index],
max_seq_length=max_seq_length, max_seq_length=max_seq_length,
seed=seed, seed=seed,
binary_head=binary_head
) )
if dataset_type == DSET_TYPE_ICT: if dataset_type == DSET_TYPE_ICT:
...@@ -517,15 +545,27 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -517,15 +545,27 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
title_dataset=title_dataset, title_dataset=title_dataset,
query_in_block_prob=args.query_in_block_prob, query_in_block_prob=args.query_in_block_prob,
use_one_sent_docs=args.use_one_sent_docs, use_one_sent_docs=args.use_one_sent_docs,
binary_head=binary_head,
**kwargs **kwargs
) )
else: elif dataset_type == DSET_TYPE_T5:
dataset = T5Dataset(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
max_seq_length_dec=max_seq_length_dec,
short_seq_prob=short_seq_prob,
**kwargs
)
elif dataset_type == DSET_TYPE_BERT:
dataset = BertDataset( dataset = BertDataset(
indexed_dataset=indexed_dataset, indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob, masked_lm_prob=masked_lm_prob,
short_seq_prob=short_seq_prob, short_seq_prob=short_seq_prob,
binary_head=binary_head,
**kwargs **kwargs
) )
else:
raise NotImplementedError("Dataset type not fully implemented.")
# Set the original pointer so dataset remains the main dataset. # Set the original pointer so dataset remains the main dataset.
indexed_dataset.set_doc_idx(doc_idx_ptr) indexed_dataset.set_doc_idx(doc_idx_ptr)
...@@ -590,4 +630,90 @@ def get_train_valid_test_split_(splits_string, size): ...@@ -590,4 +630,90 @@ def get_train_valid_test_split_(splits_string, size):
assert splits_index[-1] == size assert splits_index[-1] == size
return splits_index return splits_index
def get_samples_mapping(indexed_dataset,
data_prefix,
num_epochs,
max_num_samples,
max_seq_length,
short_seq_prob,
seed,
name,
binary_head):
"""Get a list that maps a sample index to a starting sentence index, end sentence index, and length"""
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
"or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples:
max_num_samples = np.iinfo(np.int64).max - 1
# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(max_seq_length)
indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob)
indexmap_filename += '_{}s'.format(seed)
indexmap_filename += '.npy'
# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types.
assert indexed_dataset.doc_idx.dtype == np.int64
assert indexed_dataset.sizes.dtype == np.int32
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
print_rank_0(' > building sapmles index mapping for {} ...'.format(
name))
# First compile and then import.
from megatron.data import helpers
samples_mapping = helpers.build_mapping(
indexed_dataset.doc_idx,
indexed_dataset.sizes,
num_epochs,
max_num_samples,
max_seq_length,
short_seq_prob,
seed,
verbose,
2 if binary_head else 1)
print_rank_0(' > done building sapmles index maping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
indexmap_filename))
# Make sure all the ranks have built the mapping
print_rank_0(' > elasped time to build and save samples mapping '
'(seconds): {:4f}'.format(
time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename))
start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
samples_mapping.shape[0]))
return samples_mapping
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""T5 Style dataset."""
import collections
import numpy as np
import torch
from megatron import get_tokenizer
from megatron.data.dataset_utils import (
create_masked_lm_predictions,
get_samples_mapping
)
class T5Dataset(torch.utils.data.Dataset):
def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, max_seq_length_dec,
short_seq_prob, seed):
# Params to store.
self.name = name
self.seed = seed
self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length
self.max_seq_length_dec = max_seq_length_dec
# Dataset.
self.indexed_dataset = indexed_dataset
# Build the samples mapping.
self.samples_mapping = get_samples_mapping(self.indexed_dataset,
data_prefix,
num_epochs,
max_num_samples,
self.max_seq_length - 2, # account for added tokens
short_seq_prob,
self.seed,
self.name,
False)
# Vocab stuff.
tokenizer = get_tokenizer()
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
self.vocab_id_to_token_dict = tokenizer.inv_vocab
self.cls_id = tokenizer.cls
self.sep_id = tokenizer.sep
self.mask_id = tokenizer.mask
self.pad_id = tokenizer.pad
self.bos_id = tokenizer.bos_token_id
self.eos_id = tokenizer.eos_token_id
self.sentinel_tokens = tokenizer.additional_special_tokens_ids
assert len(self.sentinel_tokens) > 0, "Provide the argument --vocab-extra-ids 100 to the script"
def __len__(self):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
start_index, end_index, seq_length = self.samples_mapping[idx]
sample = []
for index in range(start_index, end_index):
sample.append(self.indexed_dataset[index])
# Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive.
np_rng = np.random.RandomState(seed=(self.seed + idx))
return build_training_sample(sample, seq_length,
self.max_seq_length, # needed for padding
self.max_seq_length_dec,
self.vocab_id_list,
self.vocab_id_to_token_dict,
self.cls_id, self.sep_id,
self.mask_id, self.pad_id,
self.masked_lm_prob, np_rng,
self.bos_id, self.eos_id,
self.sentinel_tokens)
def build_training_sample(sample, target_seq_length,
max_seq_length, max_seq_length_dec,
vocab_id_list, vocab_id_to_token_dict,
cls_id, sep_id, mask_id, pad_id,
masked_lm_prob, np_rng, bos_id=None,
eos_id=None, sentinel_tokens=None):
"""Build training sample.
Arguments:
sample: A list of sentences in which each sentence is a list token ids.
target_seq_length: Desired sequence length.
max_seq_length: Maximum length of the sequence. All values are padded to
this length.
vocab_id_list: List of vocabulary ids. Used to pick a random id.
vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
cls_id: Start of example id.
sep_id: Separator id.
mask_id: Mask token id.
pad_id: Padding token id.
masked_lm_prob: Probability to mask tokens.
np_rng: Random number genenrator. Note that this rng state should be
numpy and not python since python randint is inclusive for
the opper bound whereas the numpy one is exclusive.
bos_id: start of decoder example id
eos_id: end of generation id
sentinel_tokens: unique value to be substituted for every replaced span
"""
assert target_seq_length <= max_seq_length
# flatten sentences into one list
tokens = [token for sentence in sample for token in sentence]
# Truncate to `target_sequence_length`.
max_num_tokens = target_seq_length
truncated = len(tokens) > max_num_tokens
tokens = tokens[:max_num_tokens]
# Masking.
max_predictions_per_seq = masked_lm_prob * max_num_tokens
(tokens, masked_positions, masked_labels, _, masked_spans) = create_masked_lm_predictions(
tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng,
max_ngrams=10, geometric_dist=True, masking_style="t5")
# Padding.
tokens_enc, tokens_dec_in, labels, enc_mask, \
dec_mask, enc_dec_mask, loss_mask \
= pad_and_convert_to_numpy(tokens, masked_positions,
masked_labels, pad_id, max_seq_length,
max_seq_length_dec, masked_spans,
bos_id, eos_id, sentinel_tokens)
train_sample = {
'text_enc': tokens_enc,
'text_dec': tokens_dec_in,
'labels': labels,
'loss_mask': loss_mask,
'truncated': int(truncated),
'enc_mask': enc_mask,
'dec_mask': dec_mask,
'enc_dec_mask': enc_dec_mask,
}
return train_sample
def pad_and_convert_to_numpy(tokens, masked_positions,
masked_labels, pad_id,
max_seq_length, max_seq_length_dec,
masked_spans=None, bos_id=None,
eos_id=None, sentinel_tokens=None):
"""Pad sequences and convert them to numpy."""
sentinel_tokens = collections.deque(sentinel_tokens)
t5_input = []
(t5_decoder_in, t5_decoder_out) = ([bos_id], [])
(start_index, end_index) = (0, None)
for span in masked_spans:
flag = sentinel_tokens.popleft()
# Append the same tokens in decoder input and output
t5_decoder_in.append(flag)
t5_decoder_in.extend(span.label)
t5_decoder_out.append(flag)
t5_decoder_out.extend(span.label)
end_index = span.index[0]
t5_input.extend(tokens[start_index: end_index])
t5_input.append(flag)
# the next start index is the token after the last span token
start_index = span.index[-1] + 1
# Add <eos> token to the t5_decoder_out
t5_decoder_out.append(eos_id)
# Add the remaining tokens to the t5 input
t5_input.extend(tokens[start_index:])
# assert (len(t5_input) - len(masked_spans)) + \
# (len(t5_decoder_in) - (len(masked_spans) + 1)) == len(tokens)
# Some checks.
# Encoder-side padding mask.
num_tokens = len(t5_input)
padding_length = max_seq_length - num_tokens
assert padding_length >= 0
assert len(masked_positions) == len(masked_labels)
# Tokens..
filler = [pad_id] * padding_length
tokens_enc = np.array(t5_input + filler, dtype=np.int64)
# Decoder-side padding mask.
num_tokens_dec = len(t5_decoder_in)
padding_length_dec = max_seq_length_dec - num_tokens_dec
assert padding_length_dec >= 0
filler_dec = [pad_id] * padding_length_dec
tokens_dec_in = np.array(t5_decoder_in + filler_dec, dtype=np.int64)
# Create attention masks
enc_mask = make_attention_mask(tokens_enc, tokens_enc)
enc_dec_mask = make_attention_mask(tokens_dec_in, tokens_enc)
dec_mask = make_attention_mask(tokens_dec_in, tokens_dec_in)
dec_mask = dec_mask * make_history_mask(tokens_dec_in)
# Labels mask.
labels = t5_decoder_out + ([-1] * padding_length_dec)
labels = np.array(labels, dtype=np.int64)
# Loss mask
loss_mask = ([1] * num_tokens_dec) + ([0] * padding_length_dec)
loss_mask = np.array(loss_mask, dtype=np.int64)
return tokens_enc, tokens_dec_in, labels, enc_mask, \
dec_mask, enc_dec_mask, loss_mask
def make_attention_mask(source_block, target_block):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
mask = mask.astype(np.int64)
# (source_length, target_length)
return mask
def make_attention_mask_3d(source_block, target_block):
"""
Returns a 3-dimensional (3-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask = (target_block[:, None, :] >= 1) * (source_block[:, :, None] >= 1)
# (batch, source_length, target_length)
# mask = mask.astype(np.int64)
return mask
def make_history_mask(block):
length = block.shape[0]
arange = np.arange(length)
history_mask = (arange[None, ] <= arange[:, None])
history_mask = history_mask.astype(np.int64)
return history_mask
def make_history_mask_3d(block):
batch, length = block.shape
arange = torch.arange(length, device=block.device)
history_mask = (arange[None, ] <= arange[:, None])[None, ]
history_mask = history_mask.expand(batch, length, length)
return history_mask
...@@ -18,5 +18,6 @@ from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm ...@@ -18,5 +18,6 @@ from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
from .distributed import DistributedDataParallel from .distributed import DistributedDataParallel
from .bert_model import BertModel from .bert_model import BertModel
from .gpt_model import GPTModel from .gpt_model import GPTModel
from .t5_model import T5Model
from .language_model import get_language_model from .language_model import get_language_model
from .module import Float16Module from .module import Float16Module
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""T5 model."""
import torch
from megatron import (
get_args,
mpu
)
from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits, get_language_model
from megatron.model.transformer import LayerNorm
from megatron.model.utils import (
openai_gelu,
get_linear_layer,
init_method_normal,
scaled_init_method_normal
)
from .module import MegatronModule
def t5_extended_attention_mask(attention_mask_list):
def attn_mask_postprocess(attn_mask):
# [b, 1, s, s]
extended_attention_mask = attn_mask.unsqueeze(1)
return extended_attention_mask
return [attn_mask_postprocess(attn_mask) for attn_mask in attention_mask_list]
def t5_position_ids(token_ids):
# Create position ids
seq_length = token_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long,
device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
return position_ids
class T5LMHead(MegatronModule):
"""Masked LM head for T5
Arguments:
mpu_vocab_size: model parallel size of vocabulary.
hidden_size: hidden size
init_method: init method for weight initialization
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: wether output logits being distributed or not.
"""
def __init__(self, mpu_vocab_size, parallel_output):
super(T5LMHead, self).__init__()
args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.model_parallel = True
self.bias.partition_dim = 0
self.bias.stride = 1
self.parallel_output = parallel_output
def forward(self, hidden_states, word_embeddings_weight):
output = parallel_lm_logits(hidden_states,
word_embeddings_weight,
self.parallel_output,
bias=self.bias)
return output
class T5Model(MegatronModule):
"""T5 Language model."""
def __init__(self, num_tokentypes=0, parallel_output=True):
super(T5Model, self).__init__()
args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.parallel_output = parallel_output
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes,
add_pooler=False,
add_decoder=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method)
self.lm_head = T5LMHead(
self.language_model.embedding.word_embeddings.weight.size(0),
parallel_output)
self._lm_head_key = 'lm_head'
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(self, encoder_input_ids, decoder_input_ids, encoder_attn_mask,
decoder_attn_mask, encoder_decoder_attn_mask,
tokentype_ids=None, lm_labels=None, enc_hidden_states=None):
# Converting the attention masks to proper parameter settings
encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask = t5_extended_attention_mask(
[encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask])
encoder_position_ids = t5_position_ids(encoder_input_ids)
decoder_position_ids = t5_position_ids(decoder_input_ids)
lm_output = self.language_model(encoder_input_ids,
encoder_position_ids,
encoder_attn_mask,
decoder_input_ids,
decoder_position_ids,
decoder_attn_mask,
encoder_decoder_attn_mask,
tokentype_ids=tokentype_ids,
enc_hidden_states=enc_hidden_states)
decoder_output, encoder_output = lm_output
# Output.
lm_logits = self.lm_head(decoder_output,
self.language_model.embedding.word_embeddings.weight)
if lm_labels is None:
return lm_logits, encoder_output
else:
if self.fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
return lm_loss, encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._lm_head_key] \
= self.lm_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
self.lm_head.load_state_dict(state_dict[self._lm_head_key],
strict=strict)
...@@ -32,10 +32,12 @@ def build_tokenizer(args): ...@@ -32,10 +32,12 @@ def build_tokenizer(args):
assert args.vocab_file is not None assert args.vocab_file is not None
if args.tokenizer_type == 'BertWordPieceLowerCase': if args.tokenizer_type == 'BertWordPieceLowerCase':
tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file, tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=True) lower_case=True,
vocab_extra_ids=args.vocab_extra_ids)
elif args.tokenizer_type == 'BertWordPieceCase': elif args.tokenizer_type == 'BertWordPieceCase':
tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file, tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=False) lower_case=False,
vocab_extra_ids=args.vocab_extra_ids)
elif args.tokenizer_type == 'GPT2BPETokenizer': elif args.tokenizer_type == 'GPT2BPETokenizer':
assert args.merge_file is not None assert args.merge_file is not None
tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file) tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
...@@ -127,7 +129,7 @@ class AbstractTokenizer(ABC): ...@@ -127,7 +129,7 @@ class AbstractTokenizer(ABC):
class _BertWordPieceTokenizer(AbstractTokenizer): class _BertWordPieceTokenizer(AbstractTokenizer):
"""Original BERT wordpiece tokenizer.""" """Original BERT wordpiece tokenizer."""
def __init__(self, vocab_file, lower_case=True): def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0):
if lower_case: if lower_case:
name = 'BERT Lower Case' name = 'BERT Lower Case'
else: else:
...@@ -138,6 +140,37 @@ class _BertWordPieceTokenizer(AbstractTokenizer): ...@@ -138,6 +140,37 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
self.sep_id = self.tokenizer.vocab['[SEP]'] self.sep_id = self.tokenizer.vocab['[SEP]']
self.pad_id = self.tokenizer.vocab['[PAD]'] self.pad_id = self.tokenizer.vocab['[PAD]']
self.mask_id = self.tokenizer.vocab['[MASK]'] self.mask_id = self.tokenizer.vocab['[MASK]']
self._additional_special_tokens = []
# (dsachan) Add BOS and EOS tokens
SPECIAL_TOKENS = {'eos_token': '[EOS]',
'bos_token': '[BOS]'}
self._bos_token = '[BOS]'
self.add_token(self._bos_token)
self._bos_token_id = self.vocab.get(self._bos_token)
self._eos_token = '[EOS]'
self.add_token(self._eos_token)
self._eos_token_id = self.vocab.get(self._eos_token)
# (dsachan) Add additional special tokens
# These can be used as sentinel tokens in T5 model inputs
additional_special_tokens = []
additional_special_tokens.extend(
["<extra_id_{}>".format(i) for i in range(vocab_extra_ids)])
self.add_additional_special_tokens(additional_special_tokens)
def add_token(self, token):
if token not in self.vocab:
self.inv_vocab[self.vocab_size] = token
# self.vocab_size comes from len(vocab)
# and it will increase as we add elements
self.vocab[token] = self.vocab_size
def add_additional_special_tokens(self, tokens_list):
setattr(self, "additional_special_tokens", tokens_list)
for value in tokens_list:
self.add_token(value)
@property @property
def vocab_size(self): def vocab_size(self):
...@@ -155,6 +188,10 @@ class _BertWordPieceTokenizer(AbstractTokenizer): ...@@ -155,6 +188,10 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
text_tokens = self.tokenizer.tokenize(text) text_tokens = self.tokenizer.tokenize(text)
return self.tokenizer.convert_tokens_to_ids(text_tokens) return self.tokenizer.convert_tokens_to_ids(text_tokens)
def decode(self, ids):
tokens = self.tokenizer.convert_ids_to_tokens(ids)
return self.tokenizer.convert_tokens_to_string(tokens)
def decode_token_ids(self, token_ids): def decode_token_ids(self, token_ids):
tokens = self.tokenizer.convert_ids_to_tokens(token_ids) tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
exclude_list = ['[PAD]', '[CLS]'] exclude_list = ['[PAD]', '[CLS]']
...@@ -185,6 +222,40 @@ class _BertWordPieceTokenizer(AbstractTokenizer): ...@@ -185,6 +222,40 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
def mask(self): def mask(self):
return self.mask_id return self.mask_id
@property
def bos_token(self):
""" Beginning of sentence token id """
return self._bos_token
@property
def eos_token(self):
""" End of sentence token id """
return self._eos_token
@property
def additional_special_tokens(self):
""" All the additional special tokens you may want to use (list of strings)."""
return self._additional_special_tokens
@property
def bos_token_id(self):
""" Id of the beginning of sentence token in the vocabulary."""
return self._bos_token_id
@property
def eos_token_id(self):
""" Id of the end of sentence token in the vocabulary."""
return self._eos_token_id
@property
def additional_special_tokens_ids(self):
""" Ids of all the additional special tokens in the vocabulary (list of integers)."""
return [self.vocab.get(token) for token in self._additional_special_tokens]
@additional_special_tokens.setter
def additional_special_tokens(self, value):
self._additional_special_tokens = value
class _GPT2BPETokenizer(AbstractTokenizer): class _GPT2BPETokenizer(AbstractTokenizer):
"""Original GPT2 BPE tokenizer.""" """Original GPT2 BPE tokenizer."""
......
...@@ -15,9 +15,11 @@ ...@@ -15,9 +15,11 @@
"""Pretrain BERT""" """Pretrain BERT"""
from functools import partial
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from functools import partial
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
......
...@@ -104,7 +104,7 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -104,7 +104,7 @@ def forward_step(data_iterator, model, input_tensor):
global_batch_size = dist.get_world_size() * micro_batch_size global_batch_size = dist.get_world_size() * micro_batch_size
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits) all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits) all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)
# scores are inner products between query and context embeddings # scores are inner products between query and context embeddings
retrieval_scores = torch.matmul(all_query_logits, retrieval_scores = torch.matmul(all_query_logits,
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain T5"""
from functools import partial
import torch
from megatron import (
get_args,
get_timers,
mpu,
print_rank_0
)
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import T5Model
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
assert pre_process and post_process, "T5 doesn't yet support pipelining"
print_rank_0('building T5 model ...')
model = T5Model(num_tokentypes=0,
parallel_output=True)
return model
def get_batch(data_iterator):
"""Build the batch."""
keys = ['text_enc', 'text_dec', 'labels', 'loss_mask',
'enc_mask', 'dec_mask', 'enc_dec_mask']
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens_enc = data_b['text_enc'].long()
tokens_dec = data_b['text_dec'].long()
labels = data_b['labels'].long()
loss_mask = data_b['loss_mask'].float()
enc_mask = (data_b['enc_mask'] < 0.5)
dec_mask = (data_b['dec_mask'] < 0.5)
enc_dec_mask = (data_b['enc_dec_mask'] < 0.5)
return tokens_enc, tokens_dec, loss_mask, labels, \
enc_mask, dec_mask, enc_dec_mask
def loss_func(loss_mask, output_tensor):
lm_loss_, _ = output_tensor
lm_loss_ = lm_loss_.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
loss = lm_loss
averaged_losses = average_losses_across_data_parallel_group([lm_loss])
return loss, {'lm loss': averaged_losses[0]}
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch generator').start()
tokens_enc, tokens_dec, loss_mask, lm_labels, enc_mask, dec_mask, enc_dec_mask \
= get_batch(data_iterator)
timers('batch generator').stop()
# Forward model lm_labels
output_tensor = model(tokens_enc,
tokens_dec,
enc_mask,
dec_mask,
enc_dec_mask,
tokentype_ids=None,
lm_labels=lm_labels)
return output_tensor, partial(loss_func, loss_mask)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0('> building train, validation, and test datasets '
'for T5 ...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
max_seq_length=args.encoder_seq_length,
max_seq_length_dec=args.decoder_seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
dataset_type='t5')
print_rank_0("> finished creating T5 datasets ...")
return train_ds, valid_ds, test_ds
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment