Commit 456f1728 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

vision transformer model and vision classification task

parent f5eac3d1
......@@ -40,6 +40,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_data_args(parser)
parser = _add_autoresume_args(parser)
parser = _add_realm_args(parser)
parser = _add_vit_args(parser)
# Custom arguments.
if extra_args_provider is not None:
......@@ -271,6 +272,8 @@ def _add_regularization_args(parser):
group.add_argument('--adam-eps', type=float, default=1e-08,
help='Term added to the denominator to improve'
'numerical stability')
group.add_argument('--sgd-momentum', type=float, default=0.9,
help='Momentum factor for sgd')
return parser
......@@ -346,6 +349,9 @@ def _add_training_args(parser):
group.add_argument('--no-bias-dropout-fusion', action='store_false',
help='Disable bias and dropout fusion.',
dest='bias_dropout_fusion')
group.add_argument('--optimizer', type=str, default='adam',
choices=['adam', 'sgd'],
help='Optimizer function')
return parser
......@@ -359,6 +365,8 @@ def _add_initialization_args(parser):
group.add_argument('--init-method-std', type=float, default=0.02,
help='Standard deviation of the zero mean normal '
'distribution used for weight initialization.')
group.add_argument('--init-method-xavier-uniform', action='store_true',
help='Enable Xavier uniform parameter initialization')
return parser
......@@ -607,3 +615,20 @@ def _add_realm_args(parser):
group.add_argument('--indexer-log-interval', type=int, default=1000,
help='After how many batches should the indexer report progress')
return parser
def _add_vit_args(parser):
group = parser.add_argument_group(title="vit")
group.add_argument('--vit-load', type=str, default=None,
help='Director containing a VitModel checkpoint')
group.add_argument('--num-classes', type=int, default=1000,
help='num of classes in vision classificaiton task')
group.add_argument('--img-dim', type=int, default=224,
help='Image size for vision classification task')
group.add_argument('--num-channels', type=int, default=3,
help='Number of channels in input image data')
group.add_argument('--patch-dim', type=int, default=16,
help='patch dimension used in vit')
return parser
......@@ -59,6 +59,7 @@ def check_checkpoint_args(checkpoint_args):
_compare('hidden_size')
_compare('num_attention_heads')
_compare('max_position_embeddings')
if args.vit_load is None:
_compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size')
_compare('tokenizer_type')
......@@ -159,7 +160,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
torch.distributed.barrier()
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True):
"""Load a model checkpoint and return the iteration."""
args = get_args()
load_dir = getattr(args, load_arg)
......@@ -252,7 +253,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
print_rank_0('could not find arguments in the checkpoint ...')
# Model.
model.load_state_dict(state_dict['model'])
model.load_state_dict(state_dict['model'], strict=strict)
# Optimizer.
if not release and not args.finetune and not args.no_load_optim:
......
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
"""AutoAugment data augmentation policy for ImageNet.
Implements the fixed AutoAugment data augmentation policy for ImageNet
provided in Appendix A, Table 9 in reference [1]. Does not include any
of the search code.
Reference:
[1] https://arxiv.org/abs/1805.09501
Code adapted from:
https://github.com/DeepVoltaire/AutoAugment
"""
import random
import numpy as np
from PIL import Image
from PIL import ImageEnhance
from PIL import ImageOps
_MAX_LEVEL = 10 # Maximum integer strength of an augmentation, if applicable.
class ImageNetPolicy:
"""Definition of an ImageNetPolicy.
Implements a fixed AutoAugment data augmentation policy targeted at
ImageNet training by randomly applying at runtime one of the 25 pre-defined
data augmentation sub-policies provided in Reference [1].
Usage example as a Pytorch Transform:
>>> transform=transforms.Compose([transforms.Resize(256),
>>> ImageNetPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
"""Initialize an ImageNetPolicy.
Args:
fillcolor (tuple): RGB color components of the color to be used for
filling when needed (default: (128, 128, 128), which
corresponds to gray).
"""
# Instantiate a list of sub-policies.
# Each entry of the list is a SubPolicy which consists of
# two augmentation operations,
# each of those parametrized as operation, probability, magnitude.
# Those two operations are applied sequentially on the image upon call.
self.policies = [
SubPolicy("posterize", 0.4, 8, "rotate", 0.6, 9, fillcolor),
SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor),
SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor),
SubPolicy("posterize", 0.6, 7, "posterize", 0.6, 6, fillcolor),
SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor),
SubPolicy("equalize", 0.4, 4, "rotate", 0.8, 8, fillcolor),
SubPolicy("solarize", 0.6, 3, "equalize", 0.6, 7, fillcolor),
SubPolicy("posterize", 0.8, 5, "equalize", 1.0, 2, fillcolor),
SubPolicy("rotate", 0.2, 3, "solarize", 0.6, 8, fillcolor),
SubPolicy("equalize", 0.6, 8, "posterize", 0.4, 6, fillcolor),
SubPolicy("rotate", 0.8, 8, "color", 0.4, 0, fillcolor),
SubPolicy("rotate", 0.4, 9, "equalize", 0.6, 2, fillcolor),
SubPolicy("equalize", 0.0, 7, "equalize", 0.8, 8, fillcolor),
SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor),
SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor),
SubPolicy("rotate", 0.8, 8, "color", 1.0, 2, fillcolor),
SubPolicy("color", 0.8, 8, "solarize", 0.8, 7, fillcolor),
SubPolicy("sharpness", 0.4, 7, "invert", 0.6, 8, fillcolor),
SubPolicy("shearX", 0.6, 5, "equalize", 1.0, 9, fillcolor),
SubPolicy("color", 0.4, 0, "equalize", 0.6, 3, fillcolor),
SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor),
SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor),
SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor),
SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor),
SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor),
]
def __call__(self, img):
"""Define call method for ImageNetPolicy class."""
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
"""Define repr method for ImageNetPolicy class."""
return "ImageNetPolicy"
class SubPolicy:
"""Definition of a SubPolicy.
A SubPolicy consists of two augmentation operations,
each of those parametrized as operation, probability, magnitude.
The two operations are applied sequentially on the image upon call.
"""
def __init__(
self,
operation1,
probability1,
magnitude_idx1,
operation2,
probability2,
magnitude_idx2,
fillcolor,
):
"""Initialize a SubPolicy.
Args:
operation1 (str): Key specifying the first augmentation operation.
There are fourteen key values altogether (see supported_ops below
listing supported operations). probability1 (float): Probability
within [0., 1.] of applying the first augmentation operation.
magnitude_idx1 (int): Integer specifiying the strength of the first
operation as an index further used to derive the magnitude from a
range of possible values.
operation2 (str): Key specifying the second augmentation operation.
probability2 (float): Probability within [0., 1.] of applying the
second augmentation operation.
magnitude_idx2 (int): Integer specifiying the strength of the
second operation as an index further used to derive the magnitude
from a range of possible values.
fillcolor (tuple): RGB color components of the color to be used for
filling.
Returns:
"""
# List of supported operations for operation1 and operation2.
supported_ops = [
"shearX",
"shearY",
"translateX",
"translateY",
"rotate",
"color",
"posterize",
"solarize",
"contrast",
"sharpness",
"brightness",
"autocontrast",
"equalize",
"invert",
]
assert (operation1 in supported_ops) and (
operation2 in supported_ops
), "SubPolicy:one of oper1 or oper2 refers to an unsupported operation."
assert (
0.0 <= probability1 <= 1.0 and 0.0 <= probability2 <= 1.0
), "SubPolicy: prob1 and prob2 should be within [0., 1.]."
assert (
isinstance(magnitude_idx1, int) and 0 <= magnitude_idx1 <= 10
), "SubPolicy: idx1 should be specified as an integer within [0, 10]."
assert (
isinstance(magnitude_idx2, int) and 0 <= magnitude_idx2 <= 10
), "SubPolicy: idx2 should be specified as an integer within [0, 10]."
# Define a dictionary where each key refers to a specific type of
# augmentation and the corresponding value is a range of ten possible
# magnitude values for that augmentation.
num_levels = _MAX_LEVEL + 1
ranges = {
"shearX": np.linspace(0, 0.3, num_levels),
"shearY": np.linspace(0, 0.3, num_levels),
"translateX": np.linspace(0, 150 / 331, num_levels),
"translateY": np.linspace(0, 150 / 331, num_levels),
"rotate": np.linspace(0, 30, num_levels),
"color": np.linspace(0.0, 0.9, num_levels),
"posterize": np.round(np.linspace(8, 4, num_levels), 0).astype(
np.int
),
"solarize": np.linspace(256, 0, num_levels), # range [0, 256]
"contrast": np.linspace(0.0, 0.9, num_levels),
"sharpness": np.linspace(0.0, 0.9, num_levels),
"brightness": np.linspace(0.0, 0.9, num_levels),
"autocontrast": [0]
* num_levels, # This augmentation doesn't use magnitude parameter.
"equalize": [0]
* num_levels, # This augmentation doesn't use magnitude parameter.
"invert": [0]
* num_levels, # This augmentation doesn't use magnitude parameter.
}
def rotate_with_fill(img, magnitude):
"""Define rotation transformation with fill.
The input image is first rotated, then it is blended together with
a gray mask of the same size. Note that fillcolor as defined
elsewhere in this module doesn't apply here.
Args:
magnitude (float): rotation angle in degrees.
Returns:
rotated_filled (PIL Image): rotated image with gray filling for
disoccluded areas unveiled by the rotation.
"""
rotated = img.convert("RGBA").rotate(magnitude)
rotated_filled = Image.composite(
rotated, Image.new("RGBA", rotated.size, (128,) * 4), rotated
)
return rotated_filled.convert(img.mode)
# Define a dictionary of augmentation functions where each key refers
# to a specific type of augmentation and the corresponding value defines
# the augmentation itself using a lambda function.
# pylint: disable=unnecessary-lambda
func_dict = {
"shearX": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
Image.BICUBIC,
fillcolor=fillcolor,
),
"shearY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
Image.BICUBIC,
fillcolor=fillcolor,
),
"translateX": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(
1,
0,
magnitude * img.size[0] * random.choice([-1, 1]),
0,
1,
0,
),
fillcolor=fillcolor,
),
"translateY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(
1,
0,
0,
0,
1,
magnitude * img.size[1] * random.choice([-1, 1]),
),
fillcolor=fillcolor,
),
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
1 + magnitude * random.choice([-1, 1])
),
"posterize": lambda img, magnitude: ImageOps.posterize(
img, magnitude
),
"solarize": lambda img, magnitude: ImageOps.solarize(
img, magnitude
),
"contrast": lambda img, magnitude: ImageEnhance.Contrast(
img
).enhance(1 + magnitude * random.choice([-1, 1])),
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(
img
).enhance(1 + magnitude * random.choice([-1, 1])),
"brightness": lambda img, magnitude: ImageEnhance.Brightness(
img
).enhance(1 + magnitude * random.choice([-1, 1])),
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
"equalize": lambda img, magnitude: ImageOps.equalize(img),
"invert": lambda img, magnitude: ImageOps.invert(img),
}
# Store probability, function and magnitude of the first augmentation
# for the sub-policy.
self.probability1 = probability1
self.operation1 = func_dict[operation1]
self.magnitude1 = ranges[operation1][magnitude_idx1]
# Store probability, function and magnitude of the second augmentation
# for the sub-policy.
self.probability2 = probability2
self.operation2 = func_dict[operation2]
self.magnitude2 = ranges[operation2][magnitude_idx2]
def __call__(self, img):
"""Define call method for SubPolicy class."""
# Randomly apply operation 1.
if random.random() < self.probability1:
img = self.operation1(img, self.magnitude1)
# Randomly apply operation 2.
if random.random() < self.probability2:
img = self.operation2(img, self.magnitude2)
return img
......@@ -22,7 +22,7 @@ from megatron import get_args
from megatron import mpu
def build_pretraining_data_loader(dataset, consumed_samples):
def build_pretraining_data_loader(dataset, consumed_samples, random_sample=False):
"""Buld dataloader given an input dataset."""
if dataset is None:
......@@ -35,7 +35,8 @@ def build_pretraining_data_loader(dataset, consumed_samples):
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
data_parallel_size=mpu.get_data_parallel_world_size(),
random_sample=random_sample)
# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
......@@ -46,41 +47,52 @@ def build_pretraining_data_loader(dataset, consumed_samples):
class MegatronPretrainingSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size):
data_parallel_rank, data_parallel_size, random_sample=False):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.micro_batch_times_data_parallel_size = self.micro_batch_size * \
data_parallel_size
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.random_sample = random_sample
# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.consumed_samples < self.total_samples, \
'no samples left to consume: {}, {}'.format(self.consumed_samples,
self.total_samples)
#assert self.consumed_samples < self.total_samples, \
# 'no samples left to consume: {}, {}'.format(self.consumed_samples,
# self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
def __len__(self):
return self.total_samples
def __iter__(self):
self.epoch = self.consumed_samples // self.total_samples
current_epoch_samples = self.consumed_samples % self.total_samples
if self.random_sample:
g = torch.Generator()
g.manual_seed(self.epoch)
idx_range_total = \
torch.randperm(self.total_samples, generator=g).tolist()
idx_range = idx_range_total[current_epoch_samples:]
else:
idx_range = range(current_epoch_samples, self.total_samples)
batch = []
# Last batch if not complete will be dropped.
for idx in range(self.consumed_samples, self.total_samples):
for idx in idx_range:
batch.append(idx)
if len(batch) == self.micro_batch_times_data_parallel_size:
self.consumed_samples += len(batch)
start_idx = self.data_parallel_rank * self.micro_batch_size
end_idx = start_idx + self.micro_batch_size
yield batch[start_idx:end_idx]
batch = []
self.consumed_samples += len(batch)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from torchvision import datasets, transforms
from megatron.data.autoaugment import ImageNetPolicy
def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True):
# training dataset
train_data_path = os.path.join(data_path[0], "train")
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
process = [
transforms.RandomResizedCrop(crop_size),
transforms.RandomHorizontalFlip(),
]
if color_jitter:
process += [
transforms.ColorJitter(
brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1
)
]
process += [ImageNetPolicy(), transforms.ToTensor(), normalize]
transform_train = transforms.Compose(process)
train_data = datasets.ImageFolder(
root=train_data_path, transform=transform_train
)
# validation dataset
val_data_path = os.path.join(data_path[0], "val")
transform_val = transforms.Compose(
[
transforms.Resize(crop_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
normalize,
]
)
val_data = datasets.ImageFolder(
root=val_data_path, transform=transform_val
)
return train_data, val_data
......@@ -83,6 +83,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
_build_num_microbatches_calculator(args)
if args.vocab_file:
_ = _build_tokenizer(args)
_set_tensorboard_writer(args)
_set_adlr_autoresume(args)
......@@ -131,7 +132,7 @@ def _set_tensorboard_writer(args):
'tensorboard writer')
if hasattr(args, 'tensorboard_dir') and \
args.tensorboard_dir and args.rank == (args.world_size -1):
args.tensorboard_dir and args.rank == (args.world_size - 1):
try:
from torch.utils.tensorboard import SummaryWriter
print('> setting tensorboard ...')
......
......@@ -30,6 +30,7 @@ def import_layernorm(fp32_residual_connection):
from .distributed import *
from .vit_model import VitModel
from .bert_model import (BertModel,
BertModelFirstStage,
BertModelIntermediateStage,
......
......@@ -28,9 +28,6 @@ from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule
def bert_attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def bert_extended_attention_mask(attention_mask):
# We create a 3D attention mask from a 2D tensor mask.
......@@ -144,7 +141,6 @@ class BertModelBase(MegatronModule):
args.num_layers)
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=self.add_binary_head,
init_method=init_method,
......
......@@ -19,7 +19,7 @@ import torch
from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
......@@ -37,7 +37,6 @@ class ClassificationBase(MegatronModule):
init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=True,
init_method=init_method,
......
......@@ -15,62 +15,71 @@
import torch
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function) :
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, scale):
import scaled_upper_triang_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = \
scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0])
softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
inputs, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
import scaled_upper_triang_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = \
scaled_upper_triang_masked_softmax_cuda.backward(output_grads,
softmax_results,
scale_t[0])
input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None
class ScaledMaskedSoftmax(torch.autograd.Function) :
class ScaledMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, mask, scale):
import scaled_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = \
scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
softmax_results = scaled_masked_softmax_cuda.forward(
inputs, mask, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
import scaled_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = \
scaled_masked_softmax_cuda.backward(output_grads,
softmax_results,
scale_t[0])
input_grads = scaled_masked_softmax_cuda.backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None, None
class FusedScaleMaskSoftmax(torch.nn.Module):
"""
fused operation: scaling + mask + softmax
......@@ -83,8 +92,16 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
scale: scaling factor used in input tensor scaling.
"""
def __init__(self, input_in_fp16, upper_triang_mask_fusion,
general_mask_fusion, mask_func, softmax_in_fp32, scale):
def __init__(
self,
input_in_fp16,
upper_triang_mask_fusion,
general_mask_fusion,
mask_func,
softmax_in_fp32,
scale,
):
super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16
self.upper_triang_mask_fusion = upper_triang_mask_fusion
......@@ -93,8 +110,9 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
assert self.scale is None or softmax_in_fp32, \
'softmax should be in fp32 when scaled'
assert (
self.scale is None or softmax_in_fp32
), "softmax should be in fp32 when scaled"
def forward(self, input, mask):
# [b, np, s, s]
......@@ -102,9 +120,12 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
assert input.dim() == 4
# invoke custom kernel
if self.input_in_fp16 and data_size[-1] <= 2048 and \
(self.upper_triang_mask_fusion or self.general_mask_fusion) and \
input.size()[2] == input.size()[3]:
if (
self.input_in_fp16
and data_size[-1] <= 2048
and (self.upper_triang_mask_fusion or self.general_mask_fusion)
and input.size()[2] == input.size()[3]
):
scale = self.scale if self.scale is not None else 1.0
if self.upper_triang_mask_fusion:
input = input.view(-1, data_size[2], data_size[3])
......@@ -118,7 +139,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask)
mask_output = self.mask_func(input, mask) if mask else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_fp16 and self.softmax_in_fp32:
......
......@@ -27,11 +27,6 @@ from .utils import init_method_normal
from .utils import scaled_init_method_normal
def gpt2_attention_mask_func(attention_scores, ltor_mask):
attention_scores.masked_fill_(ltor_mask, -10000.0)
return attention_scores
def post_language_model_processing(lm_output, labels, logit_weights,
get_key_value, parallel_output,
forward_method_parallel_output,
......@@ -72,7 +67,6 @@ class GPT2ModelBase(MegatronModule):
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=gpt2_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=False,
init_method=init_method_normal(args.init_method_std),
......
......@@ -42,7 +42,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
def get_language_model(num_tokentypes, add_pooler,
init_method=None, scaled_init_method=None):
"""Build language model and return along with the key to save."""
args = get_args()
......@@ -54,7 +54,7 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
# Language model.
args = [attention_mask_func, init_method, scaled_init_method]
args = [init_method, scaled_init_method]
kwargs = {}
cls = None
if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
......@@ -262,12 +262,6 @@ class TransformerLanguageModelBase(MegatronModule):
Arguments:
transformer_hparams: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
......@@ -277,7 +271,6 @@ class TransformerLanguageModelBase(MegatronModule):
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method,
num_tokentypes=0,
......@@ -302,8 +295,7 @@ class TransformerLanguageModelBase(MegatronModule):
# Transformer.
self.transformer = ParallelTransformer(
attention_mask_func, self.init_method,
output_layer_init_method)
self.init_method, output_layer_init_method)
self._transformer_key = 'transformer'
# Pooler.
......@@ -396,13 +388,11 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method,
num_tokentypes=0,
add_pooler=False):
super(TransformerLanguageModel, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
num_tokentypes=num_tokentypes,
......@@ -427,12 +417,10 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method,
num_tokentypes=0):
super(TransformerLanguageModelFirstStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
num_tokentypes=num_tokentypes)
......@@ -454,11 +442,9 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method):
super(TransformerLanguageModelIntermediateStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method)
......@@ -478,12 +464,10 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method,
add_pooler=False):
super(TransformerLanguageModelLastStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
add_pooler=add_pooler)
......
......@@ -19,7 +19,7 @@ import torch
from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
......@@ -36,7 +36,6 @@ class MultipleChoiceBase(MegatronModule):
init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=True,
init_method=init_method,
......
......@@ -10,7 +10,7 @@ from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.language_model import get_language_model
from megatron.model.utils import scaled_init_method_normal
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
def general_ict_model_provider(only_query_model=False, only_block_model=False):
......@@ -156,7 +156,6 @@ class IREncoderBertModel(MegatronModule):
args.num_layers)
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=True,
init_method=init_method,
......
......@@ -26,7 +26,7 @@ from megatron.checkpointing import get_checkpoint_version
from megatron.model import import_layernorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
......@@ -47,12 +47,6 @@ torch._C._jit_override_can_fuse_on_gpu(True)
Transformer takes input of size [s, b, h] and returns a
tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
"""
class ParallelMLP(MegatronModule):
......@@ -116,13 +110,11 @@ class ParallelSelfAttention(MegatronModule):
and returns output of the same size.
"""
def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number):
def __init__(self, init_method, output_layer_init_method, layer_number):
super(ParallelSelfAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
self.attention_mask_func = attention_mask_func
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
......@@ -155,7 +147,7 @@ class ParallelSelfAttention(MegatronModule):
self.fp16,
args.scaled_upper_triang_masked_softmax_fusion,
args.scaled_masked_softmax_fusion,
self.attention_mask_func,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff)
......@@ -173,7 +165,7 @@ class ParallelSelfAttention(MegatronModule):
skip_bias_add=True)
def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
input_shape = mixed_layer.size();
input_shape = mixed_layer.size()
if num_splits_first:
"""[s, b, num_splits * np * hn]
-->(view) [s, b, num_splits, np, hn]
......@@ -246,7 +238,6 @@ class ParallelSelfAttention(MegatronModule):
if get_key_value:
present = (key_layer, value_layer)
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
......@@ -272,15 +263,15 @@ class ParallelSelfAttention(MegatronModule):
device=torch.cuda.current_device())
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(matmul_result,
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0,1).transpose(1, 2), #[b * np, hn, sk]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ==================================================
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================
......@@ -298,7 +289,6 @@ class ParallelSelfAttention(MegatronModule):
:attention_scores.size(3),
:attention_scores.size(3)]
# ===========================
# Attention probs and dropout
# ===========================
......@@ -312,7 +302,6 @@ class ParallelSelfAttention(MegatronModule):
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
......@@ -335,7 +324,7 @@ class ParallelSelfAttention(MegatronModule):
output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0,1))
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
......@@ -348,7 +337,6 @@ class ParallelSelfAttention(MegatronModule):
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
# =================
# Output. [sq, b, h]
# =================
......@@ -361,7 +349,7 @@ class ParallelSelfAttention(MegatronModule):
return output, bias
def bias_dropout_add(x, bias, residual, prob, training) :
def bias_dropout_add(x, bias, residual, prob, training):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
out = residual + out
......@@ -375,13 +363,13 @@ def get_bias_dropout_add(training):
@torch.jit.script
def bias_dropout_add_fused_train(x, bias, residual, prob) :
def bias_dropout_add_fused_train(x, bias, residual, prob):
# type: (Tensor, Tensor, Tensor, float) -> Tensor
return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script
def bias_dropout_add_fused_inference(x, bias, residual, prob) :
def bias_dropout_add_fused_inference(x, bias, residual, prob):
# type: (Tensor, Tensor, Tensor, float) -> Tensor
return bias_dropout_add(x, bias, residual, prob, False)
......@@ -393,8 +381,7 @@ class ParallelTransformerLayer(MegatronModule):
output of the same size.
"""
def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number):
def __init__(self, init_method, output_layer_init_method, layer_number):
args = get_args()
super(ParallelTransformerLayer, self).__init__()
......@@ -410,7 +397,7 @@ class ParallelTransformerLayer(MegatronModule):
eps=args.layernorm_epsilon)
# Self attention.
self.attention = ParallelSelfAttention(attention_mask_func, init_method,
self.attention = ParallelSelfAttention(init_method,
output_layer_init_method,
layer_number)
self.hidden_dropout = args.hidden_dropout
......@@ -459,7 +446,7 @@ class ParallelTransformerLayer(MegatronModule):
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
#re-enable torch grad to enable fused optimization.
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
......@@ -479,7 +466,7 @@ class ParallelTransformerLayer(MegatronModule):
else:
residual = layernorm_input
#re-enable torch grad to enable fused optimization.
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
......@@ -496,8 +483,7 @@ class ParallelTransformerLayer(MegatronModule):
class ParallelTransformer(MegatronModule):
"""Transformer class."""
def __init__(self, attention_mask_func,
init_method, output_layer_init_method):
def __init__(self, init_method, output_layer_init_method):
super(ParallelTransformer, self).__init__()
args = get_args()
......@@ -515,8 +501,7 @@ class ParallelTransformer(MegatronModule):
# Transformer layers.
def build_layer(layer_number):
return ParallelTransformerLayer(
attention_mask_func, init_method,
output_layer_init_method, layer_number)
init_method, output_layer_init_method, layer_number)
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)])
......
......@@ -39,6 +39,11 @@ def scaled_init_method_normal(sigma, num_layers):
return init_
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def get_linear_layer(rows, columns, init_method):
"""Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns)
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT model."""
import math
import einops
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import (
get_linear_layer,
init_method_normal,
scaled_init_method_normal,
)
from .module import MegatronModule
class VitMlpHead(MegatronModule):
"""Pooler layer.
Pool hidden states of a specific token (for example start of the
sequence) and add a linear transformation followed by a tanh.
Arguments:
hidden_size: hidden size
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def __init__(self, hidden_size, num_classes):
super(VitMlpHead, self).__init__()
self.dense_in = torch.nn.Linear(hidden_size, hidden_size)
self.dense_out = torch.nn.Linear(hidden_size, num_classes)
torch.nn.init.constant_(self.dense_out.bias, -10)
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
x = hidden_states[:, sequence_index, :]
x = self.dense_in(x)
x = torch.tanh(x)
x = self.dense_out(x)
return x
def twod_interpolate_position_embeddings_hook(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
args = get_args()
num_patches_per_dim = args.img_dim // args.patch_dim
num_patches = num_patches_per_dim ** 2
seq_length = num_patches + 1
hidden_size = args.hidden_size
key = prefix + "weight"
# import pdb
# pdb.set_trace()
assert key in state_dict
if key in state_dict:
input_param = state_dict[key]
assert input_param.shape[1] == hidden_size
if input_param.shape[0] != seq_length:
# update input_param and load it to state_dict[key]
num_tok_input = input_param.shape[0] - 1
num_tok_new = seq_length - 1
input_param_tok, input_param_grid = (
input_param[:1, :],
input_param[1:, :],
)
gs_input = int(math.sqrt(num_tok_input))
gs_new = int(math.sqrt(num_tok_new))
input_param_grid = input_param_grid.transpose(0, 1).contiguous()
input_param_grid = input_param_grid.reshape(
(1, -1, gs_input, gs_input)
)
input_param_grid = input_param_grid.float()
scale_factor = gs_new / gs_input
input_param_grid = F.interpolate(
input_param_grid, scale_factor=scale_factor, mode="bilinear"
)
input_param_grid = input_param_grid.half()
input_param_grid = input_param_grid.reshape((-1, gs_new * gs_new))
input_param_grid = input_param_grid.transpose(0, 1).contiguous()
assert input_param_grid.shape[1] == hidden_size
input_param = torch.cat((input_param_tok, input_param_grid), dim=0)
assert (
input_param.shape[0] == seq_length
and input_param.shape[1] == hidden_size
)
state_dict[key] = input_param
class VitModel(MegatronModule):
"""Bert Language model."""
def __init__(self, num_classes, finetune=False):
super(VitModel, self).__init__()
args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
if args.init_method_xavier_uniform:
self.init_method = torch.nn.init.xavier_uniform_
self.scaled_init_method = torch.nn.init.xavier_uniform_
else:
self.init_method = init_method_normal(args.init_method_std)
self.scaled_init_method = scaled_init_method_normal(
args.init_method_std, args.num_layers
)
self.hidden_size = args.hidden_size
self.num_classes = num_classes
self.patch_dim = args.patch_dim
self.img_dim = args.img_dim
self.finetune = finetune
assert self.img_dim % self.patch_dim == 0
self.num_patches_per_dim = self.img_dim // self.patch_dim
self.num_patches = self.num_patches_per_dim ** 2
self.seq_length = self.num_patches + 1
self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels
# cls_token
self.cls_token = torch.nn.Parameter(torch.randn(1, 1, self.hidden_size))
torch.nn.init.zeros_(self.cls_token)
# Linear encoder
self.linear_encoder = torch.nn.Linear(
self.flatten_dim, self.hidden_size
)
# embedding
self.position_embeddings = torch.nn.Embedding(
self.seq_length, self.hidden_size
)
init_method_normal(args.init_method_std)(
self.position_embeddings.weight
)
self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
self.position_embeddings._register_load_state_dict_pre_hook(
twod_interpolate_position_embeddings_hook
)
self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)
# Transformer
self.transformer = ParallelTransformer(
self.init_method, self.scaled_init_method
)
# MLP head
if not self.finetune:
self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes)
else:
self.class_head = get_linear_layer(
self.hidden_size, num_classes, torch.nn.init.zeros_
)
def forward(self, x):
x = einops.rearrange(
x,
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
p1=self.patch_dim,
p2=self.patch_dim,
)
assert x.dtype == torch.half
x = self.linear_encoder(x)
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.position_embeddings(self.position_ids)
x = self.embedding_dropout(x)
x = self.transformer(x, None)
if not self.finetune:
x = self.mlp_head(x)
else:
x = self.class_head(x[:, 0, :])
return x
......@@ -20,7 +20,7 @@ from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_src_rank
_MAX_DATA_DIM = 4
_MAX_DATA_DIM = 5
def _check_data_types(keys, data, target_dtype):
......
......@@ -14,6 +14,7 @@
# limitations under the License.
from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD
from megatron import get_args
from megatron.model import import_layernorm
......@@ -52,11 +53,18 @@ def get_megatron_optimizer(model):
# Base optimizer.
param_groups = _get_params_for_weight_decay_optimization(model)
if args.optimizer == 'adam':
optimizer = Adam(param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps)
else:
assert args.optimizer == 'sgd'
optimizer = SGD(param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
momentum=args.sgd_momentum)
if args.fp16:
# Constant loss scale.
......
......@@ -57,8 +57,12 @@ def print_datetime(string):
print_rank_0('[' + string + '] datetime: {} '.format(time_str))
def pretrain(train_valid_test_dataset_provider, model_provider,
forward_step_func, extra_args_provider=None, args_defaults={}):
def pretrain(train_valid_test_dataset_provider,
model_provider,
forward_step_func,
extra_args_provider=None,
args_defaults={},
random_sample = False):
"""Main training program.
This function will run the followings in the order provided:
......@@ -113,7 +117,8 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
timers('train/valid/test data iterators').start()
train_data_iterator, valid_data_iterator, test_data_iterator \
= build_train_valid_test_data_iterators(
train_valid_test_dataset_provider)
train_valid_test_dataset_provider,
random_sample)
timers('train/valid/test data iterators').stop()
print_datetime('after dataloaders are built')
......@@ -944,8 +949,13 @@ def evaluate_and_print_results(prefix, forward_step_func,
print_rank_last('-' * length)
def cyclic_iterable(iterable):
while True:
for x in iterable:
yield x
def build_train_valid_test_data_iterators(
build_train_valid_test_datasets_provider):
build_train_valid_test_datasets_provider, random_sample=False):
"""XXX"""
args = get_args()
......@@ -989,10 +999,10 @@ def build_train_valid_test_data_iterators(
# Build dataloders.
train_dataloader = build_pretraining_data_loader(
train_ds, args.consumed_train_samples)
train_ds, args.consumed_train_samples, random_sample)
valid_dataloader = build_pretraining_data_loader(
valid_ds, args.consumed_valid_samples)
test_dataloader = build_pretraining_data_loader(test_ds, 0)
valid_ds, args.consumed_valid_samples, random_sample)
test_dataloader = build_pretraining_data_loader(test_ds, 0, random_sample)
# Flags to know if we need to do training/validation/testing.
do_train = train_dataloader is not None and args.train_iters > 0
......@@ -1014,17 +1024,17 @@ def build_train_valid_test_data_iterators(
# Build iterators.
if train_dataloader is not None:
train_data_iterator = iter(train_dataloader)
train_data_iterator = iter(cyclic_iterable(train_dataloader))
else:
train_data_iterator = None
if valid_dataloader is not None:
valid_data_iterator = iter(valid_dataloader)
valid_data_iterator = iter(cyclic_iterable(valid_dataloader))
else:
valid_data_iterator = None
if test_dataloader is not None:
test_data_iterator = iter(test_dataloader)
test_data_iterator = iter(cyclic_iterable(test_dataloader))
else:
test_data_iterator = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment