Commit 539e39b6 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'vision-merge' into 'main'

Phase1 merge: vit optimizations + dataset enhancements + scaled_softmax kernel

See merge request ADLR/megatron-lm!378
parents 9a8b89ac 798b6a64
...@@ -518,6 +518,9 @@ def _add_initialization_args(parser): ...@@ -518,6 +518,9 @@ def _add_initialization_args(parser):
group.add_argument('--seed', type=int, default=1234, group.add_argument('--seed', type=int, default=1234,
help='Random seed used for python, numpy, ' help='Random seed used for python, numpy, '
'pytorch, and cuda.') 'pytorch, and cuda.')
group.add_argument('--data-parallel-random-init', action='store_true',
help='Enable random initialization of params '
'across data parallel ranks')
group.add_argument('--init-method-std', type=float, default=0.02, group.add_argument('--init-method-std', type=float, default=0.02,
help='Standard deviation of the zero mean normal ' help='Standard deviation of the zero mean normal '
'distribution used for weight initialization.') 'distribution used for weight initialization.')
...@@ -835,11 +838,20 @@ def _add_vit_args(parser): ...@@ -835,11 +838,20 @@ def _add_vit_args(parser):
group.add_argument('--num-classes', type=int, default=1000, group.add_argument('--num-classes', type=int, default=1000,
help='num of classes in vision classificaiton task') help='num of classes in vision classificaiton task')
group.add_argument('--img-dim', type=int, default=224, group.add_argument('--img-h', type=int, default=224,
help='Image size for vision classification task') help='Image height for vision classification task')
group.add_argument('--img-w', type=int, default=224,
help='Image height for vision classification task')
group.add_argument('--num-channels', type=int, default=3, group.add_argument('--num-channels', type=int, default=3,
help='Number of channels in input image data') help='Number of channels in input image data')
group.add_argument('--patch-dim', type=int, default=16, group.add_argument('--patch-dim', type=int, default=16,
help='patch dimension used in vit') help='patch dimension used in vit')
group.add_argument('--classes-fraction', type=float, default=1.0,
help='training with fraction of classes.')
group.add_argument('--data-per-class-fraction', type=float, default=1.0,
help='training with fraction of data per class.')
group.add_argument('--no-data-sharding', action='store_false',
help='Disable data sharding.',
dest='data_sharding')
return parser return parser
...@@ -65,6 +65,8 @@ def check_checkpoint_args(checkpoint_args): ...@@ -65,6 +65,8 @@ def check_checkpoint_args(checkpoint_args):
_compare('make_vocab_size_divisible_by') _compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size') _compare('padded_vocab_size')
_compare('tokenizer_type') _compare('tokenizer_type')
if args.data_parallel_random_init:
_compare('data_parallel_random_init')
if get_checkpoint_version() < 3.0: if get_checkpoint_version() < 3.0:
_compare('tensor_model_parallel_size', _compare('tensor_model_parallel_size',
old_arg_name='model_parallel_size') old_arg_name='model_parallel_size')
...@@ -72,7 +74,6 @@ def check_checkpoint_args(checkpoint_args): ...@@ -72,7 +74,6 @@ def check_checkpoint_args(checkpoint_args):
_compare('tensor_model_parallel_size') _compare('tensor_model_parallel_size')
_compare('pipeline_model_parallel_size') _compare('pipeline_model_parallel_size')
def ensure_directory_exists(filename): def ensure_directory_exists(filename):
"""Build filename's path if it does not already exists.""" """Build filename's path if it does not already exists."""
dirname = os.path.dirname(filename) dirname = os.path.dirname(filename)
...@@ -140,6 +141,32 @@ def read_metadata(tracker_filename): ...@@ -140,6 +141,32 @@ def read_metadata(tracker_filename):
return max_iter, release return max_iter, release
def get_rng_state():
""" collect rng state across data parallel ranks """
args = get_args()
rng_state = {
'random_rng_state': random.getstate(),
'np_rng_state': np.random.get_state(),
'torch_rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state(),
'rng_tracker_states': mpu.get_cuda_rng_tracker().get_states()}
rng_state_list = None
if torch.distributed.is_initialized() and \
mpu.get_data_parallel_world_size() > 1 and \
args.data_parallel_random_init:
rng_state_list = \
[None for i in range(mpu.get_data_parallel_world_size())]
torch.distributed.all_gather_object(
rng_state_list,
rng_state,
group=mpu.get_data_parallel_group())
else:
rng_state_list = [rng_state]
return rng_state_list
def save_checkpoint(iteration, model, optimizer, lr_scheduler): def save_checkpoint(iteration, model, optimizer, lr_scheduler):
"""Save a model checkpoint.""" """Save a model checkpoint."""
args = get_args() args = get_args()
...@@ -150,6 +177,9 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -150,6 +177,9 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
print_rank_0('saving checkpoint at iteration {:7d} to {}'.format( print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save)) iteration, args.save))
# collect rng state across data parallel ranks
rng_state = get_rng_state()
if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0: if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0:
# Arguments, iteration, and model. # Arguments, iteration, and model.
...@@ -173,12 +203,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -173,12 +203,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
# RNG states. # RNG states.
if not args.no_save_rng: if not args.no_save_rng:
state_dict['random_rng_state'] = random.getstate() state_dict["rng_state"] = rng_state
state_dict['np_rng_state'] = np.random.get_state()
state_dict['torch_rng_state'] = torch.get_rng_state()
state_dict['cuda_rng_state'] = torch.cuda.get_rng_state()
state_dict['rng_tracker_states'] \
= mpu.get_cuda_rng_tracker().get_states()
# Save. # Save.
checkpoint_name = get_checkpoint_name(args.save, iteration) checkpoint_name = get_checkpoint_name(args.save, iteration)
...@@ -381,15 +406,32 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True ...@@ -381,15 +406,32 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
# rng states. # rng states.
if not release and not args.finetune and not args.no_load_rng: if not release and not args.finetune and not args.no_load_rng:
try: try:
random.setstate(state_dict['random_rng_state']) if 'rng_state' in state_dict:
np.random.set_state(state_dict['np_rng_state']) # access rng_state for data parallel rank
torch.set_rng_state(state_dict['torch_rng_state']) if args.data_parallel_random_init:
torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
# Check for empty states array rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()]
if not state_dict['rng_tracker_states']: else:
raise KeyError rng_state = state_dict['rng_state'][0]
mpu.get_cuda_rng_tracker().set_states( random.setstate(rng_state['random_rng_state'])
state_dict['rng_tracker_states']) np.random.set_state(rng_state['np_rng_state'])
torch.set_rng_state(rng_state['torch_rng_state'])
torch.cuda.set_rng_state(rng_state['cuda_rng_state'])
# Check for empty states array
if not rng_state['rng_tracker_states']:
raise KeyError
mpu.get_cuda_rng_tracker().set_states(
rng_state['rng_tracker_states'])
else: # backward compatability
random.setstate(state_dict['random_rng_state'])
np.random.set_state(state_dict['np_rng_state'])
torch.set_rng_state(state_dict['torch_rng_state'])
torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
# Check for empty states array
if not state_dict['rng_tracker_states']:
raise KeyError
mpu.get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states'])
except KeyError: except KeyError:
print_rank_0('Unable to load rng state from checkpoint {}. ' print_rank_0('Unable to load rng state from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent ' 'Specify --no-load-rng or --finetune to prevent '
......
...@@ -16,8 +16,10 @@ ...@@ -16,8 +16,10 @@
"""Dataloaders.""" """Dataloaders."""
import torch
import random import random
import torch
import numpy as np
from torch.utils.data import Dataset
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
...@@ -39,11 +41,13 @@ def build_pretraining_data_loader(dataset, consumed_samples): ...@@ -39,11 +41,13 @@ def build_pretraining_data_loader(dataset, consumed_samples):
data_parallel_size=mpu.get_data_parallel_world_size()) data_parallel_size=mpu.get_data_parallel_world_size())
elif args.dataloader_type == 'cyclic': elif args.dataloader_type == 'cyclic':
batch_sampler = MegatronPretrainingRandomSampler( batch_sampler = MegatronPretrainingRandomSampler(
dataset,
total_samples=len(dataset), total_samples=len(dataset),
consumed_samples=consumed_samples, consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size, micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size()) data_parallel_size=mpu.get_data_parallel_world_size(),
data_sharding=args.data_sharding)
else: else:
raise Exception('{} dataloader type is not supported.'.format( raise Exception('{} dataloader type is not supported.'.format(
args.dataloader_type)) args.dataloader_type))
...@@ -103,16 +107,40 @@ class MegatronPretrainingSampler: ...@@ -103,16 +107,40 @@ class MegatronPretrainingSampler:
yield batch[start_idx:end_idx] yield batch[start_idx:end_idx]
class RandomSeedDataset(Dataset):
def __init__(self, dataset):
args = get_args()
self.base_seed = args.seed
self.curr_seed = args.seed
self.dataset = dataset
def __len__(self):
return len(self.dataset)
def set_epoch(self, epoch):
self.curr_seed = self.base_seed + epoch
def __getitem__(self, idx):
seed = idx + self.curr_seed
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
return self.dataset[idx]
class MegatronPretrainingRandomSampler: class MegatronPretrainingRandomSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size, def __init__(self, dataset, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size): data_parallel_rank, data_parallel_size, data_sharding):
# Keep a copy of input params for later use. # Keep a copy of input params for later use.
self.dataset = dataset
self.total_samples = total_samples self.total_samples = total_samples
self.consumed_samples = consumed_samples self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size self.data_parallel_size = data_parallel_size
self.data_sharding = data_sharding
self.micro_batch_times_data_parallel_size = \ self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size self.micro_batch_size * data_parallel_size
self.last_batch_size = \ self.last_batch_size = \
...@@ -136,16 +164,30 @@ class MegatronPretrainingRandomSampler: ...@@ -136,16 +164,30 @@ class MegatronPretrainingRandomSampler:
current_epoch_samples = self.consumed_samples % active_total_samples current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
if isinstance(self.dataset, RandomSeedDataset):
self.dataset.set_epoch(self.epoch)
# data sharding and random sampling # data sharding and random sampling
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \ if self.data_sharding:
* self.micro_batch_size bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
bucket_offset = current_epoch_samples // self.data_parallel_size * self.micro_batch_size
start_idx = self.data_parallel_rank * bucket_size bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size
g = torch.Generator()
g.manual_seed(self.epoch) g = torch.Generator()
random_idx = torch.randperm(bucket_size, generator=g).tolist() g.manual_seed(self.epoch)
idx_range = [start_idx + x for x in random_idx[bucket_offset:]] random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
else:
full_bucket_size = (self.total_samples // self.micro_batch_size) \
* self.micro_batch_size
full_bucket_offset = current_epoch_samples
g = torch.Generator()
g.manual_seed(self.epoch)
idx_range_total = \
torch.randperm(full_bucket_size, generator=g).tolist()
idx_range_active = idx_range_total[full_bucket_offset:]
idx_range = idx_range_active[self.data_parallel_rank::self.data_parallel_size]
batch = [] batch = []
# Last batch if not complete will be dropped. # Last batch if not complete will be dropped.
......
# BSD 3-Clause License
#
# Copyright (c) Soumith Chintala 2016,
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# code taken from
# https://github.com/pytorch/vision/blob/main/torchvision/datasets/folder.py
# added support for classes_fraction and data_per_class_fraction
from torchvision.datasets import VisionDataset
from PIL import Image
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
import numpy as np
def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
extensions (tuple of strings): extensions to consider (lowercase)
Returns:
bool: True if the filename ends with one of given extensions
"""
return filename.lower().endswith(extensions)
def is_image_file(filename: str) -> bool:
"""Checks if a file is an allowed image extension.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
def make_dataset(
directory: str,
class_to_idx: Dict[str, int],
data_per_class_fraction: float,
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
Args:
directory (str): root dataset directory
class_to_idx (Dict[str, int]): dictionary mapping class name to class index
extensions (optional): A list of allowed extensions.
Either extensions or is_valid_file should be passed. Defaults to None.
is_valid_file (optional): A function that takes path of a file
and checks if the file is a valid file
(used to check of corrupt files) both extensions and
is_valid_file should not be passed. Defaults to None.
Raises:
ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
Returns:
List[Tuple[str, int]]: samples of a form (path_to_sample, class)
"""
instances = []
directory = os.path.expanduser(directory)
both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
is_valid_file = cast(Callable[[str], bool], is_valid_file)
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
local_instances = []
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid_file(path):
item = path, class_index
local_instances.append(item)
instances.extend(local_instances[0:int(len(local_instances) * data_per_class_fraction)])
return instances
class DatasetFolder(VisionDataset):
"""A generic data loader where the samples are arranged in this way: ::
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/[...]/xxz.ext
root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/[...]/asd932_.ext
Args:
root (string): Root directory path.
loader (callable): A function to load a sample given its path.
extensions (tuple[string]): A list of allowed extensions.
both extensions and is_valid_file should not be passed.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
is_valid_file (callable, optional): A function that takes path of a file
and check if the file is a valid file (used to check of corrupt files)
both extensions and is_valid_file should not be passed.
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
targets (list): The class_index value for each image in the dataset
"""
def __init__(
self,
root: str,
loader: Callable[[str], Any],
extensions: Optional[Tuple[str, ...]] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
classes_fraction=1.0,
data_per_class_fraction=1.0,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> None:
super(DatasetFolder, self).__init__(root, transform=transform,
target_transform=target_transform)
self.classes_fraction = classes_fraction
self.data_per_class_fraction = data_per_class_fraction
classes, class_to_idx = self._find_classes(self.root)
samples = self.make_dataset(self.root,
class_to_idx,
self.data_per_class_fraction,
extensions,
is_valid_file)
if len(samples) == 0:
msg = "Found 0 files in subfolders of: {}\n".format(self.root)
if extensions is not None:
msg += "Supported extensions are: {}".format(",".join(extensions))
raise RuntimeError(msg)
self.loader = loader
self.extensions = extensions
self.total = len(samples)
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
@staticmethod
def make_dataset(
directory: str,
class_to_idx: Dict[str, int],
data_per_class_fraction: float,
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
return make_dataset(directory,
class_to_idx,
data_per_class_fraction,
extensions=extensions,
is_valid_file=is_valid_file)
def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
all_classes = [d.name for d in os.scandir(dir) if d.is_dir()]
classes = all_classes[0:int(len(all_classes) * self.classes_fraction)]
classes.sort()
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
curr_index = index
for x in range(self.total):
try:
path, target = self.samples[curr_index]
sample = self.loader(path)
break
except Exception as e:
curr_index = np.random.randint(0, self.total)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self) -> int:
return len(self.samples)
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
# TODO: specify the return type
def accimage_loader(path: str) -> Any:
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_loader(path: str) -> Any:
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
class ImageFolder(DatasetFolder):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
classes_fraction=1.0,
data_per_class_fraction=1.0,
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
):
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
classes_fraction=classes_fraction,
data_per_class_fraction=data_per_class_fraction,
is_valid_file=is_valid_file)
self.imgs = self.samples
...@@ -13,46 +13,67 @@ ...@@ -13,46 +13,67 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import random
import numpy as np
import torch import torch
from torchvision import datasets, transforms import torchvision.transforms as T
from torchvision import datasets
from megatron import get_args
from megatron.data.image_folder import ImageFolder
from megatron.data.autoaugment import ImageNetPolicy from megatron.data.autoaugment import ImageNetPolicy
from megatron.data.data_samplers import RandomSeedDataset
class ClassificationTransform():
def __init__(self, image_size, train=True):
args = get_args()
assert args.fp16 or args.bf16
self.data_type = torch.half if args.fp16 else torch.bfloat16
if train:
self.transform = T.Compose([
T.RandomResizedCrop(image_size),
T.RandomHorizontalFlip(),
T.ColorJitter(0.4, 0.4, 0.4, 0.1),
ImageNetPolicy(),
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
T.ConvertImageDtype(self.data_type)
])
else:
self.transform = T.Compose([
T.Resize(image_size),
T.CenterCrop(image_size),
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
T.ConvertImageDtype(self.data_type)
])
def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True): def __call__(self, input):
output = self.transform(input)
return output
def build_train_valid_datasets(data_path, image_size=224):
args = get_args()
train_transform = ClassificationTransform(image_size)
val_transform = ClassificationTransform(image_size, train=False)
# training dataset # training dataset
train_data_path = os.path.join(data_path[0], "train") train_data_path = data_path[0]
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) train_data = ImageFolder(
process = [ root=train_data_path,
transforms.RandomResizedCrop(crop_size), transform=train_transform,
transforms.RandomHorizontalFlip(), classes_fraction=args.classes_fraction,
] data_per_class_fraction=args.data_per_class_fraction
if color_jitter:
process += [
transforms.ColorJitter(
brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1
)
]
fp16_t = transforms.ConvertImageDtype(torch.half)
process += [ImageNetPolicy(), transforms.ToTensor(), normalize, fp16_t]
transform_train = transforms.Compose(process)
train_data = datasets.ImageFolder(
root=train_data_path, transform=transform_train
) )
train_data = RandomSeedDataset(train_data)
# validation dataset # validation dataset
val_data_path = os.path.join(data_path[0], "val") val_data_path = data_path[1]
transform_val = transforms.Compose( val_data = ImageFolder(
[ root=val_data_path,
transforms.Resize(crop_size), transform=val_transform
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
normalize,
fp16_t
]
)
val_data = datasets.ImageFolder(
root=val_data_path, transform=transform_val
) )
val_data = RandomSeedDataset(val_data)
return train_data, val_data return train_data, val_data
...@@ -78,6 +78,12 @@ def load(args): ...@@ -78,6 +78,12 @@ def load(args):
scaled_masked_softmax_cuda = _cpp_extention_load_helper( scaled_masked_softmax_cuda = _cpp_extention_load_helper(
"scaled_masked_softmax_cuda", sources, extra_cuda_flags) "scaled_masked_softmax_cuda", sources, extra_cuda_flags)
# Softmax
sources=[srcpath / 'scaled_softmax.cpp',
srcpath / 'scaled_softmax_cuda.cu']
scaled_softmax_cuda = _cpp_extention_load_helper(
"scaled_softmax_cuda", sources, extra_cuda_flags)
# ================================= # =================================
# Mixed precision fused layer norm. # Mixed precision fused layer norm.
# ================================= # =================================
......
...@@ -90,6 +90,117 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) { ...@@ -90,6 +90,117 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
} }
} }
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
*/
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_softmax_warp_forward(
output_t *dst,
const input_t *src,
const acc_t scale,
int micro_batch_size,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
int itr_idx = i*element_count+it*WARP_SIZE;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = (acc_t)temp_data[element] * scale;
}
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
}
}
// compute max_value
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
} else {
break;
}
}
}
}
/* /*
* Extended softmax (from native aten pytorch) with following additional features * Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling * 1) input scaling
...@@ -326,6 +437,98 @@ int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int att ...@@ -326,6 +437,98 @@ int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int att
return batches_per_block; return batches_per_block;
} }
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_softmax_forward(
output_t *dst,
const input_t *src,
const input_t scale,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads)
{
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 );
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 1: // 2
scaled_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 2: // 4
scaled_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 3: // 8
scaled_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 4: // 16
scaled_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 5: // 32
scaled_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 6: // 64
scaled_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 7: // 128
scaled_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 8: // 256
scaled_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 9: // 512
scaled_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 10: // 1024
scaled_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 11: // 2048
scaled_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 12: // 4096
scaled_softmax_warp_forward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
default:
break;
}
}
}
template<typename input_t, typename output_t, typename acc_t> template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_forward( void dispatch_scaled_masked_softmax_forward(
output_t *dst, output_t *dst,
...@@ -338,7 +541,7 @@ void dispatch_scaled_masked_softmax_forward( ...@@ -338,7 +541,7 @@ void dispatch_scaled_masked_softmax_forward(
int attn_heads, int attn_heads,
int pad_batches) int pad_batches)
{ {
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 ); TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 );
if (key_seq_len == 0) { if (key_seq_len == 0) {
return; return;
} else { } else {
...@@ -410,6 +613,10 @@ void dispatch_scaled_masked_softmax_forward( ...@@ -410,6 +613,10 @@ void dispatch_scaled_masked_softmax_forward(
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 12: // 4096
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
default: default:
break; break;
} }
...@@ -427,7 +634,7 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -427,7 +634,7 @@ void dispatch_scaled_masked_softmax_backward(
int batches, int batches,
int attn_heads) int attn_heads)
{ {
TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 ); TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 4096 );
if (key_seq_len == 0) { if (key_seq_len == 0) {
return; return;
} else { } else {
...@@ -498,6 +705,11 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -498,6 +705,11 @@ void dispatch_scaled_masked_softmax_backward(
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len); <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 12: // 4096
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
default: default:
break; break;
} }
......
...@@ -44,7 +44,7 @@ torch::Tensor fwd_cuda( ...@@ -44,7 +44,7 @@ torch::Tensor fwd_cuda(
const int attn_heads = input.size(1); const int attn_heads = input.size(1);
const int query_seq_len = input.size(2); const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3); const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 2048); TORCH_INTERNAL_ASSERT(key_seq_len <= 4096);
TORCH_INTERNAL_ASSERT(query_seq_len > 1); TORCH_INTERNAL_ASSERT(query_seq_len > 1);
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
TORCH_INTERNAL_ASSERT(mask.size(1) == 1); TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
......
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
torch::Tensor fwd(
torch::Tensor const& input,
float scale_factor) {
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return fwd_cuda(input, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
} // end namespace scaled_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&multihead_attn::fused_softmax::scaled_softmax::fwd,
"Self Multihead Attention scaled, softmax -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::scaled_softmax::bwd,
"Self Multihead Attention scaled, softmax -- Backward.");
}
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = input.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 4096);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_softmax_forward",
dispatch_scaled_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads);
);
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1);
const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads);
);
//backward pass is completely in-place
return output_grads;
}
}
}
}
...@@ -62,7 +62,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -62,7 +62,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Random seeds for reproducibility. # Random seeds for reproducibility.
if args.rank == 0: if args.rank == 0:
print('> setting random seeds to {} ...'.format(args.seed)) print('> setting random seeds to {} ...'.format(args.seed))
_set_random_seed(args.seed) _set_random_seed(args.seed, args.data_parallel_random_init)
# Set pytorch JIT layer fusion options. # Set pytorch JIT layer fusion options.
_set_jit_fusion_options() _set_jit_fusion_options()
...@@ -118,7 +118,7 @@ def _compile_dependencies(): ...@@ -118,7 +118,7 @@ def _compile_dependencies():
args.micro_batch_size args.micro_batch_size
# Constraints on sequence length and attn_batch_size to enable warp based # Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask) # optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \ custom_kernel_constraint = seq_len > 16 and seq_len <=4096 and \
seq_len % 4 == 0 and attn_batch_size % 4 == 0 seq_len % 4 == 0 and attn_batch_size % 4 == 0
# Print a warning. # Print a warning.
if not ((args.fp16 or args.bf16) and if not ((args.fp16 or args.bf16) and
...@@ -203,11 +203,14 @@ def _init_autoresume(): ...@@ -203,11 +203,14 @@ def _init_autoresume():
torch.distributed.barrier() torch.distributed.barrier()
def _set_random_seed(seed_): def _set_random_seed(seed_, data_parallel_random_init=False):
"""Set random seed for reproducability.""" """Set random seed for reproducability."""
if seed_ is not None and seed_ > 0: if seed_ is not None and seed_ > 0:
# Ensure that different pipeline MP stages get different seeds. # Ensure that different pipeline MP stages get different seeds.
seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank()) seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
# Ensure different data parallel ranks get different seeds
if data_parallel_random_init:
seed = seed + (10 * mpu.get_data_parallel_rank())
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
......
...@@ -185,6 +185,13 @@ class DistributedDataParallel(DistributedDataParallelBase): ...@@ -185,6 +185,13 @@ class DistributedDataParallel(DistributedDataParallelBase):
buffer_.zero() buffer_.zero()
def broadcast_params(self):
for param in self.module.parameters():
torch.distributed.broadcast(param.data,
src=mpu.get_data_parallel_src_rank(),
group=mpu.get_data_parallel_group())
def allreduce_gradients(self): def allreduce_gradients(self):
"""Reduce gradients across data parallel ranks.""" """Reduce gradients across data parallel ranks."""
# If we have buffers, simply reduce the data in the buffer. # If we have buffers, simply reduce the data in the buffer.
......
...@@ -81,6 +81,37 @@ class ScaledMaskedSoftmax(torch.autograd.Function): ...@@ -81,6 +81,37 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
return input_grads, None, None return input_grads, None, None
class ScaledSoftmax(torch.autograd.Function):
"""
Fused operation which performs following two operations in sequence
1. Scale the tensor.
2. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, scale):
import scaled_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_softmax_cuda.forward(
inputs, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
import scaled_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_softmax_cuda.backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None, None
class FusedScaleMaskSoftmax(nn.Module): class FusedScaleMaskSoftmax(nn.Module):
""" """
fused operation: scaling + mask + softmax fused operation: scaling + mask + softmax
...@@ -137,12 +168,11 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -137,12 +168,11 @@ class FusedScaleMaskSoftmax(nn.Module):
if ( if (
self.scaled_masked_softmax_fusion # user want to fuse self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16 and self.input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None and 16 < sk <= 4096 # sk must be 16 ~ 2048
and 16 < sk <= 2048 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4 and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4 and attn_batches % 4 == 0 # np * b must be divisor of 4
): ):
if 0 <= sk <= 2048: if 0 <= sk <= 4096:
batch_per_block = self.get_batch_per_block(sq, sk, b, np) batch_per_block = self.get_batch_per_block(sq, sk, b, np)
if self.attn_mask_type == AttnMaskType.causal: if self.attn_mask_type == AttnMaskType.causal:
...@@ -166,7 +196,10 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -166,7 +196,10 @@ class FusedScaleMaskSoftmax(nn.Module):
return probs.view(b, np, sq, sk) return probs.view(b, np, sq, sk)
else: else:
# input is 4D tensor (b, np, sq, sk) # input is 4D tensor (b, np, sq, sk)
return ScaledMaskedSoftmax.apply(input, mask, scale) if mask is not None:
return ScaledMaskedSoftmax.apply(input, mask, scale)
else:
return ScaledSoftmax.apply(input, scale)
def forward_torch_softmax(self, input, mask): def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32: if self.input_in_float16 and self.softmax_in_fp32:
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision Transformer(VIT) model."""
import torch
from megatron import get_args
from megatron.model.utils import get_linear_layer
from megatron.model.vision.vit_backbone import VitBackbone, VitMlpHead
from megatron.model.module import MegatronModule
class VitClassificationModel(MegatronModule):
"""Vision Transformer Model."""
def __init__(self, num_classes, finetune=False,
pre_process=True, post_process=True):
super(VitClassificationModel, self).__init__()
args = get_args()
self.hidden_size = args.hidden_size
self.num_classes = num_classes
self.finetune = finetune
self.pre_process = pre_process
self.post_process = post_process
self.backbone = VitBackbone(
pre_process=self.pre_process,
post_process=self.post_process,
single_token_output=True
)
if self.post_process:
if not self.finetune:
self.head = VitMlpHead(self.hidden_size, self.num_classes)
else:
self.head = get_linear_layer(
self.hidden_size,
self.num_classes,
torch.nn.init.zeros_
)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.backbone.set_input_tensor(input_tensor)
def forward(self, input):
hidden_states = self.backbone(input)
if self.post_process:
hidden_states = self.head(hidden_states)
return hidden_states
...@@ -18,16 +18,19 @@ ...@@ -18,16 +18,19 @@
import math import math
import einops import einops
import torch import torch
import apex
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron.model import LayerNorm
from megatron.model.transformer import ParallelTransformer from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import ( from megatron.model.utils import (
get_linear_layer, get_linear_layer,
init_method_normal, init_method_normal,
scaled_init_method_normal, scaled_init_method_normal,
) )
from .module import MegatronModule from megatron.model.module import MegatronModule
CLASS_TOKEN_LENGTH = 8
class VitMlpHead(MegatronModule): class VitMlpHead(MegatronModule):
"""Pooler layer. """Pooler layer.
...@@ -44,19 +47,26 @@ class VitMlpHead(MegatronModule): ...@@ -44,19 +47,26 @@ class VitMlpHead(MegatronModule):
def __init__(self, hidden_size, num_classes): def __init__(self, hidden_size, num_classes):
super(VitMlpHead, self).__init__() super(VitMlpHead, self).__init__()
self.dense_in = torch.nn.Linear(hidden_size, hidden_size) self.dense_in = torch.nn.Linear(hidden_size, hidden_size)
self.relu = torch.nn.ReLU()
self.dense_out = torch.nn.Linear(hidden_size, num_classes) self.dense_out = torch.nn.Linear(hidden_size, num_classes)
torch.nn.init.constant_(self.dense_out.bias, -10) torch.nn.init.constant_(self.dense_out.bias, -10)
def forward(self, hidden_states, sequence_index=0): def forward(self, hidden_states):
# hidden_states: [b, s, h] # hidden_states: [b, 1, h]
# sequence_index: index of the token to pool. # sequence_index: index of the token to pool.
hidden_state = hidden_states[:, sequence_index, :] dense_in_result = self.dense_in(hidden_states)
dense_in_result = self.dense_in(hidden_state)
tanh_result = torch.tanh(dense_in_result) tanh_result = torch.tanh(dense_in_result)
dense_out_result = self.dense_out(tanh_result) dense_out_result = self.dense_out(tanh_result)
return dense_out_result return dense_out_result
def isPerfectSquare(x):
if(x >= 0):
sr = math.sqrt(x)
return (int(sr) * int(sr) == x)
return False
def twod_interpolate_position_embeddings_hook( def twod_interpolate_position_embeddings_hook(
state_dict, state_dict,
prefix, prefix,
...@@ -68,66 +78,77 @@ def twod_interpolate_position_embeddings_hook( ...@@ -68,66 +78,77 @@ def twod_interpolate_position_embeddings_hook(
): ):
args = get_args() args = get_args()
num_patches_per_dim = args.img_dim // args.patch_dim num_patches_per_dim_h = args.img_h // args.patch_dim
num_patches = num_patches_per_dim ** 2 num_patches_per_dim_w = args.img_w // args.patch_dim
seq_length = num_patches + 1 num_patches = num_patches_per_dim_h * num_patches_per_dim_w
hidden_size = args.hidden_size hidden_size = args.hidden_size
key = prefix + "weight" key = prefix + "weight"
# import pdb
# pdb.set_trace()
assert key in state_dict assert key in state_dict
if key in state_dict: if key in state_dict:
input_param = state_dict[key] input_param = state_dict[key]
input_seq_len = input_param.shape[0]
assert(isPerfectSquare(input_seq_len) or isPerfectSquare(input_seq_len - CLASS_TOKEN_LENGTH))
input_has_class_token = not isPerfectSquare(input_seq_len)
num_tok_input = input_seq_len - CLASS_TOKEN_LENGTH if input_has_class_token else input_seq_len
num_tok_output = num_patches
output_has_class_token = args.class_token_present
# update input_param and load it to state_dict[key]
if input_has_class_token:
input_param_tok = input_param[:CLASS_TOKEN_LENGTH, :]
input_param_grid = input_param[CLASS_TOKEN_LENGTH:, :]
else:
input_param_tok = torch.zeros(CLASS_TOKEN_LENGTH, hidden_size)
input_param_grid = input_param
assert input_param.shape[1] == hidden_size assert input_param.shape[1] == hidden_size
if input_param.shape[0] != seq_length:
# update input_param and load it to state_dict[key] if num_tok_input != num_tok_output:
num_tok_input = input_param.shape[0] - 1
num_tok_new = seq_length - 1
input_param_tok, input_param_grid = (
input_param[:1, :],
input_param[1:, :],
)
gs_input = int(math.sqrt(num_tok_input)) gs_input = int(math.sqrt(num_tok_input))
gs_new = int(math.sqrt(num_tok_new)) gs_new = (num_patches_per_dim_h, num_patches_per_dim_w)
input_param_grid = input_param_grid.transpose(0, 1).contiguous() input_param_grid = input_param_grid.transpose(0, 1).contiguous()
input_param_grid = input_param_grid.reshape( input_param_grid = input_param_grid.reshape(
(1, -1, gs_input, gs_input) (1, -1, gs_input, gs_input)
) )
input_param_grid = input_param_grid.float() input_param_grid = input_param_grid.float()
scale_factor = gs_new / gs_input scale_factor = (gs_new[0] / gs_input, gs_new[1] / gs_input)
input_param_grid = F.interpolate( input_param_grid = F.interpolate(
input_param_grid, scale_factor=scale_factor, mode="bilinear" input_param_grid, scale_factor=scale_factor, mode="bilinear"
) )
input_param_grid = input_param_grid.half() input_param_grid = input_param_grid.half()
input_param_grid = input_param_grid.reshape((-1, gs_new * gs_new)) input_param_grid = input_param_grid.reshape((-1, num_tok_output))
input_param_grid = input_param_grid.transpose(0, 1).contiguous() input_param_grid = input_param_grid.transpose(0, 1).contiguous()
assert input_param_grid.shape[1] == hidden_size assert input_param_grid.shape[1] == hidden_size
input_param = torch.cat((input_param_tok, input_param_grid), dim=0)
assert (
input_param.shape[0] == seq_length
and input_param.shape[1] == hidden_size
)
state_dict[key] = input_param input_param = input_param_grid
assert (
input_param.shape[0] == num_tok_output
and input_param.shape[1] == hidden_size
)
if output_has_class_token:
input_param = torch.cat((input_param_tok, input_param), dim=0)
state_dict[key] = input_param
class VitModel(MegatronModule): class VitBackbone(MegatronModule):
"""Vision Transformer Model.""" """Vision Transformer Model."""
def __init__(self, def __init__(self,
num_classes,
finetune=False,
pre_process=True, pre_process=True,
post_process=True): post_process=True,
super(VitModel, self).__init__(share_word_embeddings=False) class_token=True,
single_token_output=False):
super(VitBackbone, self).__init__(share_word_embeddings=False)
args = get_args() args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
...@@ -142,25 +163,33 @@ class VitModel(MegatronModule): ...@@ -142,25 +163,33 @@ class VitModel(MegatronModule):
self.pre_process = pre_process self.pre_process = pre_process
self.post_process = post_process self.post_process = post_process
self.class_token = class_token
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
self.num_classes = num_classes
self.patch_dim = args.patch_dim self.patch_dim = args.patch_dim
self.img_dim = args.img_dim self.img_h = args.img_h
self.finetune = finetune self.img_w = args.img_w
self.micro_batch_size = args.micro_batch_size
assert self.img_dim % self.patch_dim == 0 self.single_token_output = single_token_output
self.num_patches_per_dim = self.img_dim // self.patch_dim
self.num_patches = self.num_patches_per_dim ** 2 assert self.img_h % self.patch_dim == 0
self.seq_length = self.num_patches + 1 assert self.img_w % self.patch_dim == 0
self.num_patches_per_dim_h = self.img_h // self.patch_dim
self.num_patches_per_dim_w = self.img_w // self.patch_dim
self.num_patches = self.num_patches_per_dim_h * self.num_patches_per_dim_w
self.seq_length = self.num_patches + (CLASS_TOKEN_LENGTH if self.class_token else 0)
self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels
self.input_tensor = None
self.position_ids = None
if self.pre_process: if self.pre_process:
# cls_token # cls_token
self.cls_token = torch.nn.Parameter( if self.class_token:
torch.randn(1, 1, self.hidden_size) self.cls_token = torch.nn.Parameter(
) torch.randn(1, CLASS_TOKEN_LENGTH, self.hidden_size)
torch.nn.init.zeros_(self.cls_token) )
torch.nn.init.zeros_(self.cls_token)
self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
# Linear encoder # Linear encoder
self.linear_encoder = torch.nn.Linear( self.linear_encoder = torch.nn.Linear(
self.flatten_dim, self.hidden_size self.flatten_dim, self.hidden_size
...@@ -173,8 +202,8 @@ class VitModel(MegatronModule): ...@@ -173,8 +202,8 @@ class VitModel(MegatronModule):
init_method_normal(args.init_method_std)( init_method_normal(args.init_method_std)(
self.position_embeddings.weight self.position_embeddings.weight
) )
self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
args.class_token_present = self.class_token
self.position_embeddings._register_load_state_dict_pre_hook( self.position_embeddings._register_load_state_dict_pre_hook(
twod_interpolate_position_embeddings_hook twod_interpolate_position_embeddings_hook
) )
...@@ -183,21 +212,12 @@ class VitModel(MegatronModule): ...@@ -183,21 +212,12 @@ class VitModel(MegatronModule):
# Transformer # Transformer
self.transformer = ParallelTransformer( self.transformer = ParallelTransformer(
self.init_method, self.init_method,
self.scaled_init_method, self.scaled_init_method,
pre_process=self.pre_process, pre_process=self.pre_process,
post_process=self.post_process post_process=self.post_process,
) )
if self.post_process:
# MLP head
if not self.finetune:
self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes)
else:
self.class_head = get_linear_layer(
self.hidden_size, num_classes, torch.nn.init.zeros_
)
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()""" """See megatron.model.transformer.set_input_tensor()"""
self.transformer.set_input_tensor(input_tensor) self.transformer.set_input_tensor(input_tensor)
...@@ -214,21 +234,22 @@ class VitModel(MegatronModule): ...@@ -214,21 +234,22 @@ class VitModel(MegatronModule):
assert rearranged_input.dtype == torch.half assert rearranged_input.dtype == torch.half
encoder_output = self.linear_encoder(rearranged_input) encoder_output = self.linear_encoder(rearranged_input)
cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1)
concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1) concatenated_tokens = encoder_output
if self.class_token:
cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1)
concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1)
token_embeddings = concatenated_tokens + \ token_embeddings = concatenated_tokens + \
self.position_embeddings(self.position_ids) self.position_embeddings(self.position_ids[:, :concatenated_tokens.shape[1]])
hidden_states = self.embedding_dropout(token_embeddings) hidden_states = self.embedding_dropout(token_embeddings)
else: else:
hidden_states = input hidden_states = input
hidden_states = self.transformer(hidden_states, None) hidden_states = self.transformer(hidden_states, None)
if self.post_process: if self.single_token_output:
if not self.finetune: hidden_states = hidden_states[:,0,:]
hidden_states = self.mlp_head(hidden_states)
else:
hidden_states = self.class_head(hidden_states[:, 0, :])
return hidden_states return hidden_states
...@@ -38,6 +38,7 @@ from .initialize import is_pipeline_stage_before_split, is_pipeline_stage_after_ ...@@ -38,6 +38,7 @@ from .initialize import is_pipeline_stage_before_split, is_pipeline_stage_after_
from .initialize import is_pipeline_stage_at_split from .initialize import is_pipeline_stage_at_split
from .initialize import get_num_layers from .initialize import get_num_layers
from .initialize import get_tensor_model_parallel_src_rank from .initialize import get_tensor_model_parallel_src_rank
from .initialize import get_data_parallel_src_rank
from .initialize import get_pipeline_model_parallel_first_rank from .initialize import get_pipeline_model_parallel_first_rank
from .initialize import get_pipeline_model_parallel_last_rank from .initialize import get_pipeline_model_parallel_last_rank
from .initialize import get_pipeline_model_parallel_next_rank from .initialize import get_pipeline_model_parallel_next_rank
......
...@@ -452,6 +452,15 @@ def get_tensor_model_parallel_src_rank(): ...@@ -452,6 +452,15 @@ def get_tensor_model_parallel_src_rank():
return (global_rank // local_world_size) * local_world_size return (global_rank // local_world_size) * local_world_size
def get_data_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank = torch.distributed.get_rank()
data_parallel_size = get_data_parallel_world_size()
num_data_parallel_groups = torch.distributed.get_world_size() // data_parallel_size
return global_rank % num_data_parallel_groups
def get_pipeline_model_parallel_first_rank(): def get_pipeline_model_parallel_first_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \ assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized" "Pipeline parallel group is not initialized"
......
...@@ -35,14 +35,14 @@ def _get_params_for_weight_decay_optimization(modules): ...@@ -35,14 +35,14 @@ def _get_params_for_weight_decay_optimization(modules):
if isinstance(module_, LayerNorm): if isinstance(module_, LayerNorm):
no_weight_decay_params['params'].extend( no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values()) [p for p in list(module_._parameters.values())
if p is not None]) if p is not None and p.requires_grad])
else: else:
weight_decay_params['params'].extend( weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items()) [p for n, p in list(module_._parameters.items())
if p is not None and n != 'bias']) if p is not None and p.requires_grad and n != 'bias'])
no_weight_decay_params['params'].extend( no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items()) [p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias']) if p is not None and p.requires_grad and n == 'bias'])
return weight_decay_params, no_weight_decay_params return weight_decay_params, no_weight_decay_params
......
...@@ -285,7 +285,10 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap ...@@ -285,7 +285,10 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
args.accumulate_allreduce_grads_in_fp32, args.accumulate_allreduce_grads_in_fp32,
args.use_contiguous_buffers_in_local_ddp) args.use_contiguous_buffers_in_local_ddp)
for model_module in model] for model_module in model]
# broad cast params from data parallel src rank to other data parallel ranks
if args.data_parallel_random_init:
for model_module in model:
model_module.broadcast_params()
else: else:
raise NotImplementedError('Unknown DDP implementation specified: ' raise NotImplementedError('Unknown DDP implementation specified: '
'{}. Exiting.'.format(args.DDP_impl)) '{}. Exiting.'.format(args.DDP_impl))
......
...@@ -21,7 +21,7 @@ from functools import partial ...@@ -21,7 +21,7 @@ from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0 from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model import ModelType from megatron.model import ModelType
from megatron.model.vit_model import VitModel from megatron.model.vision.classification import VitClassificationModel
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
...@@ -31,9 +31,9 @@ def model_provider(pre_process=True, post_process=True): ...@@ -31,9 +31,9 @@ def model_provider(pre_process=True, post_process=True):
print_rank_0("building VIT model ...") print_rank_0("building VIT model ...")
args = get_args() args = get_args()
model = VitModel(num_classes=args.num_classes, model = VitClassificationModel(num_classes=args.num_classes,
pre_process=pre_process, pre_process=pre_process,
post_process=post_process) post_process=post_process)
return model return model
def get_batch(data_iterator): def get_batch(data_iterator):
...@@ -82,7 +82,10 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -82,7 +82,10 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
print_rank_0( print_rank_0(
"> building train, validation, and test datasets " "for VIT ..." "> building train, validation, and test datasets " "for VIT ..."
) )
train_ds, valid_ds = build_train_valid_datasets(data_path=args.data_path) train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
print_rank_0("> finished creating VIT datasets ...") print_rank_0("> finished creating VIT datasets ...")
return train_ds, valid_ds, None return train_ds, valid_ds, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment