Commit f17a3933 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

Merge branch 'main' into lmcafee/embed-standalone

parents 804ed2e6 fd5469aa
......@@ -538,6 +538,9 @@ def _add_initialization_args(parser):
group.add_argument('--seed', type=int, default=1234,
help='Random seed used for python, numpy, '
'pytorch, and cuda.')
group.add_argument('--data-parallel-random-init', action='store_true',
help='Enable random initialization of params '
'across data parallel ranks')
group.add_argument('--init-method-std', type=float, default=0.02,
help='Standard deviation of the zero mean normal '
'distribution used for weight initialization.')
......@@ -701,9 +704,6 @@ def _add_distributed_args(parser):
help='Call torch.cuda.empty_cache() each iteration '
'(training and eval), to reduce fragmentation.'
'0=off, 1=moderate, 2=aggressive.')
group.add_argument('--deallocate-pipeline-outputs', action='store_true',
default=False, help='If set, pipeline output tensors '
'are deallocated during the forward pass.')
group.add_argument('--standalone-embed-stage', action='store_true',
default=False, help='If set, *input* embedding layer '
'is placed on its own pipeline stage, without any '
......@@ -859,11 +859,20 @@ def _add_vit_args(parser):
group.add_argument('--num-classes', type=int, default=1000,
help='num of classes in vision classificaiton task')
group.add_argument('--img-dim', type=int, default=224,
help='Image size for vision classification task')
group.add_argument('--img-h', type=int, default=224,
help='Image height for vision classification task')
group.add_argument('--img-w', type=int, default=224,
help='Image height for vision classification task')
group.add_argument('--num-channels', type=int, default=3,
help='Number of channels in input image data')
group.add_argument('--patch-dim', type=int, default=16,
help='patch dimension used in vit')
group.add_argument('--classes-fraction', type=float, default=1.0,
help='training with fraction of classes.')
group.add_argument('--data-per-class-fraction', type=float, default=1.0,
help='training with fraction of data per class.')
group.add_argument('--no-data-sharding', action='store_false',
help='Disable data sharding.',
dest='data_sharding')
return parser
......@@ -65,6 +65,8 @@ def check_checkpoint_args(checkpoint_args):
_compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size')
_compare('tokenizer_type')
if args.data_parallel_random_init:
_compare('data_parallel_random_init')
if get_checkpoint_version() < 3.0:
_compare('tensor_model_parallel_size',
old_arg_name='model_parallel_size')
......@@ -72,7 +74,6 @@ def check_checkpoint_args(checkpoint_args):
_compare('tensor_model_parallel_size')
_compare('pipeline_model_parallel_size')
def ensure_directory_exists(filename):
"""Build filename's path if it does not already exists."""
dirname = os.path.dirname(filename)
......@@ -140,6 +141,32 @@ def read_metadata(tracker_filename):
return max_iter, release
def get_rng_state():
""" collect rng state across data parallel ranks """
args = get_args()
rng_state = {
'random_rng_state': random.getstate(),
'np_rng_state': np.random.get_state(),
'torch_rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state(),
'rng_tracker_states': mpu.get_cuda_rng_tracker().get_states()}
rng_state_list = None
if torch.distributed.is_initialized() and \
mpu.get_data_parallel_world_size() > 1 and \
args.data_parallel_random_init:
rng_state_list = \
[None for i in range(mpu.get_data_parallel_world_size())]
torch.distributed.all_gather_object(
rng_state_list,
rng_state,
group=mpu.get_data_parallel_group())
else:
rng_state_list = [rng_state]
return rng_state_list
def save_checkpoint(iteration, model, optimizer, lr_scheduler):
"""Save a model checkpoint."""
args = get_args()
......@@ -150,6 +177,9 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save))
# collect rng state across data parallel ranks
rng_state = get_rng_state()
if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0:
# Arguments, iteration, and model.
......@@ -173,12 +203,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
# RNG states.
if not args.no_save_rng:
state_dict['random_rng_state'] = random.getstate()
state_dict['np_rng_state'] = np.random.get_state()
state_dict['torch_rng_state'] = torch.get_rng_state()
state_dict['cuda_rng_state'] = torch.cuda.get_rng_state()
state_dict['rng_tracker_states'] \
= mpu.get_cuda_rng_tracker().get_states()
state_dict["rng_state"] = rng_state
# Save.
checkpoint_name = get_checkpoint_name(args.save, iteration)
......@@ -381,15 +406,32 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
# rng states.
if not release and not args.finetune and not args.no_load_rng:
try:
random.setstate(state_dict['random_rng_state'])
np.random.set_state(state_dict['np_rng_state'])
torch.set_rng_state(state_dict['torch_rng_state'])
torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
# Check for empty states array
if not state_dict['rng_tracker_states']:
raise KeyError
mpu.get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states'])
if 'rng_state' in state_dict:
# access rng_state for data parallel rank
if args.data_parallel_random_init:
rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()]
else:
rng_state = state_dict['rng_state'][0]
random.setstate(rng_state['random_rng_state'])
np.random.set_state(rng_state['np_rng_state'])
torch.set_rng_state(rng_state['torch_rng_state'])
torch.cuda.set_rng_state(rng_state['cuda_rng_state'])
# Check for empty states array
if not rng_state['rng_tracker_states']:
raise KeyError
mpu.get_cuda_rng_tracker().set_states(
rng_state['rng_tracker_states'])
else: # backward compatability
random.setstate(state_dict['random_rng_state'])
np.random.set_state(state_dict['np_rng_state'])
torch.set_rng_state(state_dict['torch_rng_state'])
torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
# Check for empty states array
if not state_dict['rng_tracker_states']:
raise KeyError
mpu.get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states'])
except KeyError:
print_rank_0('Unable to load rng state from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent '
......
......@@ -16,8 +16,10 @@
"""Dataloaders."""
import torch
import random
import torch
import numpy as np
from torch.utils.data import Dataset
from megatron import get_args
from megatron import mpu
......@@ -39,11 +41,13 @@ def build_pretraining_data_loader(dataset, consumed_samples):
data_parallel_size=mpu.get_data_parallel_world_size())
elif args.dataloader_type == 'cyclic':
batch_sampler = MegatronPretrainingRandomSampler(
dataset,
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
data_parallel_size=mpu.get_data_parallel_world_size(),
data_sharding=args.data_sharding)
else:
raise Exception('{} dataloader type is not supported.'.format(
args.dataloader_type))
......@@ -103,16 +107,40 @@ class MegatronPretrainingSampler:
yield batch[start_idx:end_idx]
class RandomSeedDataset(Dataset):
def __init__(self, dataset):
args = get_args()
self.base_seed = args.seed
self.curr_seed = args.seed
self.dataset = dataset
def __len__(self):
return len(self.dataset)
def set_epoch(self, epoch):
self.curr_seed = self.base_seed + epoch
def __getitem__(self, idx):
seed = idx + self.curr_seed
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
return self.dataset[idx]
class MegatronPretrainingRandomSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size):
def __init__(self, dataset, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size, data_sharding):
# Keep a copy of input params for later use.
self.dataset = dataset
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.data_sharding = data_sharding
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.last_batch_size = \
......@@ -136,16 +164,30 @@ class MegatronPretrainingRandomSampler:
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
if isinstance(self.dataset, RandomSeedDataset):
self.dataset.set_epoch(self.epoch)
# data sharding and random sampling
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
* self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size
g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
if self.data_sharding:
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
* self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size
g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
else:
full_bucket_size = (self.total_samples // self.micro_batch_size) \
* self.micro_batch_size
full_bucket_offset = current_epoch_samples
g = torch.Generator()
g.manual_seed(self.epoch)
idx_range_total = \
torch.randperm(full_bucket_size, generator=g).tolist()
idx_range_active = idx_range_total[full_bucket_offset:]
idx_range = idx_range_active[self.data_parallel_rank::self.data_parallel_size]
batch = []
# Last batch if not complete will be dropped.
......
# BSD 3-Clause License
#
# Copyright (c) Soumith Chintala 2016,
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# code taken from
# https://github.com/pytorch/vision/blob/main/torchvision/datasets/folder.py
# added support for classes_fraction and data_per_class_fraction
from torchvision.datasets import VisionDataset
from PIL import Image
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
import numpy as np
def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
extensions (tuple of strings): extensions to consider (lowercase)
Returns:
bool: True if the filename ends with one of given extensions
"""
return filename.lower().endswith(extensions)
def is_image_file(filename: str) -> bool:
"""Checks if a file is an allowed image extension.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
def make_dataset(
directory: str,
class_to_idx: Dict[str, int],
data_per_class_fraction: float,
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
Args:
directory (str): root dataset directory
class_to_idx (Dict[str, int]): dictionary mapping class name to class index
extensions (optional): A list of allowed extensions.
Either extensions or is_valid_file should be passed. Defaults to None.
is_valid_file (optional): A function that takes path of a file
and checks if the file is a valid file
(used to check of corrupt files) both extensions and
is_valid_file should not be passed. Defaults to None.
Raises:
ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
Returns:
List[Tuple[str, int]]: samples of a form (path_to_sample, class)
"""
instances = []
directory = os.path.expanduser(directory)
both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
is_valid_file = cast(Callable[[str], bool], is_valid_file)
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
local_instances = []
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid_file(path):
item = path, class_index
local_instances.append(item)
instances.extend(local_instances[0:int(len(local_instances) * data_per_class_fraction)])
return instances
class DatasetFolder(VisionDataset):
"""A generic data loader where the samples are arranged in this way: ::
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/[...]/xxz.ext
root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/[...]/asd932_.ext
Args:
root (string): Root directory path.
loader (callable): A function to load a sample given its path.
extensions (tuple[string]): A list of allowed extensions.
both extensions and is_valid_file should not be passed.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
is_valid_file (callable, optional): A function that takes path of a file
and check if the file is a valid file (used to check of corrupt files)
both extensions and is_valid_file should not be passed.
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
targets (list): The class_index value for each image in the dataset
"""
def __init__(
self,
root: str,
loader: Callable[[str], Any],
extensions: Optional[Tuple[str, ...]] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
classes_fraction=1.0,
data_per_class_fraction=1.0,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> None:
super(DatasetFolder, self).__init__(root, transform=transform,
target_transform=target_transform)
self.classes_fraction = classes_fraction
self.data_per_class_fraction = data_per_class_fraction
classes, class_to_idx = self._find_classes(self.root)
samples = self.make_dataset(self.root,
class_to_idx,
self.data_per_class_fraction,
extensions,
is_valid_file)
if len(samples) == 0:
msg = "Found 0 files in subfolders of: {}\n".format(self.root)
if extensions is not None:
msg += "Supported extensions are: {}".format(",".join(extensions))
raise RuntimeError(msg)
self.loader = loader
self.extensions = extensions
self.total = len(samples)
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
@staticmethod
def make_dataset(
directory: str,
class_to_idx: Dict[str, int],
data_per_class_fraction: float,
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
return make_dataset(directory,
class_to_idx,
data_per_class_fraction,
extensions=extensions,
is_valid_file=is_valid_file)
def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
all_classes = [d.name for d in os.scandir(dir) if d.is_dir()]
classes = all_classes[0:int(len(all_classes) * self.classes_fraction)]
classes.sort()
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
curr_index = index
for x in range(self.total):
try:
path, target = self.samples[curr_index]
sample = self.loader(path)
break
except Exception as e:
curr_index = np.random.randint(0, self.total)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self) -> int:
return len(self.samples)
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
# TODO: specify the return type
def accimage_loader(path: str) -> Any:
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_loader(path: str) -> Any:
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
class ImageFolder(DatasetFolder):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
classes_fraction=1.0,
data_per_class_fraction=1.0,
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
):
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
classes_fraction=classes_fraction,
data_per_class_fraction=data_per_class_fraction,
is_valid_file=is_valid_file)
self.imgs = self.samples
......@@ -13,46 +13,67 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import numpy as np
import torch
from torchvision import datasets, transforms
import torchvision.transforms as T
from torchvision import datasets
from megatron import get_args
from megatron.data.image_folder import ImageFolder
from megatron.data.autoaugment import ImageNetPolicy
from megatron.data.data_samplers import RandomSeedDataset
class ClassificationTransform():
def __init__(self, image_size, train=True):
args = get_args()
assert args.fp16 or args.bf16
self.data_type = torch.half if args.fp16 else torch.bfloat16
if train:
self.transform = T.Compose([
T.RandomResizedCrop(image_size),
T.RandomHorizontalFlip(),
T.ColorJitter(0.4, 0.4, 0.4, 0.1),
ImageNetPolicy(),
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
T.ConvertImageDtype(self.data_type)
])
else:
self.transform = T.Compose([
T.Resize(image_size),
T.CenterCrop(image_size),
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
T.ConvertImageDtype(self.data_type)
])
def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True):
def __call__(self, input):
output = self.transform(input)
return output
def build_train_valid_datasets(data_path, image_size=224):
args = get_args()
train_transform = ClassificationTransform(image_size)
val_transform = ClassificationTransform(image_size, train=False)
# training dataset
train_data_path = os.path.join(data_path[0], "train")
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
process = [
transforms.RandomResizedCrop(crop_size),
transforms.RandomHorizontalFlip(),
]
if color_jitter:
process += [
transforms.ColorJitter(
brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1
)
]
fp16_t = transforms.ConvertImageDtype(torch.half)
process += [ImageNetPolicy(), transforms.ToTensor(), normalize, fp16_t]
transform_train = transforms.Compose(process)
train_data = datasets.ImageFolder(
root=train_data_path, transform=transform_train
train_data_path = data_path[0]
train_data = ImageFolder(
root=train_data_path,
transform=train_transform,
classes_fraction=args.classes_fraction,
data_per_class_fraction=args.data_per_class_fraction
)
train_data = RandomSeedDataset(train_data)
# validation dataset
val_data_path = os.path.join(data_path[0], "val")
transform_val = transforms.Compose(
[
transforms.Resize(crop_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
normalize,
fp16_t
]
)
val_data = datasets.ImageFolder(
root=val_data_path, transform=transform_val
val_data_path = data_path[1]
val_data = ImageFolder(
root=val_data_path,
transform=val_transform
)
val_data = RandomSeedDataset(val_data)
return train_data, val_data
......@@ -78,6 +78,12 @@ def load(args):
scaled_masked_softmax_cuda = _cpp_extention_load_helper(
"scaled_masked_softmax_cuda", sources, extra_cuda_flags)
# Softmax
sources=[srcpath / 'scaled_softmax.cpp',
srcpath / 'scaled_softmax_cuda.cu']
scaled_softmax_cuda = _cpp_extention_load_helper(
"scaled_softmax_cuda", sources, extra_cuda_flags)
# =================================
# Mixed precision fused layer norm.
# =================================
......
......@@ -90,6 +90,117 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
*/
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_softmax_warp_forward(
output_t *dst,
const input_t *src,
const acc_t scale,
int micro_batch_size,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
int itr_idx = i*element_count+it*WARP_SIZE;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = (acc_t)temp_data[element] * scale;
}
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
}
}
// compute max_value
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
} else {
break;
}
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
......@@ -326,6 +437,98 @@ int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int att
return batches_per_block;
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_softmax_forward(
output_t *dst,
const input_t *src,
const input_t scale,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads)
{
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 );
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 1: // 2
scaled_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 2: // 4
scaled_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 3: // 8
scaled_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 4: // 16
scaled_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 5: // 32
scaled_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 6: // 64
scaled_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 7: // 128
scaled_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 8: // 256
scaled_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 9: // 512
scaled_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 10: // 1024
scaled_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 11: // 2048
scaled_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 12: // 4096
scaled_softmax_warp_forward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
default:
break;
}
}
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_forward(
output_t *dst,
......@@ -338,7 +541,7 @@ void dispatch_scaled_masked_softmax_forward(
int attn_heads,
int pad_batches)
{
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 );
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 );
if (key_seq_len == 0) {
return;
} else {
......@@ -410,6 +613,10 @@ void dispatch_scaled_masked_softmax_forward(
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 12: // 4096
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
default:
break;
}
......@@ -427,7 +634,7 @@ void dispatch_scaled_masked_softmax_backward(
int batches,
int attn_heads)
{
TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 );
TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 4096 );
if (key_seq_len == 0) {
return;
} else {
......@@ -498,6 +705,11 @@ void dispatch_scaled_masked_softmax_backward(
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 12: // 4096
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
default:
break;
}
......
......@@ -44,7 +44,7 @@ torch::Tensor fwd_cuda(
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 2048);
TORCH_INTERNAL_ASSERT(key_seq_len <= 4096);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
......
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
torch::Tensor fwd(
torch::Tensor const& input,
float scale_factor) {
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return fwd_cuda(input, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
} // end namespace scaled_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&multihead_attn::fused_softmax::scaled_softmax::fwd,
"Self Multihead Attention scaled, softmax -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::scaled_softmax::bwd,
"Self Multihead Attention scaled, softmax -- Backward.");
}
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = input.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 4096);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_softmax_forward",
dispatch_scaled_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads);
);
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1);
const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads);
);
//backward pass is completely in-place
return output_grads;
}
}
}
}
......@@ -62,7 +62,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Random seeds for reproducibility.
if args.rank == 0:
print('> setting random seeds to {} ...'.format(args.seed))
_set_random_seed(args.seed)
_set_random_seed(args.seed, args.data_parallel_random_init)
# Set pytorch JIT layer fusion options.
_set_jit_fusion_options()
......@@ -118,7 +118,7 @@ def _compile_dependencies():
args.micro_batch_size
# Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \
custom_kernel_constraint = seq_len > 16 and seq_len <=4096 and \
seq_len % 4 == 0 and attn_batch_size % 4 == 0
# Print a warning.
if not ((args.fp16 or args.bf16) and
......@@ -203,11 +203,14 @@ def _init_autoresume():
torch.distributed.barrier()
def _set_random_seed(seed_):
def _set_random_seed(seed_, data_parallel_random_init=False):
"""Set random seed for reproducability."""
if seed_ is not None and seed_ > 0:
# Ensure that different pipeline MP stages get different seeds.
seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
# Ensure different data parallel ranks get different seeds
if data_parallel_random_init:
seed = seed + (10 * mpu.get_data_parallel_rank())
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
......
......@@ -185,6 +185,13 @@ class DistributedDataParallel(DistributedDataParallelBase):
buffer_.zero()
def broadcast_params(self):
for param in self.module.parameters():
torch.distributed.broadcast(param.data,
src=mpu.get_data_parallel_src_rank(),
group=mpu.get_data_parallel_group())
def allreduce_gradients(self):
"""Reduce gradients across data parallel ranks."""
# If we have buffers, simply reduce the data in the buffer.
......
......@@ -81,6 +81,37 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
return input_grads, None, None
class ScaledSoftmax(torch.autograd.Function):
"""
Fused operation which performs following two operations in sequence
1. Scale the tensor.
2. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, scale):
import scaled_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_softmax_cuda.forward(
inputs, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
import scaled_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_softmax_cuda.backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None, None
class FusedScaleMaskSoftmax(nn.Module):
"""
fused operation: scaling + mask + softmax
......@@ -137,12 +168,11 @@ class FusedScaleMaskSoftmax(nn.Module):
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None
and 16 < sk <= 2048 # sk must be 16 ~ 2048
and 16 < sk <= 4096 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 2048:
if 0 <= sk <= 4096:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
if self.attn_mask_type == AttnMaskType.causal:
......@@ -166,7 +196,10 @@ class FusedScaleMaskSoftmax(nn.Module):
return probs.view(b, np, sq, sk)
else:
# input is 4D tensor (b, np, sq, sk)
return ScaledMaskedSoftmax.apply(input, mask, scale)
if mask is not None:
return ScaledMaskedSoftmax.apply(input, mask, scale)
else:
return ScaledSoftmax.apply(input, scale)
def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
......
......@@ -27,7 +27,6 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
......@@ -699,9 +698,32 @@ class ParallelTransformer(MegatronModule):
# See set_input_tensor()
hidden_states = self.input_tensor
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states = mpu.make_viewless_tensor(
hidden_states,
requires_grad = True,
keep_graph = True,
)
# Transpose encoder output.
if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous()
encoder_output = encoder_output.transpose(0, 1).contiguous()
# Forward pass.
if self.activations_checkpoint_method is not None:
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask,
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision Transformer(VIT) model."""
import torch
from megatron import get_args
from megatron.model.utils import get_linear_layer
from megatron.model.vision.vit_backbone import VitBackbone, VitMlpHead
from megatron.model.module import MegatronModule
class VitClassificationModel(MegatronModule):
"""Vision Transformer Model."""
def __init__(self, num_classes, finetune=False,
pre_process=True, post_process=True):
super(VitClassificationModel, self).__init__()
args = get_args()
self.hidden_size = args.hidden_size
self.num_classes = num_classes
self.finetune = finetune
self.pre_process = pre_process
self.post_process = post_process
self.backbone = VitBackbone(
pre_process=self.pre_process,
post_process=self.post_process,
single_token_output=True
)
if self.post_process:
if not self.finetune:
self.head = VitMlpHead(self.hidden_size, self.num_classes)
else:
self.head = get_linear_layer(
self.hidden_size,
self.num_classes,
torch.nn.init.zeros_
)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.backbone.set_input_tensor(input_tensor)
def forward(self, input):
hidden_states = self.backbone(input)
if self.post_process:
hidden_states = self.head(hidden_states)
return hidden_states
......@@ -18,16 +18,19 @@
import math
import einops
import torch
import apex
import torch.nn.functional as F
from megatron import get_args
from megatron.model import LayerNorm
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import (
get_linear_layer,
init_method_normal,
scaled_init_method_normal,
)
from .module import MegatronModule
from megatron.model.module import MegatronModule
CLASS_TOKEN_LENGTH = 8
class VitMlpHead(MegatronModule):
"""Pooler layer.
......@@ -44,19 +47,26 @@ class VitMlpHead(MegatronModule):
def __init__(self, hidden_size, num_classes):
super(VitMlpHead, self).__init__()
self.dense_in = torch.nn.Linear(hidden_size, hidden_size)
self.relu = torch.nn.ReLU()
self.dense_out = torch.nn.Linear(hidden_size, num_classes)
torch.nn.init.constant_(self.dense_out.bias, -10)
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h]
def forward(self, hidden_states):
# hidden_states: [b, 1, h]
# sequence_index: index of the token to pool.
hidden_state = hidden_states[:, sequence_index, :]
dense_in_result = self.dense_in(hidden_state)
dense_in_result = self.dense_in(hidden_states)
tanh_result = torch.tanh(dense_in_result)
dense_out_result = self.dense_out(tanh_result)
return dense_out_result
def isPerfectSquare(x):
if(x >= 0):
sr = math.sqrt(x)
return (int(sr) * int(sr) == x)
return False
def twod_interpolate_position_embeddings_hook(
state_dict,
prefix,
......@@ -68,66 +78,77 @@ def twod_interpolate_position_embeddings_hook(
):
args = get_args()
num_patches_per_dim = args.img_dim // args.patch_dim
num_patches = num_patches_per_dim ** 2
seq_length = num_patches + 1
num_patches_per_dim_h = args.img_h // args.patch_dim
num_patches_per_dim_w = args.img_w // args.patch_dim
num_patches = num_patches_per_dim_h * num_patches_per_dim_w
hidden_size = args.hidden_size
key = prefix + "weight"
# import pdb
# pdb.set_trace()
assert key in state_dict
if key in state_dict:
input_param = state_dict[key]
input_seq_len = input_param.shape[0]
assert(isPerfectSquare(input_seq_len) or isPerfectSquare(input_seq_len - CLASS_TOKEN_LENGTH))
input_has_class_token = not isPerfectSquare(input_seq_len)
num_tok_input = input_seq_len - CLASS_TOKEN_LENGTH if input_has_class_token else input_seq_len
num_tok_output = num_patches
output_has_class_token = args.class_token_present
# update input_param and load it to state_dict[key]
if input_has_class_token:
input_param_tok = input_param[:CLASS_TOKEN_LENGTH, :]
input_param_grid = input_param[CLASS_TOKEN_LENGTH:, :]
else:
input_param_tok = torch.zeros(CLASS_TOKEN_LENGTH, hidden_size)
input_param_grid = input_param
assert input_param.shape[1] == hidden_size
if input_param.shape[0] != seq_length:
# update input_param and load it to state_dict[key]
num_tok_input = input_param.shape[0] - 1
num_tok_new = seq_length - 1
input_param_tok, input_param_grid = (
input_param[:1, :],
input_param[1:, :],
)
if num_tok_input != num_tok_output:
gs_input = int(math.sqrt(num_tok_input))
gs_new = int(math.sqrt(num_tok_new))
gs_new = (num_patches_per_dim_h, num_patches_per_dim_w)
input_param_grid = input_param_grid.transpose(0, 1).contiguous()
input_param_grid = input_param_grid.reshape(
(1, -1, gs_input, gs_input)
)
input_param_grid = input_param_grid.float()
scale_factor = gs_new / gs_input
scale_factor = (gs_new[0] / gs_input, gs_new[1] / gs_input)
input_param_grid = F.interpolate(
input_param_grid, scale_factor=scale_factor, mode="bilinear"
)
input_param_grid = input_param_grid.half()
input_param_grid = input_param_grid.reshape((-1, gs_new * gs_new))
input_param_grid = input_param_grid.reshape((-1, num_tok_output))
input_param_grid = input_param_grid.transpose(0, 1).contiguous()
assert input_param_grid.shape[1] == hidden_size
input_param = torch.cat((input_param_tok, input_param_grid), dim=0)
assert (
input_param.shape[0] == seq_length
and input_param.shape[1] == hidden_size
)
state_dict[key] = input_param
input_param = input_param_grid
assert (
input_param.shape[0] == num_tok_output
and input_param.shape[1] == hidden_size
)
if output_has_class_token:
input_param = torch.cat((input_param_tok, input_param), dim=0)
state_dict[key] = input_param
class VitModel(MegatronModule):
class VitBackbone(MegatronModule):
"""Vision Transformer Model."""
def __init__(self,
num_classes,
finetune=False,
def __init__(self,
pre_process=True,
post_process=True):
super(VitModel, self).__init__(share_word_embeddings=False)
post_process=True,
class_token=True,
single_token_output=False):
super(VitBackbone, self).__init__(share_word_embeddings=False)
args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
......@@ -142,25 +163,33 @@ class VitModel(MegatronModule):
self.pre_process = pre_process
self.post_process = post_process
self.class_token = class_token
self.hidden_size = args.hidden_size
self.num_classes = num_classes
self.patch_dim = args.patch_dim
self.img_dim = args.img_dim
self.finetune = finetune
assert self.img_dim % self.patch_dim == 0
self.num_patches_per_dim = self.img_dim // self.patch_dim
self.num_patches = self.num_patches_per_dim ** 2
self.seq_length = self.num_patches + 1
self.img_h = args.img_h
self.img_w = args.img_w
self.micro_batch_size = args.micro_batch_size
self.single_token_output = single_token_output
assert self.img_h % self.patch_dim == 0
assert self.img_w % self.patch_dim == 0
self.num_patches_per_dim_h = self.img_h // self.patch_dim
self.num_patches_per_dim_w = self.img_w // self.patch_dim
self.num_patches = self.num_patches_per_dim_h * self.num_patches_per_dim_w
self.seq_length = self.num_patches + (CLASS_TOKEN_LENGTH if self.class_token else 0)
self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels
self.input_tensor = None
self.position_ids = None
if self.pre_process:
# cls_token
self.cls_token = torch.nn.Parameter(
torch.randn(1, 1, self.hidden_size)
)
torch.nn.init.zeros_(self.cls_token)
if self.class_token:
self.cls_token = torch.nn.Parameter(
torch.randn(1, CLASS_TOKEN_LENGTH, self.hidden_size)
)
torch.nn.init.zeros_(self.cls_token)
self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
# Linear encoder
self.linear_encoder = torch.nn.Linear(
self.flatten_dim, self.hidden_size
......@@ -173,8 +202,8 @@ class VitModel(MegatronModule):
init_method_normal(args.init_method_std)(
self.position_embeddings.weight
)
self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
args.class_token_present = self.class_token
self.position_embeddings._register_load_state_dict_pre_hook(
twod_interpolate_position_embeddings_hook
)
......@@ -183,21 +212,12 @@ class VitModel(MegatronModule):
# Transformer
self.transformer = ParallelTransformer(
self.init_method,
self.init_method,
self.scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process
post_process=self.post_process,
)
if self.post_process:
# MLP head
if not self.finetune:
self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes)
else:
self.class_head = get_linear_layer(
self.hidden_size, num_classes, torch.nn.init.zeros_
)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.transformer.set_input_tensor(input_tensor)
......@@ -214,21 +234,22 @@ class VitModel(MegatronModule):
assert rearranged_input.dtype == torch.half
encoder_output = self.linear_encoder(rearranged_input)
cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1)
concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1)
concatenated_tokens = encoder_output
if self.class_token:
cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1)
concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1)
token_embeddings = concatenated_tokens + \
self.position_embeddings(self.position_ids)
self.position_embeddings(self.position_ids[:, :concatenated_tokens.shape[1]])
hidden_states = self.embedding_dropout(token_embeddings)
else:
hidden_states = input
hidden_states = self.transformer(hidden_states, None)
if self.post_process:
if not self.finetune:
hidden_states = self.mlp_head(hidden_states)
else:
hidden_states = self.class_head(hidden_states[:, 0, :])
if self.single_token_output:
hidden_states = hidden_states[:,0,:]
return hidden_states
......@@ -38,6 +38,7 @@ from .initialize import is_pipeline_stage_before_split, is_pipeline_stage_after_
from .initialize import is_pipeline_stage_at_split
from .initialize import get_num_layers
from .initialize import get_tensor_model_parallel_src_rank
from .initialize import get_data_parallel_src_rank
from .initialize import get_pipeline_model_parallel_first_rank
from .initialize import get_pipeline_model_parallel_last_rank
from .initialize import get_pipeline_model_parallel_next_rank
......@@ -65,6 +66,9 @@ from .random import get_cuda_rng_tracker
from .random import model_parallel_cuda_manual_seed
from .random import gather_split_1d_tensor
from .random import split_tensor_into_1d_equal_chunks
from .random import make_viewless_tensor
from .random import assert_viewless_tensor
from .random import safely_set_viewless_tensor_data
from .utils import divide
from .utils import split_tensor_along_last_dim
......@@ -484,6 +484,15 @@ def get_tensor_model_parallel_src_rank():
return (global_rank // local_world_size) * local_world_size
def get_data_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank = torch.distributed.get_rank()
data_parallel_size = get_data_parallel_world_size()
num_data_parallel_groups = torch.distributed.get_world_size() // data_parallel_size
return global_rank % num_data_parallel_groups
def get_pipeline_model_parallel_first_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
......
......@@ -98,6 +98,84 @@ def gather_split_1d_tensor(tensor):
group=get_tensor_model_parallel_group())
return gathered
def _kernel_make_viewless_tensor(inp, requires_grad):
'''Make a viewless tensor.
View tensors have the undesirable side-affect of retaining a reference
to the originally-viewed tensor, even after manually setting the '.data'
field. This method creates a new tensor that links to the old tensor's
data, without linking the viewed tensor, referenced via the '._base'
field.
'''
out = torch.empty(
(1,),
dtype = inp.dtype,
device = inp.device,
requires_grad = requires_grad,
)
out.data = inp.data
return out
class MakeViewlessTensor(torch.autograd.Function):
'''
Autograd function to make a viewless tensor.
This function should be used in cases where the computation graph needs
to be propagated, but we only want a viewless tensor (e.g.,
ParallelTransformer's hidden_states). Call this function by passing
'keep_graph = True' to 'make_viewless_tensor()'.
'''
@staticmethod
def forward(ctx, inp, requires_grad):
return _kernel_make_viewless_tensor(inp, requires_grad)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
def make_viewless_tensor(inp, requires_grad, keep_graph):
'''
Entry-point for creating viewless tensors.
This method should be used, rather than calling 'MakeViewlessTensor'
or '_kernel_make_viewless_tensor' directly. This method acts as a
switch for determining if an autograd function or a regular method
should be used to create the tensor.
'''
# return tensor as-is, if not a 'view'
if inp._base is None:
return inp
# create viewless tensor
if keep_graph:
return MakeViewlessTensor.apply(inp, requires_grad)
else:
return _kernel_make_viewless_tensor(inp, requires_grad)
def assert_viewless_tensor(tensor, extra_msg = None):
'''Assert that a tensor is not a view (i.e., its '._base' field is
not set).'''
if isinstance(tensor, list):
[ assert_viewless_tensor(t) for t in tensor ]
return tensor
if not isinstance(tensor, torch.Tensor):
return tensor
assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). %s"
) % extra_msg
return tensor
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
'''Safely set tensor's '.data' field.
Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
'''
assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
tensor.data = new_data_tensor
class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
......@@ -242,8 +320,9 @@ class CheckpointFunction(torch.autograd.Function):
# the chunk corresponding to the current rank.
if distribute_checkpointed_activations:
ctx.input_0_shape = args[0].data.shape
args[0].data = split_tensor_into_1d_equal_chunks(args[0].data,
new_buffer=True)
safely_set_viewless_tensor_data(
args[0],
split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True))
# Store everything.
ctx.save_for_backward(*args)
......@@ -257,8 +336,9 @@ class CheckpointFunction(torch.autograd.Function):
"please use .backward() if possible")
inputs = ctx.saved_tensors
if ctx.distribute_checkpointed_activations:
inputs[0].data = gather_split_1d_tensor(inputs[0].data)
inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
safely_set_viewless_tensor_data(
inputs[0],
gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape))
# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
......
......@@ -35,14 +35,14 @@ def _get_params_for_weight_decay_optimization(modules):
if isinstance(module_, LayerNorm):
no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values())
if p is not None])
if p is not None and p.requires_grad])
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n != 'bias'])
if p is not None and p.requires_grad and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
if p is not None and p.requires_grad and n == 'bias'])
return weight_decay_params, no_weight_decay_params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment