Commit 456f1728 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

vision transformer model and vision classification task

parent f5eac3d1
...@@ -40,6 +40,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -40,6 +40,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_data_args(parser) parser = _add_data_args(parser)
parser = _add_autoresume_args(parser) parser = _add_autoresume_args(parser)
parser = _add_realm_args(parser) parser = _add_realm_args(parser)
parser = _add_vit_args(parser)
# Custom arguments. # Custom arguments.
if extra_args_provider is not None: if extra_args_provider is not None:
...@@ -271,6 +272,8 @@ def _add_regularization_args(parser): ...@@ -271,6 +272,8 @@ def _add_regularization_args(parser):
group.add_argument('--adam-eps', type=float, default=1e-08, group.add_argument('--adam-eps', type=float, default=1e-08,
help='Term added to the denominator to improve' help='Term added to the denominator to improve'
'numerical stability') 'numerical stability')
group.add_argument('--sgd-momentum', type=float, default=0.9,
help='Momentum factor for sgd')
return parser return parser
...@@ -346,6 +349,9 @@ def _add_training_args(parser): ...@@ -346,6 +349,9 @@ def _add_training_args(parser):
group.add_argument('--no-bias-dropout-fusion', action='store_false', group.add_argument('--no-bias-dropout-fusion', action='store_false',
help='Disable bias and dropout fusion.', help='Disable bias and dropout fusion.',
dest='bias_dropout_fusion') dest='bias_dropout_fusion')
group.add_argument('--optimizer', type=str, default='adam',
choices=['adam', 'sgd'],
help='Optimizer function')
return parser return parser
...@@ -359,6 +365,8 @@ def _add_initialization_args(parser): ...@@ -359,6 +365,8 @@ def _add_initialization_args(parser):
group.add_argument('--init-method-std', type=float, default=0.02, group.add_argument('--init-method-std', type=float, default=0.02,
help='Standard deviation of the zero mean normal ' help='Standard deviation of the zero mean normal '
'distribution used for weight initialization.') 'distribution used for weight initialization.')
group.add_argument('--init-method-xavier-uniform', action='store_true',
help='Enable Xavier uniform parameter initialization')
return parser return parser
...@@ -607,3 +615,20 @@ def _add_realm_args(parser): ...@@ -607,3 +615,20 @@ def _add_realm_args(parser):
group.add_argument('--indexer-log-interval', type=int, default=1000, group.add_argument('--indexer-log-interval', type=int, default=1000,
help='After how many batches should the indexer report progress') help='After how many batches should the indexer report progress')
return parser return parser
def _add_vit_args(parser):
group = parser.add_argument_group(title="vit")
group.add_argument('--vit-load', type=str, default=None,
help='Director containing a VitModel checkpoint')
group.add_argument('--num-classes', type=int, default=1000,
help='num of classes in vision classificaiton task')
group.add_argument('--img-dim', type=int, default=224,
help='Image size for vision classification task')
group.add_argument('--num-channels', type=int, default=3,
help='Number of channels in input image data')
group.add_argument('--patch-dim', type=int, default=16,
help='patch dimension used in vit')
return parser
...@@ -59,6 +59,7 @@ def check_checkpoint_args(checkpoint_args): ...@@ -59,6 +59,7 @@ def check_checkpoint_args(checkpoint_args):
_compare('hidden_size') _compare('hidden_size')
_compare('num_attention_heads') _compare('num_attention_heads')
_compare('max_position_embeddings') _compare('max_position_embeddings')
if args.vit_load is None:
_compare('make_vocab_size_divisible_by') _compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size') _compare('padded_vocab_size')
_compare('tokenizer_type') _compare('tokenizer_type')
...@@ -159,7 +160,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -159,7 +160,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
torch.distributed.barrier() torch.distributed.barrier()
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True):
"""Load a model checkpoint and return the iteration.""" """Load a model checkpoint and return the iteration."""
args = get_args() args = get_args()
load_dir = getattr(args, load_arg) load_dir = getattr(args, load_arg)
...@@ -252,7 +253,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): ...@@ -252,7 +253,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
print_rank_0('could not find arguments in the checkpoint ...') print_rank_0('could not find arguments in the checkpoint ...')
# Model. # Model.
model.load_state_dict(state_dict['model']) model.load_state_dict(state_dict['model'], strict=strict)
# Optimizer. # Optimizer.
if not release and not args.finetune and not args.no_load_optim: if not release and not args.finetune and not args.no_load_optim:
......
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
"""AutoAugment data augmentation policy for ImageNet.
Implements the fixed AutoAugment data augmentation policy for ImageNet
provided in Appendix A, Table 9 in reference [1]. Does not include any
of the search code.
Reference:
[1] https://arxiv.org/abs/1805.09501
Code adapted from:
https://github.com/DeepVoltaire/AutoAugment
"""
import random
import numpy as np
from PIL import Image
from PIL import ImageEnhance
from PIL import ImageOps
_MAX_LEVEL = 10 # Maximum integer strength of an augmentation, if applicable.
class ImageNetPolicy:
"""Definition of an ImageNetPolicy.
Implements a fixed AutoAugment data augmentation policy targeted at
ImageNet training by randomly applying at runtime one of the 25 pre-defined
data augmentation sub-policies provided in Reference [1].
Usage example as a Pytorch Transform:
>>> transform=transforms.Compose([transforms.Resize(256),
>>> ImageNetPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
"""Initialize an ImageNetPolicy.
Args:
fillcolor (tuple): RGB color components of the color to be used for
filling when needed (default: (128, 128, 128), which
corresponds to gray).
"""
# Instantiate a list of sub-policies.
# Each entry of the list is a SubPolicy which consists of
# two augmentation operations,
# each of those parametrized as operation, probability, magnitude.
# Those two operations are applied sequentially on the image upon call.
self.policies = [
SubPolicy("posterize", 0.4, 8, "rotate", 0.6, 9, fillcolor),
SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor),
SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor),
SubPolicy("posterize", 0.6, 7, "posterize", 0.6, 6, fillcolor),
SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor),
SubPolicy("equalize", 0.4, 4, "rotate", 0.8, 8, fillcolor),
SubPolicy("solarize", 0.6, 3, "equalize", 0.6, 7, fillcolor),
SubPolicy("posterize", 0.8, 5, "equalize", 1.0, 2, fillcolor),
SubPolicy("rotate", 0.2, 3, "solarize", 0.6, 8, fillcolor),
SubPolicy("equalize", 0.6, 8, "posterize", 0.4, 6, fillcolor),
SubPolicy("rotate", 0.8, 8, "color", 0.4, 0, fillcolor),
SubPolicy("rotate", 0.4, 9, "equalize", 0.6, 2, fillcolor),
SubPolicy("equalize", 0.0, 7, "equalize", 0.8, 8, fillcolor),
SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor),
SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor),
SubPolicy("rotate", 0.8, 8, "color", 1.0, 2, fillcolor),
SubPolicy("color", 0.8, 8, "solarize", 0.8, 7, fillcolor),
SubPolicy("sharpness", 0.4, 7, "invert", 0.6, 8, fillcolor),
SubPolicy("shearX", 0.6, 5, "equalize", 1.0, 9, fillcolor),
SubPolicy("color", 0.4, 0, "equalize", 0.6, 3, fillcolor),
SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor),
SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor),
SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor),
SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor),
SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor),
]
def __call__(self, img):
"""Define call method for ImageNetPolicy class."""
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
"""Define repr method for ImageNetPolicy class."""
return "ImageNetPolicy"
class SubPolicy:
"""Definition of a SubPolicy.
A SubPolicy consists of two augmentation operations,
each of those parametrized as operation, probability, magnitude.
The two operations are applied sequentially on the image upon call.
"""
def __init__(
self,
operation1,
probability1,
magnitude_idx1,
operation2,
probability2,
magnitude_idx2,
fillcolor,
):
"""Initialize a SubPolicy.
Args:
operation1 (str): Key specifying the first augmentation operation.
There are fourteen key values altogether (see supported_ops below
listing supported operations). probability1 (float): Probability
within [0., 1.] of applying the first augmentation operation.
magnitude_idx1 (int): Integer specifiying the strength of the first
operation as an index further used to derive the magnitude from a
range of possible values.
operation2 (str): Key specifying the second augmentation operation.
probability2 (float): Probability within [0., 1.] of applying the
second augmentation operation.
magnitude_idx2 (int): Integer specifiying the strength of the
second operation as an index further used to derive the magnitude
from a range of possible values.
fillcolor (tuple): RGB color components of the color to be used for
filling.
Returns:
"""
# List of supported operations for operation1 and operation2.
supported_ops = [
"shearX",
"shearY",
"translateX",
"translateY",
"rotate",
"color",
"posterize",
"solarize",
"contrast",
"sharpness",
"brightness",
"autocontrast",
"equalize",
"invert",
]
assert (operation1 in supported_ops) and (
operation2 in supported_ops
), "SubPolicy:one of oper1 or oper2 refers to an unsupported operation."
assert (
0.0 <= probability1 <= 1.0 and 0.0 <= probability2 <= 1.0
), "SubPolicy: prob1 and prob2 should be within [0., 1.]."
assert (
isinstance(magnitude_idx1, int) and 0 <= magnitude_idx1 <= 10
), "SubPolicy: idx1 should be specified as an integer within [0, 10]."
assert (
isinstance(magnitude_idx2, int) and 0 <= magnitude_idx2 <= 10
), "SubPolicy: idx2 should be specified as an integer within [0, 10]."
# Define a dictionary where each key refers to a specific type of
# augmentation and the corresponding value is a range of ten possible
# magnitude values for that augmentation.
num_levels = _MAX_LEVEL + 1
ranges = {
"shearX": np.linspace(0, 0.3, num_levels),
"shearY": np.linspace(0, 0.3, num_levels),
"translateX": np.linspace(0, 150 / 331, num_levels),
"translateY": np.linspace(0, 150 / 331, num_levels),
"rotate": np.linspace(0, 30, num_levels),
"color": np.linspace(0.0, 0.9, num_levels),
"posterize": np.round(np.linspace(8, 4, num_levels), 0).astype(
np.int
),
"solarize": np.linspace(256, 0, num_levels), # range [0, 256]
"contrast": np.linspace(0.0, 0.9, num_levels),
"sharpness": np.linspace(0.0, 0.9, num_levels),
"brightness": np.linspace(0.0, 0.9, num_levels),
"autocontrast": [0]
* num_levels, # This augmentation doesn't use magnitude parameter.
"equalize": [0]
* num_levels, # This augmentation doesn't use magnitude parameter.
"invert": [0]
* num_levels, # This augmentation doesn't use magnitude parameter.
}
def rotate_with_fill(img, magnitude):
"""Define rotation transformation with fill.
The input image is first rotated, then it is blended together with
a gray mask of the same size. Note that fillcolor as defined
elsewhere in this module doesn't apply here.
Args:
magnitude (float): rotation angle in degrees.
Returns:
rotated_filled (PIL Image): rotated image with gray filling for
disoccluded areas unveiled by the rotation.
"""
rotated = img.convert("RGBA").rotate(magnitude)
rotated_filled = Image.composite(
rotated, Image.new("RGBA", rotated.size, (128,) * 4), rotated
)
return rotated_filled.convert(img.mode)
# Define a dictionary of augmentation functions where each key refers
# to a specific type of augmentation and the corresponding value defines
# the augmentation itself using a lambda function.
# pylint: disable=unnecessary-lambda
func_dict = {
"shearX": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
Image.BICUBIC,
fillcolor=fillcolor,
),
"shearY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
Image.BICUBIC,
fillcolor=fillcolor,
),
"translateX": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(
1,
0,
magnitude * img.size[0] * random.choice([-1, 1]),
0,
1,
0,
),
fillcolor=fillcolor,
),
"translateY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(
1,
0,
0,
0,
1,
magnitude * img.size[1] * random.choice([-1, 1]),
),
fillcolor=fillcolor,
),
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
1 + magnitude * random.choice([-1, 1])
),
"posterize": lambda img, magnitude: ImageOps.posterize(
img, magnitude
),
"solarize": lambda img, magnitude: ImageOps.solarize(
img, magnitude
),
"contrast": lambda img, magnitude: ImageEnhance.Contrast(
img
).enhance(1 + magnitude * random.choice([-1, 1])),
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(
img
).enhance(1 + magnitude * random.choice([-1, 1])),
"brightness": lambda img, magnitude: ImageEnhance.Brightness(
img
).enhance(1 + magnitude * random.choice([-1, 1])),
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
"equalize": lambda img, magnitude: ImageOps.equalize(img),
"invert": lambda img, magnitude: ImageOps.invert(img),
}
# Store probability, function and magnitude of the first augmentation
# for the sub-policy.
self.probability1 = probability1
self.operation1 = func_dict[operation1]
self.magnitude1 = ranges[operation1][magnitude_idx1]
# Store probability, function and magnitude of the second augmentation
# for the sub-policy.
self.probability2 = probability2
self.operation2 = func_dict[operation2]
self.magnitude2 = ranges[operation2][magnitude_idx2]
def __call__(self, img):
"""Define call method for SubPolicy class."""
# Randomly apply operation 1.
if random.random() < self.probability1:
img = self.operation1(img, self.magnitude1)
# Randomly apply operation 2.
if random.random() < self.probability2:
img = self.operation2(img, self.magnitude2)
return img
...@@ -22,7 +22,7 @@ from megatron import get_args ...@@ -22,7 +22,7 @@ from megatron import get_args
from megatron import mpu from megatron import mpu
def build_pretraining_data_loader(dataset, consumed_samples): def build_pretraining_data_loader(dataset, consumed_samples, random_sample=False):
"""Buld dataloader given an input dataset.""" """Buld dataloader given an input dataset."""
if dataset is None: if dataset is None:
...@@ -35,7 +35,8 @@ def build_pretraining_data_loader(dataset, consumed_samples): ...@@ -35,7 +35,8 @@ def build_pretraining_data_loader(dataset, consumed_samples):
consumed_samples=consumed_samples, consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size, micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size()) data_parallel_size=mpu.get_data_parallel_world_size(),
random_sample=random_sample)
# Torch dataloader. # Torch dataloader.
return torch.utils.data.DataLoader(dataset, return torch.utils.data.DataLoader(dataset,
...@@ -46,41 +47,52 @@ def build_pretraining_data_loader(dataset, consumed_samples): ...@@ -46,41 +47,52 @@ def build_pretraining_data_loader(dataset, consumed_samples):
class MegatronPretrainingSampler: class MegatronPretrainingSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size, def __init__(self, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size): data_parallel_rank, data_parallel_size, random_sample=False):
# Keep a copy of input params for later use. # Keep a copy of input params for later use.
self.total_samples = total_samples self.total_samples = total_samples
self.consumed_samples = consumed_samples self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank self.data_parallel_rank = data_parallel_rank
self.micro_batch_times_data_parallel_size = self.micro_batch_size * \ self.micro_batch_times_data_parallel_size = \
data_parallel_size self.micro_batch_size * data_parallel_size
self.random_sample = random_sample
# Sanity checks. # Sanity checks.
assert self.total_samples > 0, \ assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples) 'no sample to consume: {}'.format(self.total_samples)
assert self.consumed_samples < self.total_samples, \ #assert self.consumed_samples < self.total_samples, \
'no samples left to consume: {}, {}'.format(self.consumed_samples, # 'no samples left to consume: {}, {}'.format(self.consumed_samples,
self.total_samples) # self.total_samples)
assert self.micro_batch_size > 0 assert self.micro_batch_size > 0
assert data_parallel_size > 0 assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \ assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \ 'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size) '{}'.format(self.data_parallel_rank, data_parallel_size)
def __len__(self): def __len__(self):
return self.total_samples return self.total_samples
def __iter__(self): def __iter__(self):
self.epoch = self.consumed_samples // self.total_samples
current_epoch_samples = self.consumed_samples % self.total_samples
if self.random_sample:
g = torch.Generator()
g.manual_seed(self.epoch)
idx_range_total = \
torch.randperm(self.total_samples, generator=g).tolist()
idx_range = idx_range_total[current_epoch_samples:]
else:
idx_range = range(current_epoch_samples, self.total_samples)
batch = [] batch = []
# Last batch if not complete will be dropped. # Last batch if not complete will be dropped.
for idx in range(self.consumed_samples, self.total_samples): for idx in idx_range:
batch.append(idx) batch.append(idx)
if len(batch) == self.micro_batch_times_data_parallel_size: if len(batch) == self.micro_batch_times_data_parallel_size:
self.consumed_samples += len(batch)
start_idx = self.data_parallel_rank * self.micro_batch_size start_idx = self.data_parallel_rank * self.micro_batch_size
end_idx = start_idx + self.micro_batch_size end_idx = start_idx + self.micro_batch_size
yield batch[start_idx:end_idx] yield batch[start_idx:end_idx]
batch = [] batch = []
self.consumed_samples += len(batch)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from torchvision import datasets, transforms
from megatron.data.autoaugment import ImageNetPolicy
def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True):
# training dataset
train_data_path = os.path.join(data_path[0], "train")
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
process = [
transforms.RandomResizedCrop(crop_size),
transforms.RandomHorizontalFlip(),
]
if color_jitter:
process += [
transforms.ColorJitter(
brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1
)
]
process += [ImageNetPolicy(), transforms.ToTensor(), normalize]
transform_train = transforms.Compose(process)
train_data = datasets.ImageFolder(
root=train_data_path, transform=transform_train
)
# validation dataset
val_data_path = os.path.join(data_path[0], "val")
transform_val = transforms.Compose(
[
transforms.Resize(crop_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
normalize,
]
)
val_data = datasets.ImageFolder(
root=val_data_path, transform=transform_val
)
return train_data, val_data
...@@ -83,6 +83,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={}, ...@@ -83,6 +83,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
defaults=args_defaults, defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args) ignore_unknown_args=ignore_unknown_args)
_build_num_microbatches_calculator(args) _build_num_microbatches_calculator(args)
if args.vocab_file:
_ = _build_tokenizer(args) _ = _build_tokenizer(args)
_set_tensorboard_writer(args) _set_tensorboard_writer(args)
_set_adlr_autoresume(args) _set_adlr_autoresume(args)
...@@ -131,7 +132,7 @@ def _set_tensorboard_writer(args): ...@@ -131,7 +132,7 @@ def _set_tensorboard_writer(args):
'tensorboard writer') 'tensorboard writer')
if hasattr(args, 'tensorboard_dir') and \ if hasattr(args, 'tensorboard_dir') and \
args.tensorboard_dir and args.rank == (args.world_size -1): args.tensorboard_dir and args.rank == (args.world_size - 1):
try: try:
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
print('> setting tensorboard ...') print('> setting tensorboard ...')
......
...@@ -30,6 +30,7 @@ def import_layernorm(fp32_residual_connection): ...@@ -30,6 +30,7 @@ def import_layernorm(fp32_residual_connection):
from .distributed import * from .distributed import *
from .vit_model import VitModel
from .bert_model import (BertModel, from .bert_model import (BertModel,
BertModelFirstStage, BertModelFirstStage,
BertModelIntermediateStage, BertModelIntermediateStage,
......
...@@ -28,9 +28,6 @@ from megatron.model.utils import init_method_normal ...@@ -28,9 +28,6 @@ from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule from .module import MegatronModule
def bert_attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def bert_extended_attention_mask(attention_mask): def bert_extended_attention_mask(attention_mask):
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
...@@ -144,7 +141,6 @@ class BertModelBase(MegatronModule): ...@@ -144,7 +141,6 @@ class BertModelBase(MegatronModule):
args.num_layers) args.num_layers)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=self.add_binary_head, add_pooler=self.add_binary_head,
init_method=init_method, init_method=init_method,
......
...@@ -19,7 +19,7 @@ import torch ...@@ -19,7 +19,7 @@ import torch
from megatron import get_args, print_rank_last from megatron import get_args, print_rank_last
from megatron import mpu from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
...@@ -37,7 +37,6 @@ class ClassificationBase(MegatronModule): ...@@ -37,7 +37,6 @@ class ClassificationBase(MegatronModule):
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=True, add_pooler=True,
init_method=init_method, init_method=init_method,
......
...@@ -15,62 +15,71 @@ ...@@ -15,62 +15,71 @@
import torch import torch
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function) :
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
""" """
Fused operation which performs following three operations in sequence Fused operation which performs following three operations in sequence
1. Scale the tensor. 1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models). 2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax. 3. Perform softmax.
""" """
@staticmethod @staticmethod
def forward(ctx, inputs, scale): def forward(ctx, inputs, scale):
import scaled_upper_triang_masked_softmax_cuda import scaled_upper_triang_masked_softmax_cuda
scale_t = torch.tensor([scale]) scale_t = torch.tensor([scale])
softmax_results = \ softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0]) inputs, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t) ctx.save_for_backward(softmax_results, scale_t)
return softmax_results return softmax_results
@staticmethod @staticmethod
def backward(ctx, output_grads): def backward(ctx, output_grads):
import scaled_upper_triang_masked_softmax_cuda import scaled_upper_triang_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors softmax_results, scale_t = ctx.saved_tensors
input_grads = \ input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
scaled_upper_triang_masked_softmax_cuda.backward(output_grads, output_grads, softmax_results, scale_t[0]
softmax_results, )
scale_t[0])
return input_grads, None return input_grads, None
class ScaledMaskedSoftmax(torch.autograd.Function) :
class ScaledMaskedSoftmax(torch.autograd.Function):
""" """
Fused operation which performs following three operations in sequence Fused operation which performs following three operations in sequence
1. Scale the tensor. 1. Scale the tensor.
2. Apply the mask. 2. Apply the mask.
3. Perform softmax. 3. Perform softmax.
""" """
@staticmethod @staticmethod
def forward(ctx, inputs, mask, scale): def forward(ctx, inputs, mask, scale):
import scaled_masked_softmax_cuda import scaled_masked_softmax_cuda
scale_t = torch.tensor([scale]) scale_t = torch.tensor([scale])
softmax_results = \ softmax_results = scaled_masked_softmax_cuda.forward(
scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0]) inputs, mask, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t) ctx.save_for_backward(softmax_results, scale_t)
return softmax_results return softmax_results
@staticmethod @staticmethod
def backward(ctx, output_grads): def backward(ctx, output_grads):
import scaled_masked_softmax_cuda import scaled_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors softmax_results, scale_t = ctx.saved_tensors
input_grads = \ input_grads = scaled_masked_softmax_cuda.backward(
scaled_masked_softmax_cuda.backward(output_grads, output_grads, softmax_results, scale_t[0]
softmax_results, )
scale_t[0])
return input_grads, None, None return input_grads, None, None
class FusedScaleMaskSoftmax(torch.nn.Module): class FusedScaleMaskSoftmax(torch.nn.Module):
""" """
fused operation: scaling + mask + softmax fused operation: scaling + mask + softmax
...@@ -83,8 +92,16 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -83,8 +92,16 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
scale: scaling factor used in input tensor scaling. scale: scaling factor used in input tensor scaling.
""" """
def __init__(self, input_in_fp16, upper_triang_mask_fusion,
general_mask_fusion, mask_func, softmax_in_fp32, scale): def __init__(
self,
input_in_fp16,
upper_triang_mask_fusion,
general_mask_fusion,
mask_func,
softmax_in_fp32,
scale,
):
super(FusedScaleMaskSoftmax, self).__init__() super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16 self.input_in_fp16 = input_in_fp16
self.upper_triang_mask_fusion = upper_triang_mask_fusion self.upper_triang_mask_fusion = upper_triang_mask_fusion
...@@ -93,8 +110,9 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -93,8 +110,9 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
self.softmax_in_fp32 = softmax_in_fp32 self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale self.scale = scale
assert self.scale is None or softmax_in_fp32, \ assert (
'softmax should be in fp32 when scaled' self.scale is None or softmax_in_fp32
), "softmax should be in fp32 when scaled"
def forward(self, input, mask): def forward(self, input, mask):
# [b, np, s, s] # [b, np, s, s]
...@@ -102,9 +120,12 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -102,9 +120,12 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
assert input.dim() == 4 assert input.dim() == 4
# invoke custom kernel # invoke custom kernel
if self.input_in_fp16 and data_size[-1] <= 2048 and \ if (
(self.upper_triang_mask_fusion or self.general_mask_fusion) and \ self.input_in_fp16
input.size()[2] == input.size()[3]: and data_size[-1] <= 2048
and (self.upper_triang_mask_fusion or self.general_mask_fusion)
and input.size()[2] == input.size()[3]
):
scale = self.scale if self.scale is not None else 1.0 scale = self.scale if self.scale is not None else 1.0
if self.upper_triang_mask_fusion: if self.upper_triang_mask_fusion:
input = input.view(-1, data_size[2], data_size[3]) input = input.view(-1, data_size[2], data_size[3])
...@@ -118,7 +139,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -118,7 +139,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
if self.scale is not None: if self.scale is not None:
input = input * self.scale input = input * self.scale
mask_output = self.mask_func(input, mask) mask_output = self.mask_func(input, mask) if mask else input
probs = torch.nn.Softmax(dim=-1)(mask_output) probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_fp16 and self.softmax_in_fp32: if self.input_in_fp16 and self.softmax_in_fp32:
......
...@@ -27,11 +27,6 @@ from .utils import init_method_normal ...@@ -27,11 +27,6 @@ from .utils import init_method_normal
from .utils import scaled_init_method_normal from .utils import scaled_init_method_normal
def gpt2_attention_mask_func(attention_scores, ltor_mask):
attention_scores.masked_fill_(ltor_mask, -10000.0)
return attention_scores
def post_language_model_processing(lm_output, labels, logit_weights, def post_language_model_processing(lm_output, labels, logit_weights,
get_key_value, parallel_output, get_key_value, parallel_output,
forward_method_parallel_output, forward_method_parallel_output,
...@@ -72,7 +67,6 @@ class GPT2ModelBase(MegatronModule): ...@@ -72,7 +67,6 @@ class GPT2ModelBase(MegatronModule):
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=gpt2_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=False, add_pooler=False,
init_method=init_method_normal(args.init_method_std), init_method=init_method_normal(args.init_method_std),
......
...@@ -42,7 +42,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -42,7 +42,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
return mpu.gather_from_tensor_model_parallel_region(logits_parallel) return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(attention_mask_func, num_tokentypes, add_pooler, def get_language_model(num_tokentypes, add_pooler,
init_method=None, scaled_init_method=None): init_method=None, scaled_init_method=None):
"""Build language model and return along with the key to save.""" """Build language model and return along with the key to save."""
args = get_args() args = get_args()
...@@ -54,7 +54,7 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler, ...@@ -54,7 +54,7 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers) scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
# Language model. # Language model.
args = [attention_mask_func, init_method, scaled_init_method] args = [init_method, scaled_init_method]
kwargs = {} kwargs = {}
cls = None cls = None
if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage(): if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
...@@ -262,12 +262,6 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -262,12 +262,6 @@ class TransformerLanguageModelBase(MegatronModule):
Arguments: Arguments:
transformer_hparams: transformer hyperparameters transformer_hparams: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
vocab_size: vocabulary size vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This max_sequence_length: maximum size of sequence. This
is used for positional embedding is used for positional embedding
...@@ -277,7 +271,6 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -277,7 +271,6 @@ class TransformerLanguageModelBase(MegatronModule):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
num_tokentypes=0, num_tokentypes=0,
...@@ -302,8 +295,7 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -302,8 +295,7 @@ class TransformerLanguageModelBase(MegatronModule):
# Transformer. # Transformer.
self.transformer = ParallelTransformer( self.transformer = ParallelTransformer(
attention_mask_func, self.init_method, self.init_method, output_layer_init_method)
output_layer_init_method)
self._transformer_key = 'transformer' self._transformer_key = 'transformer'
# Pooler. # Pooler.
...@@ -396,13 +388,11 @@ class TransformerLanguageModel(TransformerLanguageModelBase): ...@@ -396,13 +388,11 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
num_tokentypes=0, num_tokentypes=0,
add_pooler=False): add_pooler=False):
super(TransformerLanguageModel, self).__init__( super(TransformerLanguageModel, self).__init__(
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
...@@ -427,12 +417,10 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase): ...@@ -427,12 +417,10 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
num_tokentypes=0): num_tokentypes=0):
super(TransformerLanguageModelFirstStage, self).__init__( super(TransformerLanguageModelFirstStage, self).__init__(
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
num_tokentypes=num_tokentypes) num_tokentypes=num_tokentypes)
...@@ -454,11 +442,9 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase): ...@@ -454,11 +442,9 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method): output_layer_init_method):
super(TransformerLanguageModelIntermediateStage, self).__init__( super(TransformerLanguageModelIntermediateStage, self).__init__(
attention_mask_func,
init_method, init_method,
output_layer_init_method) output_layer_init_method)
...@@ -478,12 +464,10 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase): ...@@ -478,12 +464,10 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
add_pooler=False): add_pooler=False):
super(TransformerLanguageModelLastStage, self).__init__( super(TransformerLanguageModelLastStage, self).__init__(
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
add_pooler=add_pooler) add_pooler=add_pooler)
......
...@@ -19,7 +19,7 @@ import torch ...@@ -19,7 +19,7 @@ import torch
from megatron import get_args, print_rank_last from megatron import get_args, print_rank_last
from megatron import mpu from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
...@@ -36,7 +36,6 @@ class MultipleChoiceBase(MegatronModule): ...@@ -36,7 +36,6 @@ class MultipleChoiceBase(MegatronModule):
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=True, add_pooler=True,
init_method=init_method, init_method=init_method,
......
...@@ -10,7 +10,7 @@ from megatron.model.utils import get_linear_layer ...@@ -10,7 +10,7 @@ from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
def general_ict_model_provider(only_query_model=False, only_block_model=False): def general_ict_model_provider(only_query_model=False, only_block_model=False):
...@@ -156,7 +156,6 @@ class IREncoderBertModel(MegatronModule): ...@@ -156,7 +156,6 @@ class IREncoderBertModel(MegatronModule):
args.num_layers) args.num_layers)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=True, add_pooler=True,
init_method=init_method, init_method=init_method,
......
...@@ -26,7 +26,7 @@ from megatron.checkpointing import get_checkpoint_version ...@@ -26,7 +26,7 @@ from megatron.checkpointing import get_checkpoint_version
from megatron.model import import_layernorm from megatron.model import import_layernorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import openai_gelu, erf_gelu from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
# flags required to enable jit fusion kernels # flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)
...@@ -47,12 +47,6 @@ torch._C._jit_override_can_fuse_on_gpu(True) ...@@ -47,12 +47,6 @@ torch._C._jit_override_can_fuse_on_gpu(True)
Transformer takes input of size [s, b, h] and returns a Transformer takes input of size [s, b, h] and returns a
tensor of the same size. We use the following arguments: tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters hyperparameters: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
""" """
class ParallelMLP(MegatronModule): class ParallelMLP(MegatronModule):
...@@ -116,13 +110,11 @@ class ParallelSelfAttention(MegatronModule): ...@@ -116,13 +110,11 @@ class ParallelSelfAttention(MegatronModule):
and returns output of the same size. and returns output of the same size.
""" """
def __init__(self, attention_mask_func, init_method, def __init__(self, init_method, output_layer_init_method, layer_number):
output_layer_init_method, layer_number):
super(ParallelSelfAttention, self).__init__() super(ParallelSelfAttention, self).__init__()
args = get_args() args = get_args()
self.fp16 = args.fp16 self.fp16 = args.fp16
self.attention_mask_func = attention_mask_func
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling: if self.apply_query_key_layer_scaling:
...@@ -155,7 +147,7 @@ class ParallelSelfAttention(MegatronModule): ...@@ -155,7 +147,7 @@ class ParallelSelfAttention(MegatronModule):
self.fp16, self.fp16,
args.scaled_upper_triang_masked_softmax_fusion, args.scaled_upper_triang_masked_softmax_fusion,
args.scaled_masked_softmax_fusion, args.scaled_masked_softmax_fusion,
self.attention_mask_func, attention_mask_func,
self.attention_softmax_in_fp32, self.attention_softmax_in_fp32,
coeff) coeff)
...@@ -173,7 +165,7 @@ class ParallelSelfAttention(MegatronModule): ...@@ -173,7 +165,7 @@ class ParallelSelfAttention(MegatronModule):
skip_bias_add=True) skip_bias_add=True)
def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first): def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
input_shape = mixed_layer.size(); input_shape = mixed_layer.size()
if num_splits_first: if num_splits_first:
"""[s, b, num_splits * np * hn] """[s, b, num_splits * np * hn]
-->(view) [s, b, num_splits, np, hn] -->(view) [s, b, num_splits, np, hn]
...@@ -246,7 +238,6 @@ class ParallelSelfAttention(MegatronModule): ...@@ -246,7 +238,6 @@ class ParallelSelfAttention(MegatronModule):
if get_key_value: if get_key_value:
present = (key_layer, value_layer) present = (key_layer, value_layer)
# =================================== # ===================================
# Raw attention scores. [b, np, s, s] # Raw attention scores. [b, np, s, s]
# =================================== # ===================================
...@@ -272,15 +263,15 @@ class ParallelSelfAttention(MegatronModule): ...@@ -272,15 +263,15 @@ class ParallelSelfAttention(MegatronModule):
device=torch.cuda.current_device()) device=torch.cuda.current_device())
# Raw attention scores. [b * np, sq, sk] # Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(matmul_result, matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn] query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0,1).transpose(1, 2), #[b * np, hn, sk] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor)) beta=0.0, alpha=(1.0/self.norm_factor))
# change view to [b, np, sq, sk] # change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size) attention_scores = matmul_result.view(*output_size)
# ================================================== # ==================================================
# Update attention mask for inference. [b, np, sq, sk] # Update attention mask for inference. [b, np, sq, sk]
# ================================================== # ==================================================
...@@ -298,7 +289,6 @@ class ParallelSelfAttention(MegatronModule): ...@@ -298,7 +289,6 @@ class ParallelSelfAttention(MegatronModule):
:attention_scores.size(3), :attention_scores.size(3),
:attention_scores.size(3)] :attention_scores.size(3)]
# =========================== # ===========================
# Attention probs and dropout # Attention probs and dropout
# =========================== # ===========================
...@@ -312,7 +302,6 @@ class ParallelSelfAttention(MegatronModule): ...@@ -312,7 +302,6 @@ class ParallelSelfAttention(MegatronModule):
with mpu.get_cuda_rng_tracker().fork(): with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs) attention_probs = self.attention_dropout(attention_probs)
# ========================= # =========================
# Context layer. [sq, b, hp] # Context layer. [sq, b, hp]
# ========================= # =========================
...@@ -335,7 +324,7 @@ class ParallelSelfAttention(MegatronModule): ...@@ -335,7 +324,7 @@ class ParallelSelfAttention(MegatronModule):
output_size[2], -1) output_size[2], -1)
# matmul: [b * np, sq, hn] # matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0,1)) context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn] # change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size) context_layer = context_layer.view(*output_size)
...@@ -348,7 +337,6 @@ class ParallelSelfAttention(MegatronModule): ...@@ -348,7 +337,6 @@ class ParallelSelfAttention(MegatronModule):
(self.hidden_size_per_partition,) (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(*new_context_layer_shape)
# ================= # =================
# Output. [sq, b, h] # Output. [sq, b, h]
# ================= # =================
...@@ -361,7 +349,7 @@ class ParallelSelfAttention(MegatronModule): ...@@ -361,7 +349,7 @@ class ParallelSelfAttention(MegatronModule):
return output, bias return output, bias
def bias_dropout_add(x, bias, residual, prob, training) : def bias_dropout_add(x, bias, residual, prob, training):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out = torch.nn.functional.dropout(x + bias, p=prob, training=training) out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
out = residual + out out = residual + out
...@@ -375,13 +363,13 @@ def get_bias_dropout_add(training): ...@@ -375,13 +363,13 @@ def get_bias_dropout_add(training):
@torch.jit.script @torch.jit.script
def bias_dropout_add_fused_train(x, bias, residual, prob) : def bias_dropout_add_fused_train(x, bias, residual, prob):
# type: (Tensor, Tensor, Tensor, float) -> Tensor # type: (Tensor, Tensor, Tensor, float) -> Tensor
return bias_dropout_add(x, bias, residual, prob, True) return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script @torch.jit.script
def bias_dropout_add_fused_inference(x, bias, residual, prob) : def bias_dropout_add_fused_inference(x, bias, residual, prob):
# type: (Tensor, Tensor, Tensor, float) -> Tensor # type: (Tensor, Tensor, Tensor, float) -> Tensor
return bias_dropout_add(x, bias, residual, prob, False) return bias_dropout_add(x, bias, residual, prob, False)
...@@ -393,8 +381,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -393,8 +381,7 @@ class ParallelTransformerLayer(MegatronModule):
output of the same size. output of the same size.
""" """
def __init__(self, attention_mask_func, init_method, def __init__(self, init_method, output_layer_init_method, layer_number):
output_layer_init_method, layer_number):
args = get_args() args = get_args()
super(ParallelTransformerLayer, self).__init__() super(ParallelTransformerLayer, self).__init__()
...@@ -410,7 +397,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -410,7 +397,7 @@ class ParallelTransformerLayer(MegatronModule):
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon)
# Self attention. # Self attention.
self.attention = ParallelSelfAttention(attention_mask_func, init_method, self.attention = ParallelSelfAttention(init_method,
output_layer_init_method, output_layer_init_method,
layer_number) layer_number)
self.hidden_dropout = args.hidden_dropout self.hidden_dropout = args.hidden_dropout
...@@ -459,7 +446,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -459,7 +446,7 @@ class ParallelTransformerLayer(MegatronModule):
else: else:
bias_dropout_add_func = get_bias_dropout_add(self.training) bias_dropout_add_func = get_bias_dropout_add(self.training)
#re-enable torch grad to enable fused optimization. # re-enable torch grad to enable fused optimization.
with torch.enable_grad(): with torch.enable_grad():
layernorm_input = bias_dropout_add_func( layernorm_input = bias_dropout_add_func(
attention_output, attention_output,
...@@ -479,7 +466,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -479,7 +466,7 @@ class ParallelTransformerLayer(MegatronModule):
else: else:
residual = layernorm_input residual = layernorm_input
#re-enable torch grad to enable fused optimization. # re-enable torch grad to enable fused optimization.
with torch.enable_grad(): with torch.enable_grad():
output = bias_dropout_add_func( output = bias_dropout_add_func(
mlp_output, mlp_output,
...@@ -496,8 +483,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -496,8 +483,7 @@ class ParallelTransformerLayer(MegatronModule):
class ParallelTransformer(MegatronModule): class ParallelTransformer(MegatronModule):
"""Transformer class.""" """Transformer class."""
def __init__(self, attention_mask_func, def __init__(self, init_method, output_layer_init_method):
init_method, output_layer_init_method):
super(ParallelTransformer, self).__init__() super(ParallelTransformer, self).__init__()
args = get_args() args = get_args()
...@@ -515,8 +501,7 @@ class ParallelTransformer(MegatronModule): ...@@ -515,8 +501,7 @@ class ParallelTransformer(MegatronModule):
# Transformer layers. # Transformer layers.
def build_layer(layer_number): def build_layer(layer_number):
return ParallelTransformerLayer( return ParallelTransformerLayer(
attention_mask_func, init_method, init_method, output_layer_init_method, layer_number)
output_layer_init_method, layer_number)
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)]) [build_layer(i + 1 + offset) for i in range(self.num_layers)])
......
...@@ -39,6 +39,11 @@ def scaled_init_method_normal(sigma, num_layers): ...@@ -39,6 +39,11 @@ def scaled_init_method_normal(sigma, num_layers):
return init_ return init_
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def get_linear_layer(rows, columns, init_method): def get_linear_layer(rows, columns, init_method):
"""Simple linear layer with weight initialization.""" """Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns) layer = torch.nn.Linear(rows, columns)
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT model."""
import math
import einops
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import (
get_linear_layer,
init_method_normal,
scaled_init_method_normal,
)
from .module import MegatronModule
class VitMlpHead(MegatronModule):
"""Pooler layer.
Pool hidden states of a specific token (for example start of the
sequence) and add a linear transformation followed by a tanh.
Arguments:
hidden_size: hidden size
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def __init__(self, hidden_size, num_classes):
super(VitMlpHead, self).__init__()
self.dense_in = torch.nn.Linear(hidden_size, hidden_size)
self.dense_out = torch.nn.Linear(hidden_size, num_classes)
torch.nn.init.constant_(self.dense_out.bias, -10)
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
x = hidden_states[:, sequence_index, :]
x = self.dense_in(x)
x = torch.tanh(x)
x = self.dense_out(x)
return x
def twod_interpolate_position_embeddings_hook(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
args = get_args()
num_patches_per_dim = args.img_dim // args.patch_dim
num_patches = num_patches_per_dim ** 2
seq_length = num_patches + 1
hidden_size = args.hidden_size
key = prefix + "weight"
# import pdb
# pdb.set_trace()
assert key in state_dict
if key in state_dict:
input_param = state_dict[key]
assert input_param.shape[1] == hidden_size
if input_param.shape[0] != seq_length:
# update input_param and load it to state_dict[key]
num_tok_input = input_param.shape[0] - 1
num_tok_new = seq_length - 1
input_param_tok, input_param_grid = (
input_param[:1, :],
input_param[1:, :],
)
gs_input = int(math.sqrt(num_tok_input))
gs_new = int(math.sqrt(num_tok_new))
input_param_grid = input_param_grid.transpose(0, 1).contiguous()
input_param_grid = input_param_grid.reshape(
(1, -1, gs_input, gs_input)
)
input_param_grid = input_param_grid.float()
scale_factor = gs_new / gs_input
input_param_grid = F.interpolate(
input_param_grid, scale_factor=scale_factor, mode="bilinear"
)
input_param_grid = input_param_grid.half()
input_param_grid = input_param_grid.reshape((-1, gs_new * gs_new))
input_param_grid = input_param_grid.transpose(0, 1).contiguous()
assert input_param_grid.shape[1] == hidden_size
input_param = torch.cat((input_param_tok, input_param_grid), dim=0)
assert (
input_param.shape[0] == seq_length
and input_param.shape[1] == hidden_size
)
state_dict[key] = input_param
class VitModel(MegatronModule):
"""Bert Language model."""
def __init__(self, num_classes, finetune=False):
super(VitModel, self).__init__()
args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
if args.init_method_xavier_uniform:
self.init_method = torch.nn.init.xavier_uniform_
self.scaled_init_method = torch.nn.init.xavier_uniform_
else:
self.init_method = init_method_normal(args.init_method_std)
self.scaled_init_method = scaled_init_method_normal(
args.init_method_std, args.num_layers
)
self.hidden_size = args.hidden_size
self.num_classes = num_classes
self.patch_dim = args.patch_dim
self.img_dim = args.img_dim
self.finetune = finetune
assert self.img_dim % self.patch_dim == 0
self.num_patches_per_dim = self.img_dim // self.patch_dim
self.num_patches = self.num_patches_per_dim ** 2
self.seq_length = self.num_patches + 1
self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels
# cls_token
self.cls_token = torch.nn.Parameter(torch.randn(1, 1, self.hidden_size))
torch.nn.init.zeros_(self.cls_token)
# Linear encoder
self.linear_encoder = torch.nn.Linear(
self.flatten_dim, self.hidden_size
)
# embedding
self.position_embeddings = torch.nn.Embedding(
self.seq_length, self.hidden_size
)
init_method_normal(args.init_method_std)(
self.position_embeddings.weight
)
self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
self.position_embeddings._register_load_state_dict_pre_hook(
twod_interpolate_position_embeddings_hook
)
self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)
# Transformer
self.transformer = ParallelTransformer(
self.init_method, self.scaled_init_method
)
# MLP head
if not self.finetune:
self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes)
else:
self.class_head = get_linear_layer(
self.hidden_size, num_classes, torch.nn.init.zeros_
)
def forward(self, x):
x = einops.rearrange(
x,
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
p1=self.patch_dim,
p2=self.patch_dim,
)
assert x.dtype == torch.half
x = self.linear_encoder(x)
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.position_embeddings(self.position_ids)
x = self.embedding_dropout(x)
x = self.transformer(x, None)
if not self.finetune:
x = self.mlp_head(x)
else:
x = self.class_head(x[:, 0, :])
return x
...@@ -20,7 +20,7 @@ from .initialize import get_tensor_model_parallel_rank ...@@ -20,7 +20,7 @@ from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_src_rank from .initialize import get_tensor_model_parallel_src_rank
_MAX_DATA_DIM = 4 _MAX_DATA_DIM = 5
def _check_data_types(keys, data, target_dtype): def _check_data_types(keys, data, target_dtype):
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from apex.optimizers import FusedAdam as Adam from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD
from megatron import get_args from megatron import get_args
from megatron.model import import_layernorm from megatron.model import import_layernorm
...@@ -52,11 +53,18 @@ def get_megatron_optimizer(model): ...@@ -52,11 +53,18 @@ def get_megatron_optimizer(model):
# Base optimizer. # Base optimizer.
param_groups = _get_params_for_weight_decay_optimization(model) param_groups = _get_params_for_weight_decay_optimization(model)
if args.optimizer == 'adam':
optimizer = Adam(param_groups, optimizer = Adam(param_groups,
lr=args.lr, lr=args.lr,
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2), betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps) eps=args.adam_eps)
else:
assert args.optimizer == 'sgd'
optimizer = SGD(param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
momentum=args.sgd_momentum)
if args.fp16: if args.fp16:
# Constant loss scale. # Constant loss scale.
......
...@@ -57,8 +57,12 @@ def print_datetime(string): ...@@ -57,8 +57,12 @@ def print_datetime(string):
print_rank_0('[' + string + '] datetime: {} '.format(time_str)) print_rank_0('[' + string + '] datetime: {} '.format(time_str))
def pretrain(train_valid_test_dataset_provider, model_provider, def pretrain(train_valid_test_dataset_provider,
forward_step_func, extra_args_provider=None, args_defaults={}): model_provider,
forward_step_func,
extra_args_provider=None,
args_defaults={},
random_sample = False):
"""Main training program. """Main training program.
This function will run the followings in the order provided: This function will run the followings in the order provided:
...@@ -113,7 +117,8 @@ def pretrain(train_valid_test_dataset_provider, model_provider, ...@@ -113,7 +117,8 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
timers('train/valid/test data iterators').start() timers('train/valid/test data iterators').start()
train_data_iterator, valid_data_iterator, test_data_iterator \ train_data_iterator, valid_data_iterator, test_data_iterator \
= build_train_valid_test_data_iterators( = build_train_valid_test_data_iterators(
train_valid_test_dataset_provider) train_valid_test_dataset_provider,
random_sample)
timers('train/valid/test data iterators').stop() timers('train/valid/test data iterators').stop()
print_datetime('after dataloaders are built') print_datetime('after dataloaders are built')
...@@ -944,8 +949,13 @@ def evaluate_and_print_results(prefix, forward_step_func, ...@@ -944,8 +949,13 @@ def evaluate_and_print_results(prefix, forward_step_func,
print_rank_last('-' * length) print_rank_last('-' * length)
def cyclic_iterable(iterable):
while True:
for x in iterable:
yield x
def build_train_valid_test_data_iterators( def build_train_valid_test_data_iterators(
build_train_valid_test_datasets_provider): build_train_valid_test_datasets_provider, random_sample=False):
"""XXX""" """XXX"""
args = get_args() args = get_args()
...@@ -989,10 +999,10 @@ def build_train_valid_test_data_iterators( ...@@ -989,10 +999,10 @@ def build_train_valid_test_data_iterators(
# Build dataloders. # Build dataloders.
train_dataloader = build_pretraining_data_loader( train_dataloader = build_pretraining_data_loader(
train_ds, args.consumed_train_samples) train_ds, args.consumed_train_samples, random_sample)
valid_dataloader = build_pretraining_data_loader( valid_dataloader = build_pretraining_data_loader(
valid_ds, args.consumed_valid_samples) valid_ds, args.consumed_valid_samples, random_sample)
test_dataloader = build_pretraining_data_loader(test_ds, 0) test_dataloader = build_pretraining_data_loader(test_ds, 0, random_sample)
# Flags to know if we need to do training/validation/testing. # Flags to know if we need to do training/validation/testing.
do_train = train_dataloader is not None and args.train_iters > 0 do_train = train_dataloader is not None and args.train_iters > 0
...@@ -1014,17 +1024,17 @@ def build_train_valid_test_data_iterators( ...@@ -1014,17 +1024,17 @@ def build_train_valid_test_data_iterators(
# Build iterators. # Build iterators.
if train_dataloader is not None: if train_dataloader is not None:
train_data_iterator = iter(train_dataloader) train_data_iterator = iter(cyclic_iterable(train_dataloader))
else: else:
train_data_iterator = None train_data_iterator = None
if valid_dataloader is not None: if valid_dataloader is not None:
valid_data_iterator = iter(valid_dataloader) valid_data_iterator = iter(cyclic_iterable(valid_dataloader))
else: else:
valid_data_iterator = None valid_data_iterator = None
if test_dataloader is not None: if test_dataloader is not None:
test_data_iterator = iter(test_dataloader) test_data_iterator = iter(cyclic_iterable(test_dataloader))
else: else:
test_data_iterator = None test_data_iterator = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment