Commit 651cc107 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Merging megatron with ICT

parents e919dd8e c601d751
"""AutoAugment data augmentation policy for ImageNet.
-- Begin license text.
MIT License
Copyright (c) 2018 Philip Popien
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
-- End license text.
Code adapted from https://github.com/DeepVoltaire/AutoAugment.
This module implements the fixed AutoAugment data augmentation policy for ImageNet provided in
Appendix A, Table 9 of reference [1]. It does not include any of the search code for augmentation
policies.
Reference:
[1] https://arxiv.org/abs/1805.09501
"""
import random
import numpy as np
from PIL import Image
from PIL import ImageEnhance
from PIL import ImageOps
_MAX_LEVEL = 10 # Maximum integer strength of an augmentation, if applicable.
class ImageNetPolicy:
"""Definition of an ImageNetPolicy.
Implements a fixed AutoAugment data augmentation policy targeted at
ImageNet training by randomly applying at runtime one of the 25 pre-defined
data augmentation sub-policies provided in Reference [1].
Usage example as a Pytorch Transform:
>>> transform=transforms.Compose([transforms.Resize(256),
>>> ImageNetPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
"""Initialize an ImageNetPolicy.
Args:
fillcolor (tuple): RGB color components of the color to be used for
filling when needed (default: (128, 128, 128), which
corresponds to gray).
"""
# Instantiate a list of sub-policies.
# Each entry of the list is a SubPolicy which consists of
# two augmentation operations,
# each of those parametrized as operation, probability, magnitude.
# Those two operations are applied sequentially on the image upon call.
self.policies = [
SubPolicy("posterize", 0.4, 8, "rotate", 0.6, 9, fillcolor),
SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor),
SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor),
SubPolicy("posterize", 0.6, 7, "posterize", 0.6, 6, fillcolor),
SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor),
SubPolicy("equalize", 0.4, 4, "rotate", 0.8, 8, fillcolor),
SubPolicy("solarize", 0.6, 3, "equalize", 0.6, 7, fillcolor),
SubPolicy("posterize", 0.8, 5, "equalize", 1.0, 2, fillcolor),
SubPolicy("rotate", 0.2, 3, "solarize", 0.6, 8, fillcolor),
SubPolicy("equalize", 0.6, 8, "posterize", 0.4, 6, fillcolor),
SubPolicy("rotate", 0.8, 8, "color", 0.4, 0, fillcolor),
SubPolicy("rotate", 0.4, 9, "equalize", 0.6, 2, fillcolor),
SubPolicy("equalize", 0.0, 7, "equalize", 0.8, 8, fillcolor),
SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor),
SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor),
SubPolicy("rotate", 0.8, 8, "color", 1.0, 2, fillcolor),
SubPolicy("color", 0.8, 8, "solarize", 0.8, 7, fillcolor),
SubPolicy("sharpness", 0.4, 7, "invert", 0.6, 8, fillcolor),
SubPolicy("shearX", 0.6, 5, "equalize", 1.0, 9, fillcolor),
SubPolicy("color", 0.4, 0, "equalize", 0.6, 3, fillcolor),
SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor),
SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor),
SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor),
SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor),
SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor),
]
def __call__(self, img):
"""Define call method for ImageNetPolicy class."""
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
"""Define repr method for ImageNetPolicy class."""
return "ImageNetPolicy"
class SubPolicy:
"""Definition of a SubPolicy.
A SubPolicy consists of two augmentation operations,
each of those parametrized as operation, probability, magnitude.
The two operations are applied sequentially on the image upon call.
"""
def __init__(
self,
operation1,
probability1,
magnitude_idx1,
operation2,
probability2,
magnitude_idx2,
fillcolor,
):
"""Initialize a SubPolicy.
Args:
operation1 (str): Key specifying the first augmentation operation.
There are fourteen key values altogether (see supported_ops below
listing supported operations). probability1 (float): Probability
within [0., 1.] of applying the first augmentation operation.
magnitude_idx1 (int): Integer specifiying the strength of the first
operation as an index further used to derive the magnitude from a
range of possible values.
operation2 (str): Key specifying the second augmentation operation.
probability2 (float): Probability within [0., 1.] of applying the
second augmentation operation.
magnitude_idx2 (int): Integer specifiying the strength of the
second operation as an index further used to derive the magnitude
from a range of possible values.
fillcolor (tuple): RGB color components of the color to be used for
filling.
Returns:
"""
# List of supported operations for operation1 and operation2.
supported_ops = [
"shearX",
"shearY",
"translateX",
"translateY",
"rotate",
"color",
"posterize",
"solarize",
"contrast",
"sharpness",
"brightness",
"autocontrast",
"equalize",
"invert",
]
assert (operation1 in supported_ops) and (
operation2 in supported_ops
), "SubPolicy:one of oper1 or oper2 refers to an unsupported operation."
assert (
0.0 <= probability1 <= 1.0 and 0.0 <= probability2 <= 1.0
), "SubPolicy: prob1 and prob2 should be within [0., 1.]."
assert (
isinstance(magnitude_idx1, int) and 0 <= magnitude_idx1 <= 10
), "SubPolicy: idx1 should be specified as an integer within [0, 10]."
assert (
isinstance(magnitude_idx2, int) and 0 <= magnitude_idx2 <= 10
), "SubPolicy: idx2 should be specified as an integer within [0, 10]."
# Define a dictionary where each key refers to a specific type of
# augmentation and the corresponding value is a range of ten possible
# magnitude values for that augmentation.
num_levels = _MAX_LEVEL + 1
ranges = {
"shearX": np.linspace(0, 0.3, num_levels),
"shearY": np.linspace(0, 0.3, num_levels),
"translateX": np.linspace(0, 150 / 331, num_levels),
"translateY": np.linspace(0, 150 / 331, num_levels),
"rotate": np.linspace(0, 30, num_levels),
"color": np.linspace(0.0, 0.9, num_levels),
"posterize": np.round(np.linspace(8, 4, num_levels), 0).astype(
np.int
),
"solarize": np.linspace(256, 0, num_levels), # range [0, 256]
"contrast": np.linspace(0.0, 0.9, num_levels),
"sharpness": np.linspace(0.0, 0.9, num_levels),
"brightness": np.linspace(0.0, 0.9, num_levels),
"autocontrast": [0]
* num_levels, # This augmentation doesn't use magnitude parameter.
"equalize": [0]
* num_levels, # This augmentation doesn't use magnitude parameter.
"invert": [0]
* num_levels, # This augmentation doesn't use magnitude parameter.
}
def rotate_with_fill(img, magnitude):
"""Define rotation transformation with fill.
The input image is first rotated, then it is blended together with
a gray mask of the same size. Note that fillcolor as defined
elsewhere in this module doesn't apply here.
Args:
magnitude (float): rotation angle in degrees.
Returns:
rotated_filled (PIL Image): rotated image with gray filling for
disoccluded areas unveiled by the rotation.
"""
rotated = img.convert("RGBA").rotate(magnitude)
rotated_filled = Image.composite(
rotated, Image.new("RGBA", rotated.size, (128,) * 4), rotated
)
return rotated_filled.convert(img.mode)
# Define a dictionary of augmentation functions where each key refers
# to a specific type of augmentation and the corresponding value defines
# the augmentation itself using a lambda function.
# pylint: disable=unnecessary-lambda
func_dict = {
"shearX": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
Image.BICUBIC,
fillcolor=fillcolor,
),
"shearY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
Image.BICUBIC,
fillcolor=fillcolor,
),
"translateX": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(
1,
0,
magnitude * img.size[0] * random.choice([-1, 1]),
0,
1,
0,
),
fillcolor=fillcolor,
),
"translateY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(
1,
0,
0,
0,
1,
magnitude * img.size[1] * random.choice([-1, 1]),
),
fillcolor=fillcolor,
),
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
1 + magnitude * random.choice([-1, 1])
),
"posterize": lambda img, magnitude: ImageOps.posterize(
img, magnitude
),
"solarize": lambda img, magnitude: ImageOps.solarize(
img, magnitude
),
"contrast": lambda img, magnitude: ImageEnhance.Contrast(
img
).enhance(1 + magnitude * random.choice([-1, 1])),
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(
img
).enhance(1 + magnitude * random.choice([-1, 1])),
"brightness": lambda img, magnitude: ImageEnhance.Brightness(
img
).enhance(1 + magnitude * random.choice([-1, 1])),
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
"equalize": lambda img, magnitude: ImageOps.equalize(img),
"invert": lambda img, magnitude: ImageOps.invert(img),
}
# Store probability, function and magnitude of the first augmentation
# for the sub-policy.
self.probability1 = probability1
self.operation1 = func_dict[operation1]
self.magnitude1 = ranges[operation1][magnitude_idx1]
# Store probability, function and magnitude of the second augmentation
# for the sub-policy.
self.probability2 = probability2
self.operation2 = func_dict[operation2]
self.magnitude2 = ranges[operation2][magnitude_idx2]
def __call__(self, img):
"""Define call method for SubPolicy class."""
# Randomly apply operation 1.
if random.random() < self.probability1:
img = self.operation1(img, self.magnitude1)
# Randomly apply operation 2.
if random.random() < self.probability2:
img = self.operation2(img, self.magnitude2)
return img
...@@ -36,13 +36,14 @@ class BertDataset(Dataset): ...@@ -36,13 +36,14 @@ class BertDataset(Dataset):
def __init__(self, name, indexed_dataset, data_prefix, def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob, num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed): max_seq_length, short_seq_prob, seed, binary_head):
# Params to store. # Params to store.
self.name = name self.name = name
self.seed = seed self.seed = seed
self.masked_lm_prob = masked_lm_prob self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length self.max_seq_length = max_seq_length
self.binary_head = binary_head
# Dataset. # Dataset.
self.indexed_dataset = indexed_dataset self.indexed_dataset = indexed_dataset
...@@ -55,7 +56,8 @@ class BertDataset(Dataset): ...@@ -55,7 +56,8 @@ class BertDataset(Dataset):
self.max_seq_length, self.max_seq_length,
short_seq_prob, short_seq_prob,
self.seed, self.seed,
self.name) self.name,
self.binary_head)
# Vocab stuff. # Vocab stuff.
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
...@@ -81,7 +83,8 @@ class BertDataset(Dataset): ...@@ -81,7 +83,8 @@ class BertDataset(Dataset):
self.vocab_id_to_token_dict, self.vocab_id_to_token_dict,
self.cls_id, self.sep_id, self.cls_id, self.sep_id,
self.mask_id, self.pad_id, self.mask_id, self.pad_id,
self.masked_lm_prob, np_rng) self.masked_lm_prob, np_rng,
self.binary_head)
def get_samples_mapping_(indexed_dataset, def get_samples_mapping_(indexed_dataset,
...@@ -91,7 +94,8 @@ def get_samples_mapping_(indexed_dataset, ...@@ -91,7 +94,8 @@ def get_samples_mapping_(indexed_dataset,
max_seq_length, max_seq_length,
short_seq_prob, short_seq_prob,
seed, seed,
name): name,
binary_head):
if not num_epochs: if not num_epochs:
if not max_num_samples: if not max_num_samples:
raise ValueError("Need to specify either max_num_samples " raise ValueError("Need to specify either max_num_samples "
...@@ -128,8 +132,6 @@ def get_samples_mapping_(indexed_dataset, ...@@ -128,8 +132,6 @@ def get_samples_mapping_(indexed_dataset,
print_rank_0(' > building sapmles index mapping for {} ...'.format( print_rank_0(' > building sapmles index mapping for {} ...'.format(
name)) name))
# First compile and then import. # First compile and then import.
from megatron.data.dataset_utils import compile_helper
compile_helper()
from megatron.data import helpers from megatron.data import helpers
samples_mapping = helpers.build_mapping( samples_mapping = helpers.build_mapping(
indexed_dataset.doc_idx, indexed_dataset.doc_idx,
...@@ -139,7 +141,8 @@ def get_samples_mapping_(indexed_dataset, ...@@ -139,7 +141,8 @@ def get_samples_mapping_(indexed_dataset,
max_seq_length - 3, # account for added tokens max_seq_length - 3, # account for added tokens
short_seq_prob, short_seq_prob,
seed, seed,
verbose) verbose,
2 if binary_head else 1)
print_rank_0(' > done building sapmles index maping') print_rank_0(' > done building sapmles index maping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True) np.save(indexmap_filename, samples_mapping, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format( print_rank_0(' > saved the index mapping in {}'.format(
...@@ -175,7 +178,7 @@ def build_training_sample(sample, ...@@ -175,7 +178,7 @@ def build_training_sample(sample,
target_seq_length, max_seq_length, target_seq_length, max_seq_length,
vocab_id_list, vocab_id_to_token_dict, vocab_id_list, vocab_id_to_token_dict,
cls_id, sep_id, mask_id, pad_id, cls_id, sep_id, mask_id, pad_id,
masked_lm_prob, np_rng): masked_lm_prob, np_rng, binary_head):
"""Biuld training sample. """Biuld training sample.
Arguments: Arguments:
...@@ -195,12 +198,21 @@ def build_training_sample(sample, ...@@ -195,12 +198,21 @@ def build_training_sample(sample,
the opper bound whereas the numpy one is exclusive. the opper bound whereas the numpy one is exclusive.
""" """
# We assume that we have at least two sentences in the sample if binary_head:
assert len(sample) > 1 # We assume that we have at least two sentences in the sample
assert len(sample) > 1
assert target_seq_length <= max_seq_length assert target_seq_length <= max_seq_length
# Divide sample into two segments (A and B). # Divide sample into two segments (A and B).
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng) if binary_head:
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample,
np_rng)
else:
tokens_a = []
for j in range(len(sample)):
tokens_a.extend(sample[j])
tokens_b = []
is_next_random = False
# Truncate to `target_sequence_length`. # Truncate to `target_sequence_length`.
max_num_tokens = target_seq_length max_num_tokens = target_seq_length
......
...@@ -49,13 +49,6 @@ class BlendableDataset(torch.utils.data.Dataset): ...@@ -49,13 +49,6 @@ class BlendableDataset(torch.utils.data.Dataset):
self.dataset_index = np.zeros(self.size, dtype=np.uint8) self.dataset_index = np.zeros(self.size, dtype=np.uint8)
self.dataset_sample_index = np.zeros(self.size, dtype=np.int64) self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)
if torch.distributed.get_rank() == 0:
from megatron.data.dataset_utils import compile_helper
compile_helper()
# Simple barrier
tmp = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(tmp, group=mpu.get_data_parallel_group())
from megatron.data import helpers from megatron.data import helpers
helpers.build_blending_indices(self.dataset_index, helpers.build_blending_indices(self.dataset_index,
self.dataset_sample_index, self.dataset_sample_index,
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import torch import torch
import random
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
...@@ -30,12 +30,23 @@ def build_pretraining_data_loader(dataset, consumed_samples): ...@@ -30,12 +30,23 @@ def build_pretraining_data_loader(dataset, consumed_samples):
args = get_args() args = get_args()
# Megatron sampler # Megatron sampler
batch_sampler = MegatronPretrainingSampler( if args.dataloader_type == 'single':
total_samples=len(dataset), batch_sampler = MegatronPretrainingSampler(
consumed_samples=consumed_samples, total_samples=len(dataset),
micro_batch_size=args.micro_batch_size, consumed_samples=consumed_samples,
data_parallel_rank=mpu.get_data_parallel_rank(), micro_batch_size=args.micro_batch_size,
data_parallel_size=mpu.get_data_parallel_world_size()) data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
elif args.dataloader_type == 'cyclic':
batch_sampler = MegatronPretrainingRandomSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
else:
raise Exception('{} dataloader type is not supported.'.format(
args.dataloader_type))
# Torch dataloader. # Torch dataloader.
return torch.utils.data.DataLoader(dataset, return torch.utils.data.DataLoader(dataset,
...@@ -43,10 +54,8 @@ def build_pretraining_data_loader(dataset, consumed_samples): ...@@ -43,10 +54,8 @@ def build_pretraining_data_loader(dataset, consumed_samples):
num_workers=args.num_workers, num_workers=args.num_workers,
pin_memory=True) pin_memory=True)
class MegatronPretrainingSampler: class MegatronPretrainingSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size, def __init__(self, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size): data_parallel_rank, data_parallel_size):
# Keep a copy of input params for later use. # Keep a copy of input params for later use.
...@@ -54,8 +63,8 @@ class MegatronPretrainingSampler: ...@@ -54,8 +63,8 @@ class MegatronPretrainingSampler:
self.consumed_samples = consumed_samples self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank self.data_parallel_rank = data_parallel_rank
self.micro_batch_times_data_parallel_size = self.micro_batch_size * \ self.micro_batch_times_data_parallel_size = \
data_parallel_size self.micro_batch_size * data_parallel_size
# Sanity checks. # Sanity checks.
assert self.total_samples > 0, \ assert self.total_samples > 0, \
...@@ -69,11 +78,9 @@ class MegatronPretrainingSampler: ...@@ -69,11 +78,9 @@ class MegatronPretrainingSampler:
'data_parallel_rank should be smaller than data size: {}, ' \ 'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size) '{}'.format(self.data_parallel_rank, data_parallel_size)
def __len__(self): def __len__(self):
return self.total_samples return self.total_samples
def __iter__(self): def __iter__(self):
batch = [] batch = []
# Last batch if not complete will be dropped. # Last batch if not complete will be dropped.
...@@ -84,3 +91,57 @@ class MegatronPretrainingSampler: ...@@ -84,3 +91,57 @@ class MegatronPretrainingSampler:
end_idx = start_idx + self.micro_batch_size end_idx = start_idx + self.micro_batch_size
yield batch[start_idx:end_idx] yield batch[start_idx:end_idx]
batch = [] batch = []
class MegatronPretrainingRandomSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.last_batch_size = \
self.total_samples % self.micro_batch_times_data_parallel_size
# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
def __len__(self):
return self.total_samples
def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
# data sharding and random sampling
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
* self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size
g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
batch = []
# Last batch if not complete will be dropped.
for idx in idx_range:
batch.append(idx)
if len(batch) == self.micro_batch_size:
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch = []
...@@ -114,7 +114,6 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): ...@@ -114,7 +114,6 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
"""Truncates a pair of sequences to a maximum sequence length.""" """Truncates a pair of sequences to a maximum sequence length."""
#print(len_a, len_b, max_num_tokens) #print(len_a, len_b, max_num_tokens)
assert len_a > 0 assert len_a > 0
assert len_b > 0
if len_a + len_b <= max_num_tokens: if len_a + len_b <= max_num_tokens:
return False return False
while len_a + len_b > max_num_tokens: while len_a + len_b > max_num_tokens:
...@@ -150,10 +149,11 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): ...@@ -150,10 +149,11 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
for token in tokens_b: for token in tokens_b:
tokens.append(token) tokens.append(token)
tokentypes.append(1) tokentypes.append(1)
# [SEP]. if tokens_b:
tokens.append(sep_id) # [SEP].
tokentypes.append(1) tokens.append(sep_id)
tokentypes.append(1)
return tokens, tokentypes return tokens, tokentypes
...@@ -392,6 +392,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -392,6 +392,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples, train_valid_test_num_samples,
max_seq_length, masked_lm_prob, max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup, short_seq_prob, seed, skip_warmup,
binary_head,
dataset_type='standard_bert'): dataset_type='standard_bert'):
if len(data_prefix) == 1: if len(data_prefix) == 1:
...@@ -401,6 +402,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -401,6 +402,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
max_seq_length, masked_lm_prob, max_seq_length, masked_lm_prob,
short_seq_prob, seed, short_seq_prob, seed,
skip_warmup, skip_warmup,
binary_head,
dataset_type=dataset_type) dataset_type=dataset_type)
# Blending dataset. # Blending dataset.
# Parse the values. # Parse the values.
...@@ -417,7 +419,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -417,7 +419,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
prefixes[i], data_impl, splits_string, prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i], datasets_train_valid_test_num_samples[i],
max_seq_length, masked_lm_prob, short_seq_prob, max_seq_length, masked_lm_prob, short_seq_prob,
seed, skip_warmup, dataset_type=dataset_type) seed, skip_warmup, binary_head, dataset_type=dataset_type)
if train_ds: if train_ds:
train_datasets.append(train_ds) train_datasets.append(train_ds)
if valid_ds: if valid_ds:
...@@ -444,6 +446,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -444,6 +446,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples, train_valid_test_num_samples,
max_seq_length, masked_lm_prob, max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup, short_seq_prob, seed, skip_warmup,
binary_head,
dataset_type='standard_bert'): dataset_type='standard_bert'):
if dataset_type not in DSET_TYPES: if dataset_type not in DSET_TYPES:
...@@ -503,7 +506,8 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -503,7 +506,8 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
num_epochs=None, num_epochs=None,
max_num_samples=train_valid_test_num_samples[index], max_num_samples=train_valid_test_num_samples[index],
max_seq_length=max_seq_length, max_seq_length=max_seq_length,
seed=seed seed=seed,
binary_head=binary_head
) )
if dataset_type == DSET_TYPE_ICT: if dataset_type == DSET_TYPE_ICT:
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""GPT2 style dataset.""" """GPT style dataset."""
import os import os
import time import time
...@@ -107,7 +107,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -107,7 +107,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
if splits[index + 1] > splits[index]: if splits[index + 1] > splits[index]:
documents = np.arange(start=splits[index], stop=splits[index + 1], documents = np.arange(start=splits[index], stop=splits[index + 1],
step=1, dtype=np.int32) step=1, dtype=np.int32)
dataset = GPT2Dataset(name, data_prefix, dataset = GPTDataset(name, data_prefix,
documents, indexed_dataset, documents, indexed_dataset,
train_valid_test_num_samples[index], train_valid_test_num_samples[index],
seq_length, seed) seq_length, seed)
...@@ -136,7 +136,7 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): ...@@ -136,7 +136,7 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
return indexed_dataset return indexed_dataset
class GPT2Dataset(torch.utils.data.Dataset): class GPTDataset(torch.utils.data.Dataset):
def __init__(self, name, data_prefix, documents, indexed_dataset, def __init__(self, name, data_prefix, documents, indexed_dataset,
num_samples, seq_length, seed): num_samples, seq_length, seed):
...@@ -269,8 +269,6 @@ def _build_index_mappings(name, data_prefix, documents, sizes, ...@@ -269,8 +269,6 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
start_time = time.time() start_time = time.time()
# Use C++ implementation for speed. # Use C++ implementation for speed.
# First compile and then import. # First compile and then import.
from megatron.data.dataset_utils import compile_helper
compile_helper()
from megatron.data import helpers from megatron.data import helpers
assert doc_idx.dtype == np.int32 assert doc_idx.dtype == np.int32
assert sizes.dtype == np.int32 assert sizes.dtype == np.int32
......
...@@ -189,6 +189,9 @@ inline int32_t get_target_sample_len(const int32_t short_seq_ratio, ...@@ -189,6 +189,9 @@ inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
const int32_t max_length, const int32_t max_length,
std::mt19937& rand32_gen) { std::mt19937& rand32_gen) {
/* Training sample length. */ /* Training sample length. */
if (short_seq_ratio == 0) {
return max_length;
}
const auto random_number = rand32_gen(); const auto random_number = rand32_gen();
if ((random_number % short_seq_ratio) == 0) { if ((random_number % short_seq_ratio) == 0) {
return 2 + random_number % (max_length - 1); return 2 + random_number % (max_length - 1);
...@@ -205,7 +208,8 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -205,7 +208,8 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
const int32_t max_seq_length, const int32_t max_seq_length,
const double short_seq_prob, const double short_seq_prob,
const int32_t seed, const int32_t seed,
const bool verbose) { const bool verbose,
const int32_t min_num_sent) {
/* Build a mapping of (start-index, end-index, sequence-length) where /* Build a mapping of (start-index, end-index, sequence-length) where
start and end index are the indices of the sentences in the sample start and end index are the indices of the sentences in the sample
and sequence-length is the target sequence length. and sequence-length is the target sequence length.
...@@ -214,7 +218,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -214,7 +218,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
// Consistency checks. // Consistency checks.
assert(num_epochs > 0); assert(num_epochs > 0);
assert(max_seq_length > 1); assert(max_seq_length > 1);
assert(short_seq_prob > 0.0); assert(short_seq_prob >= 0.0);
assert(short_seq_prob <= 1.0); assert(short_seq_prob <= 1.0);
assert(seed > 0); assert(seed > 0);
...@@ -223,7 +227,10 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -223,7 +227,10 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
auto sizes = sizes_.unchecked<1>(); auto sizes = sizes_.unchecked<1>();
// For efficiency, convert probability to ratio. Note: rand() generates int. // For efficiency, convert probability to ratio. Note: rand() generates int.
const auto short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob)); int32_t short_seq_ratio = 0;
if (short_seq_prob > 0) {
short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
}
if (verbose) { if (verbose) {
const auto sent_start_index = docs[0]; const auto sent_start_index = docs[0];
...@@ -322,7 +329,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -322,7 +329,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
} }
// If we have more than two sentences. // If we have more than two sentences.
if ((num_remain_sent > 1) && (!contains_long_sentence)) { if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
// Set values. // Set values.
auto seq_len = int32_t{0}; auto seq_len = int32_t{0};
...@@ -346,7 +353,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -346,7 +353,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
// and if we have reached end of the document. // and if we have reached end of the document.
if (((seq_len >= target_seq_len) && if (((seq_len >= target_seq_len) &&
(num_remain_sent > 1) && (num_remain_sent > 1) &&
(num_sent > 1) ) || (num_remain_sent == 0)) { (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
// Check for overflow. // Check for overflow.
if ((3 * map_index + 2) > if ((3 * map_index + 2) >
...@@ -437,7 +444,8 @@ py::array build_mapping(const py::array_t<int64_t>& docs_, ...@@ -437,7 +444,8 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
const int max_seq_length, const int max_seq_length,
const double short_seq_prob, const double short_seq_prob,
const int seed, const int seed,
const bool verbose) { const bool verbose,
const int32_t min_num_sent) {
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) { if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
if (verbose) { if (verbose) {
...@@ -445,14 +453,16 @@ py::array build_mapping(const py::array_t<int64_t>& docs_, ...@@ -445,14 +453,16 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
} }
return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs, return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length, max_num_samples, max_seq_length,
short_seq_prob, seed, verbose); short_seq_prob, seed, verbose,
min_num_sent);
} else { } else {
if (verbose) { if (verbose) {
cout << " using uint32 for data mapping..." << endl << std::flush; cout << " using uint32 for data mapping..." << endl << std::flush;
} }
return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs, return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length, max_num_samples, max_seq_length,
short_seq_prob, seed, verbose); short_seq_prob, seed, verbose,
min_num_sent);
} }
} }
......
...@@ -49,7 +49,7 @@ class ICTDataset(Dataset): ...@@ -49,7 +49,7 @@ class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task.""" """Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix, def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length, query_in_block_prob, num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
seed, use_titles=True, use_one_sent_docs=False): seed, use_titles=True, use_one_sent_docs=False, binary_head=False):
self.name = name self.name = name
self.seed = seed self.seed = seed
self.max_seq_length = max_seq_length self.max_seq_length = max_seq_length
......
...@@ -152,10 +152,6 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo ...@@ -152,10 +152,6 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
print_rank_0(' > building samples index mapping for {} ...'.format( print_rank_0(' > building samples index mapping for {} ...'.format(
name)) name))
# compile/bind the C++ helper code
from megatron.data.dataset_utils import compile_helper
compile_helper()
from megatron.data import helpers from megatron.data import helpers
mapping_array = helpers.build_blocks_mapping( mapping_array = helpers.build_blocks_mapping(
block_dataset.doc_idx, block_dataset.doc_idx,
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from torchvision import datasets, transforms
from megatron.data.autoaugment import ImageNetPolicy
def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True):
# training dataset
train_data_path = os.path.join(data_path[0], "train")
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
process = [
transforms.RandomResizedCrop(crop_size),
transforms.RandomHorizontalFlip(),
]
if color_jitter:
process += [
transforms.ColorJitter(
brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1
)
]
fp16_t = transforms.ConvertImageDtype(torch.half)
process += [ImageNetPolicy(), transforms.ToTensor(), normalize, fp16_t]
transform_train = transforms.Compose(process)
train_data = datasets.ImageFolder(
root=train_data_path, transform=transform_train
)
# validation dataset
val_data_path = os.path.join(data_path[0], "val")
transform_val = transforms.Compose(
[
transforms.Resize(crop_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
normalize,
fp16_t
]
)
val_data = datasets.ImageFolder(
root=val_data_path, transform=transform_val
)
return train_data, val_data
This diff is collapsed.
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import mpu
class tofp16(nn.Module):
"""
Utility module that implements::
def forward(self, input):
return input.half()
"""
def __init__(self):
super(tofp16, self).__init__()
def forward(self, input):
return input.half()
def BN_convert_float(module):
"""
Utility function for network_to_half().
Retained for legacy purposes.
"""
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
module.float()
for child in module.children():
BN_convert_float(child)
return module
def network_to_half(network):
"""
Convert model to half precision in a batchnorm-safe way.
Retained for legacy purposes. It is recommended to use FP16Model.
"""
return nn.Sequential(tofp16(), BN_convert_float(network.half()))
def convert_module(module, dtype):
"""
Converts a module's immediate parameters and buffers to dtype.
"""
for param in module.parameters(recurse=False):
if param is not None:
if param.data.dtype.is_floating_point:
param.data = param.data.to(dtype=dtype)
if param._grad is not None and param._grad.data.dtype.is_floating_point:
param._grad.data = param._grad.data.to(dtype=dtype)
for buf in module.buffers(recurse=False):
if buf is not None and buf.data.dtype.is_floating_point:
buf.data = buf.data.to(dtype=dtype)
def convert_network(network, dtype):
"""
Converts a network's parameters and buffers to dtype.
"""
for module in network.modules():
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
continue
convert_module(module, dtype)
return network
class FP16Model(nn.Module):
"""
Convert model to half precision in a batchnorm-safe way.
"""
def __init__(self, network):
super(FP16Model, self).__init__()
self.network = convert_network(network, dtype=torch.half)
def forward(self, *inputs):
inputs = tuple(t.half() for t in inputs)
return self.network(*inputs)
def backwards_debug_hook(grad):
raise RuntimeError("master_params recieved a gradient in the backward pass!")
def prep_param_lists(model, flat_master=False):
"""
Creates a list of FP32 master parameters for a given model, as in
`Training Neural Networks with Mixed Precision: Real Examples`_.
Args:
model (torch.nn.Module): Existing Pytorch model
flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization.
Returns:
A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element.
Example::
model_params, master_params = prep_param_lists(model)
.. warning::
Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`.
.. _`Training Neural Networks with Mixed Precision: Real Examples`:
http://on-demand.gputechconf.com/gtc/2018/video/S81012/
"""
model_params = [param for param in model.parameters() if param.requires_grad]
if flat_master:
# Give the user some more useful error messages
try:
# flatten_dense_tensors returns a contiguous flat array.
# http://pytorch.org/docs/master/_modules/torch/_utils.html
master_params = _flatten_dense_tensors([param.data for param in model_params]).float()
except BaseException:
print("Error in prep_param_lists: model may contain a mixture of parameters "
"of different types. Use flat_master=False, or use F16_Optimizer.")
raise
master_params = torch.nn.Parameter(master_params)
master_params.requires_grad = True
# master_params.register_hook(backwards_debug_hook)
if master_params.grad is None:
master_params.grad = master_params.new(*master_params.size())
return model_params, [master_params]
else:
master_params = [param.clone().float().detach() for param in model_params]
for param in master_params:
param.requires_grad = True
return model_params, master_params
def model_grads_to_master_grads(model_params, master_params, flat_master=False):
"""
Copy model gradients to master gradients.
Args:
model_params: List of model parameters created by :func:`prep_param_lists`.
master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`.
"""
if flat_master:
# The flattening may incur one more deep copy than is necessary.
master_params[0].grad.data.copy_(
_flatten_dense_tensors([p.grad.data for p in model_params]))
else:
for model, master in zip(model_params, master_params):
if model.grad is not None:
if master.grad is None:
master.grad = Variable(master.data.new(*master.data.size()))
else:
master.grad = None
model_grads = [p.grad for p in model_params if p.grad is not None]
master_grads = [p.grad for p in master_params if p.grad is not None]
_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[model_grads, master_grads],
1.0)
def master_params_to_model_params(model_params, master_params, flat_master=False):
"""
Copy master parameters to model parameters.
Args:
model_params: List of model parameters created by :func:`prep_param_lists`.
master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`.
"""
if flat_master:
for model, master in zip(model_params,
_unflatten_dense_tensors(master_params[0].data, model_params)):
model.data.copy_(master)
else:
for model, master in zip(model_params, master_params):
model.data.copy_(master.data)
# Backward compatibility fixes
def to_python_float(t):
if hasattr(t, 'item'):
return t.item()
else:
return t[0]
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
clip_grad_norm = mpu.clip_grad_norm
# elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
# clip_grad_norm = torch.nn.utils.clip_grad_norm
# else:
# clip_grad_norm = torch.nn.utils.clip_grad_norm_
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import mpu
# item() is a recent addition, so this helps with backward compatibility.
def to_python_float(t):
if hasattr(t, 'item'):
return t.item()
else:
return t[0]
class LossScaler:
"""
Class that manages a static loss scale. This class is intended to interact with
:class:`FP16_Optimizer`, and should not be directly manipulated by the user.
Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to
:class:`FP16_Optimizer`'s constructor.
Args:
scale (float, optional, default=1.0): The loss scale.
"""
def __init__(self, scale=1):
self.cur_scale = scale
# `params` is a list / generator of torch.Variable
def has_overflow(self, params):
return False
# `x` is a torch.Tensor
def _has_inf_or_nan(x):
return False
def update_scale(self, overflow):
pass
@property
def loss_scale(self):
return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out):
_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[grad_in, grad_in],
self.loss_scale)
return grad_in
def backward(self, output_tensor, retain_graph=False, output_tensor_grad=None):
# If output_tensor_grad is None, this is the last stage, and
# output_tensor is actually the loss and needs to be scaled.
# Otherwise, output_tensor does not need to be scaled again since
# output_tensor_grad is already scaled.
if output_tensor_grad is None:
scaled_output_tensor = output_tensor * self.loss_scale
else:
scaled_output_tensor = output_tensor
torch.autograd.backward(scaled_output_tensor, grad_tensors=output_tensor_grad,
retain_graph=retain_graph)
class DynamicLossScaler:
"""
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
:class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler`
operates, because the default options can be changed using the
the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor.
Loss scaling is designed to combat the problem of underflowing gradients encountered at long
times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss
scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are
encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has
occurred.
:class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch,
and :class:`DynamicLossScaler` adjusts the loss scale to a lower value.
If a certain number of iterations occur without overflowing gradients detected,
:class:`DynamicLossScaler` increases the loss scale once more.
In this way :class:`DynamicLossScaler` attempts to "ride the edge" of
always using the highest loss scale possible without incurring overflow.
Args:
init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.`
scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``.
scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale.
"""
def __init__(self,
init_scale=2**32,
scale_factor=2.,
scale_window=1000,
min_scale=1,
delayed_shift=1,
consecutive_hysteresis=False):
self.cur_scale = init_scale
self.cur_iter = 0
self.last_overflow_iter = -1
self.scale_factor = scale_factor
self.scale_window = scale_window
self.min_scale = min_scale
self.delayed_shift = delayed_shift
self.cur_hysteresis = delayed_shift
self.consecutive_hysteresis = consecutive_hysteresis
# `params` is a list / generator of torch.Variable
def has_overflow_serial(self, params):
for p in params:
if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data):
return True
return False
def has_overflow(self, params):
overflow = self.has_overflow_serial(params)
# Since each model parallel GPU carries only part of the model,
# make sure overflow flag is synced across all the model parallel GPUs
overflow_gpu = torch.cuda.ByteTensor([overflow])
torch.distributed.all_reduce(overflow_gpu,
op=torch.distributed.ReduceOp.MAX,
group=mpu.get_model_parallel_group())
overflow = overflow_gpu[0].item()
return bool(overflow)
# `x` is a torch.Tensor
def _has_inf_or_nan(x):
try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as x
# (which is true for some recent version of pytorch).
cpu_sum = float(x.float().sum())
# More efficient version that can be used if .sum() returns a Python scalar
# cpu_sum = float(x.sum())
except RuntimeError as instance:
# We want to check if inst is actually an overflow exception.
# RuntimeError could come from a different error.
# If so, we still want the exception to propagate.
if "value cannot be converted" not in instance.args[0]:
raise
return True
else:
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
return False
# `overflow` is boolean indicating whether the gradient overflowed
def update_scale(self, overflow):
if not hasattr(self, 'min_scale'):
self.min_scale = 1
if not hasattr(self, 'delayed_shift'):
self.delayed_shift = 1
if not hasattr(self, 'cur_hysteresis'):
self.cur_hysteresis = 1
if not hasattr(self, 'consecutive_hysteresis'):
self.consecutive_hysteresis = True
if overflow:
# self.cur_scale /= self.scale_factor
if self.delayed_shift == 1 or self.cur_hysteresis == 1:
self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_scale)
else:
self.cur_hysteresis -= 1
self.last_overflow_iter = self.cur_iter
else:
if self.consecutive_hysteresis:
self.cur_hysteresis = self.delayed_shift
if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
if not self.consecutive_hysteresis:
self.cur_hysteresis = self.delayed_shift
self.cur_scale *= self.scale_factor
self.cur_iter += 1
@property
def loss_scale(self):
return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out):
_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[grad_in, grad_in],
self.loss_scale)
return grad_in
def backward(self, output_tensor, retain_graph=False, output_tensor_grad=None):
# If output_tensor_grad is None, this is the last stage, and
# output_tensor is actually the loss and needs to be scaled.
# Otherwise, output_tensor does not need to be scaled again since
# output_tensor_grad is already scaled.
if output_tensor_grad is None:
scaled_output_tensor = output_tensor * self.loss_scale
else:
scaled_output_tensor = output_tensor
torch.autograd.backward(scaled_output_tensor, grad_tensors=output_tensor_grad,
retain_graph=retain_graph)
##############################################################
# Example usage below here -- assuming it's in a separate file
##############################################################
"""
TO-DO separate out into an example.
if __name__ == "__main__":
import torch
from torch.autograd import Variable
from dynamic_loss_scaler import DynamicLossScaler
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10
# Create random Tensors to hold inputs and outputs, and wrap them in Variables.
x = Variable(torch.randn(N, D_in), requires_grad=False)
y = Variable(torch.randn(N, D_out), requires_grad=False)
w1 = Variable(torch.randn(D_in, H), requires_grad=True)
w2 = Variable(torch.randn(H, D_out), requires_grad=True)
parameters = [w1, w2]
learning_rate = 1e-6
optimizer = torch.optim.SGD(parameters, lr=learning_rate)
loss_scaler = DynamicLossScaler()
for t in range(500):
y_pred = x.mm(w1).clamp(min=0).mm(w2)
loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale
print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale))
print('Iter {} scaled loss: {}'.format(t, loss.data[0]))
print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale))
# Run backprop
optimizer.zero_grad()
loss.backward()
# Check for overflow
has_overflow = DynamicLossScaler.has_overflow(parameters)
# If no overflow, unscale grad and update as usual
if not has_overflow:
for param in parameters:
param.grad.data.mul_(1. / loss_scaler.loss_scale)
optimizer.step()
# Otherwise, don't do anything -- ie, skip iteration
else:
print('OVERFLOW!')
# Update loss scale for next iteration
loss_scaler.update_scale(has_overflow)
"""
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""For backward compatibility, we need the class definitions to deserialize."""
class LossScaler:
def __init__(self, scale=1):
self.cur_scale = scale
class DynamicLossScaler:
def __init__(self,
init_scale=2**32,
scale_factor=2.,
scale_window=1000,
min_scale=1,
delayed_shift=1,
consecutive_hysteresis=False):
self.cur_scale = init_scale
self.cur_iter = 0
self.last_overflow_iter = -1
self.scale_factor = scale_factor
self.scale_window = scale_window
self.min_scale = min_scale
self.delayed_shift = delayed_shift
self.cur_hysteresis = delayed_shift
self.consecutive_hysteresis = consecutive_hysteresis
...@@ -81,7 +81,6 @@ __global__ void scaled_masked_softmax_warp_forward( ...@@ -81,7 +81,6 @@ __global__ void scaled_masked_softmax_warp_forward(
const uint8_t *mask, const uint8_t *mask,
const acc_t scale, const acc_t scale,
int micro_batch_size, int micro_batch_size,
int stride,
int element_count, int element_count,
int pad_batches) int pad_batches)
{ {
...@@ -111,9 +110,9 @@ __global__ void scaled_masked_softmax_warp_forward( ...@@ -111,9 +110,9 @@ __global__ void scaled_masked_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch // there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x; int local_idx = threadIdx.x;
src += first_batch * stride + local_idx; src += first_batch * element_count + local_idx;
dst += first_batch * stride + local_idx; dst += first_batch * element_count + local_idx;
mask += pad_first_batch * stride + local_idx; mask += pad_first_batch * element_count + local_idx;
// load data from global memory // load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS]; acc_t elements[WARP_BATCH][WARP_ITERATIONS];
...@@ -185,7 +184,6 @@ __global__ void scaled_masked_softmax_warp_backward( ...@@ -185,7 +184,6 @@ __global__ void scaled_masked_softmax_warp_backward(
const input_t *output, const input_t *output,
acc_t scale, acc_t scale,
int micro_batch_size, int micro_batch_size,
int stride,
int element_count) int element_count)
{ {
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
...@@ -209,7 +207,7 @@ __global__ void scaled_masked_softmax_warp_backward( ...@@ -209,7 +207,7 @@ __global__ void scaled_masked_softmax_warp_backward(
int local_idx = threadIdx.x; int local_idx = threadIdx.x;
// the first element to process by the current thread // the first element to process by the current thread
int thread_offset = first_batch * stride + local_idx; int thread_offset = first_batch * element_count + local_idx;
grad += thread_offset; grad += thread_offset;
output += thread_offset; output += thread_offset;
gradInput += thread_offset; gradInput += thread_offset;
...@@ -277,20 +275,19 @@ void dispatch_scaled_masked_softmax_forward( ...@@ -277,20 +275,19 @@ void dispatch_scaled_masked_softmax_forward(
const input_t *src, const input_t *src,
const uint8_t *mask, const uint8_t *mask,
const input_t scale, const input_t scale,
int softmax_elements, int query_seq_len,
int softmax_elements_stride, int key_seq_len,
int batches, int batches,
int attn_heads, int attn_heads,
int pad_batches) int pad_batches)
{ {
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 ); TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 );
if (softmax_elements == 0) { if (key_seq_len == 0) {
return; return;
} else { } else {
int log2_elements = log2_ceil(softmax_elements); int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements; const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements; int batch_count = batches * attn_heads * query_seq_len;
int batch_count = batches * attn_heads * seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
...@@ -302,59 +299,59 @@ void dispatch_scaled_masked_softmax_forward( ...@@ -302,59 +299,59 @@ void dispatch_scaled_masked_softmax_forward(
constexpr int threads_per_block = 128; constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size); int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp; int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(seq_len%batches_per_block == 0); TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
dim3 blocks(seq_len/batches_per_block, attn_heads, batches); dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR // Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) { switch (log2_elements) {
case 0: // 1 case 0: // 1
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 1: // 2 case 1: // 2
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 2: // 4 case 2: // 4
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 3: // 8 case 3: // 8
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 4: // 16 case 4: // 16
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 5: // 32 case 5: // 32
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 6: // 64 case 6: // 64
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 7: // 128 case 7: // 128
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 8: // 256 case 8: // 256
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 9: // 512 case 9: // 512
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 10: // 1024 case 10: // 1024
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 11: // 2048 case 11: // 2048
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, softmax_elements_stride, softmax_elements, pad_batches); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
default: default:
break; break;
...@@ -368,19 +365,18 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -368,19 +365,18 @@ void dispatch_scaled_masked_softmax_backward(
input_t *grad, input_t *grad,
const input_t *output, const input_t *output,
const acc_t scale, const acc_t scale,
int softmax_elements, int query_seq_len,
int softmax_elements_stride, int key_seq_len,
int batches, int batches,
int attn_heads) int attn_heads)
{ {
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 );
if (softmax_elements == 0) { if (key_seq_len == 0) {
return; return;
} else { } else {
int log2_elements = log2_ceil(softmax_elements); int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements; const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements; int batch_count = batches * attn_heads * query_seq_len;
int batch_count = batches * attn_heads * seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
...@@ -399,51 +395,51 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -399,51 +395,51 @@ void dispatch_scaled_masked_softmax_backward(
switch (log2_elements) { switch (log2_elements) {
case 0: // 1 case 0: // 1
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 1: // 2 case 1: // 2
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 2: // 4 case 2: // 4
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 3: // 8 case 3: // 8
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 4: // 16 case 4: // 16
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 5: // 32 case 5: // 32
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 6: // 64 case 6: // 64
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 7: // 128 case 7: // 128
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 8: // 256 case 8: // 256
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 9: // 512 case 9: // 512
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 10: // 1024 case 10: // 1024
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 11: // 2048 case 11: // 2048
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
default: default:
break; break;
......
...@@ -37,17 +37,19 @@ torch::Tensor fwd_cuda( ...@@ -37,17 +37,19 @@ torch::Tensor fwd_cuda(
const int batches = input.size(0); const int batches = input.size(0);
const int pad_batches = mask.size(0); const int pad_batches = mask.size(0);
const int attn_heads = input.size(1); const int attn_heads = input.size(1);
const int seq_len = input.size(2); const int query_seq_len = input.size(2);
TORCH_INTERNAL_ASSERT(seq_len <= 2048); const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 2048);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
TORCH_INTERNAL_ASSERT(mask.size(1) == 1); TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
TORCH_INTERNAL_ASSERT(mask.size(2) == seq_len); TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == seq_len); TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
// Output // Output
auto act_options = input.options().requires_grad(false); auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results = torch::Tensor softmax_results =
torch::empty({batches, attn_heads, seq_len, seq_len}, act_options); torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr // Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr()); void* input_ptr = static_cast<void*>(input.data_ptr());
...@@ -59,8 +61,8 @@ torch::Tensor fwd_cuda( ...@@ -59,8 +61,8 @@ torch::Tensor fwd_cuda(
reinterpret_cast<const half*>(input_ptr), reinterpret_cast<const half*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr), reinterpret_cast<const uint8_t*>(mask_ptr),
scale_factor, scale_factor,
seq_len, query_seq_len,
seq_len, key_seq_len,
batches, batches,
attn_heads, attn_heads,
pad_batches); pad_batches);
...@@ -78,8 +80,8 @@ torch::Tensor bwd_cuda( ...@@ -78,8 +80,8 @@ torch::Tensor bwd_cuda(
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.size(0); const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1); const int attn_heads = output_grads.size(1);
const int seq_len = output_grads.size(2); const int query_seq_len = output_grads.size(2);
TORCH_INTERNAL_ASSERT(output_grads.size(2) == output_grads.size(3)); const int key_seq_len = output_grads.size(3);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr()); void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
...@@ -89,8 +91,8 @@ torch::Tensor bwd_cuda( ...@@ -89,8 +91,8 @@ torch::Tensor bwd_cuda(
reinterpret_cast<half*>(output_grads_ptr), reinterpret_cast<half*>(output_grads_ptr),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const*>(softmax_results.data_ptr()),
scale_factor, scale_factor,
seq_len, query_seq_len,
seq_len, key_seq_len,
batches, batches,
attn_heads); attn_heads);
......
...@@ -83,7 +83,8 @@ def set_global_variables(extra_args_provider=None, args_defaults={}, ...@@ -83,7 +83,8 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
defaults=args_defaults, defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args) ignore_unknown_args=ignore_unknown_args)
_build_num_microbatches_calculator(args) _build_num_microbatches_calculator(args)
_ = _build_tokenizer(args) if args.vocab_file:
_ = _build_tokenizer(args)
_set_tensorboard_writer(args) _set_tensorboard_writer(args)
_set_adlr_autoresume(args) _set_adlr_autoresume(args)
_set_timers() _set_timers()
...@@ -131,12 +132,13 @@ def _set_tensorboard_writer(args): ...@@ -131,12 +132,13 @@ def _set_tensorboard_writer(args):
'tensorboard writer') 'tensorboard writer')
if hasattr(args, 'tensorboard_dir') and \ if hasattr(args, 'tensorboard_dir') and \
args.tensorboard_dir and args.rank == (args.world_size -1): args.tensorboard_dir and args.rank == (args.world_size - 1):
try: try:
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
print('> setting tensorboard ...') print('> setting tensorboard ...')
_GLOBAL_TENSORBOARD_WRITER = SummaryWriter( _GLOBAL_TENSORBOARD_WRITER = SummaryWriter(
log_dir=args.tensorboard_dir) log_dir=args.tensorboard_dir,
max_queue=args.tensorboard_queue_size)
except ModuleNotFoundError: except ModuleNotFoundError:
print('WARNING: TensorBoard writing requested but is not ' print('WARNING: TensorBoard writing requested but is not '
'available (are you using PyTorch 1.1.0 or later?), ' 'available (are you using PyTorch 1.1.0 or later?), '
......
...@@ -78,6 +78,13 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -78,6 +78,13 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Autoresume. # Autoresume.
_init_autoresume() _init_autoresume()
# Compile dataset C++ code.
if torch.distributed.get_rank() == 0:
from megatron.data.dataset_utils import compile_helper
compile_helper()
# Simple barrier
torch.distributed.barrier()
# No continuation function # No continuation function
return None return None
......
...@@ -30,10 +30,17 @@ def import_layernorm(fp32_residual_connection): ...@@ -30,10 +30,17 @@ def import_layernorm(fp32_residual_connection):
from .distributed import * from .distributed import *
from .bert_model import BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage from .bert_model import (BertModel,
BertModelFirstStage,
BertModelIntermediateStage,
BertModelLastStage)
from .realm_model import ICTBertModel from .realm_model import ICTBertModel
from .gpt2_model import GPT2Model, GPT2ModelFirstStage, GPT2ModelIntermediateStage, GPT2ModelLastStage from .gpt_model import (GPTModel,
from .utils import get_params_for_weight_decay_optimization GPTModelFirstStage,
GPTModelIntermediateStage,
GPTModelLastStage)
from .language_model import get_language_model from .language_model import get_language_model
from .module import FP16Module
from .realm_model import ICTBertModel
...@@ -19,6 +19,7 @@ import torch ...@@ -19,6 +19,7 @@ import torch
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model import import_layernorm from megatron.model import import_layernorm
...@@ -26,11 +27,7 @@ from megatron.model.utils import openai_gelu, erf_gelu ...@@ -26,11 +27,7 @@ from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule, PipelinedMegatronModule from .module import MegatronModule
def bert_attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def bert_extended_attention_mask(attention_mask): def bert_extended_attention_mask(attention_mask):
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
...@@ -77,9 +74,7 @@ class BertLMHead(MegatronModule): ...@@ -77,9 +74,7 @@ class BertLMHead(MegatronModule):
args = get_args() args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.tensor_model_parallel = True mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
self.bias.partition_dim = 0
self.bias.stride = 1
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
...@@ -127,7 +122,7 @@ def post_language_model_processing(lm_output, pooled_output, ...@@ -127,7 +122,7 @@ def post_language_model_processing(lm_output, pooled_output,
return lm_loss, binary_logits return lm_loss, binary_logits
class BertModelBase(PipelinedMegatronModule): class BertModelBase(MegatronModule):
"""Bert Language model.""" """Bert Language model."""
def __init__(self, num_tokentypes=2, add_binary_head=True, def __init__(self, num_tokentypes=2, add_binary_head=True,
...@@ -144,9 +139,9 @@ class BertModelBase(PipelinedMegatronModule): ...@@ -144,9 +139,9 @@ class BertModelBase(PipelinedMegatronModule):
args.num_layers) args.num_layers)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=self.add_binary_head, add_pooler=self.add_binary_head,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method) scaled_init_method=scaled_init_method)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment