Unverified Commit da01c234 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Develop/experiments (#59)



* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.

* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>

* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000

* Integrate 1d tensor parallel in Colossal-AI (#39)

* fixed 1D and 2D convergence (#38)

* optimized 2D operations

* fixed 1D ViT convergence problem

* Feature/ddp (#49)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.

* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* support torch ddp

* fix loss accumulation

* add log for ddp

* change seed

* modify timing hook
Co-authored-by: default avatarFrank Lee <somerlee.9@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>

* Feature/pipeline (#40)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.

* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* optimize communication of pipeline parallel

* fix grad clip for pipeline
Co-authored-by: default avatarFrank Lee <somerlee.9@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>

* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)

* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset

* update api for better usability (#58)

update api for better usability
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>
Co-authored-by: default avatarpuck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>
Co-authored-by: default avatarアマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: default avatarBoxiangW <45734921+BoxiangW@users.noreply.github.com>
parent eb2f8b1f
from .caltech101_dataset import Caltech101Dataset
from .cifar10_dataset import CIFAR10Dataset
from .sampler import *
import numpy as np
def pil_img_to_numpy(pil_img):
"""convert a PIL image to numpy nd-array
:param pil_img: a PIL image
:type pil_img: PIL.Image
:return: a nd-array
:rtype: numpy.ndarray
"""
np_img = np.array(pil_img)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return np_img
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC
from torch.utils.data import Dataset
from torchvision.transforms import transforms
from colossalai.builder import build_transform
class BaseDataset(Dataset, ABC):
def __init__(self, transform_pipeline: list):
transform_list = [build_transform(cfg) for cfg in transform_pipeline]
transform = transforms.Compose(transform_list)
self._transform_pipeline = transform
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.distributed as dist
from torchvision.datasets import Caltech101
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import DATASETS
from .base_dataset import BaseDataset
@DATASETS.register_module
class Caltech101Dataset(BaseDataset):
"""`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.
:param transform_pipeline: A list of functions' config, which takes in an PIL image
and returns a transformed version
:type transform_pipeline: list
"""
def __init__(self, transform_pipeline: list, *args, **kwargs):
super().__init__(transform_pipeline)
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() != 0:
dist.barrier()
self._dataset = Caltech101(
transform=self._transform_pipeline, *args, **kwargs)
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() == 0:
dist.barrier()
def __len__(self):
return len(self._dataset)
def __getitem__(self, item):
"""
:param item: Index
:type item: int
:return: ((image,), (target,)) where the type of target specified by target_type.
:rtype: tuple
"""
img, label = self._dataset.__getitem__(item)
return (img,), (label,)
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.distributed as dist
from torchvision.datasets import CIFAR10
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import DATASETS
from .base_dataset import BaseDataset
@DATASETS.register_module
class CIFAR10Dataset(BaseDataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
:param transform_pipeline: A list of functions' config, which takes in an PIL image
and returns a transformed version
:type transform_pipeline: list
"""
def __init__(self, transform_pipeline: list, *args, **kwargs):
super().__init__(transform_pipeline)
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() != 0:
dist.barrier()
self._dataset = CIFAR10(transform=self._transform_pipeline,
*args,
**kwargs)
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() == 0:
dist.barrier()
def __len__(self):
return len(self._dataset)
def __getitem__(self, item):
"""
:param item: Index
:type item: int
:return: ((image,), (target,)) where the type of target specified by target_type.
:rtype: tuple
"""
img, label = self._dataset.__getitem__(item)
return (img,), (label,)
from .base_sampler import BaseSampler
from .data_parallel_sampler import DataParallelSampler
__all__ = ['BaseSampler', 'DataParallelSampler']
import math
from torch import Tensor
from torch.nn import init as init
def init_weight_(tensor: Tensor, fan_in: int, fan_out: int = None, init_method: str = 'torch'):
if init_method == 'torch':
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
init.uniform_(tensor, -bound, bound)
elif init_method == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(tensor, -a, a)
elif init_method == 'jax_embed':
std = math.sqrt(1.0 / fan_in)
init.trunc_normal_(tensor, std=std / .87962566103423978)
elif init_method == 'zero':
init.zeros_(tensor)
def init_bias_(tensor: Tensor, fan_in: int, init_method: str = 'torch'):
if init_method == 'torch':
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(tensor, -bound, bound)
elif init_method == 'jax':
init.normal_(tensor, std=1e-6)
elif init_method == 'jax_embed':
init.trunc_normal_(tensor, std=.02)
elif init_method == 'zero':
init.zeros_(tensor)
from .fused_bias_gelu import bias_gelu_impl
from .parallel_1d import * from .parallel_1d import *
from .parallel_2d import * from .parallel_2d import *
from .parallel_2p5d import * from .parallel_2p5d import *
from .parallel_3d import * from .parallel_3d import *
from .parallel_sequence import * from .parallel_sequence import *
from .parallel_vision_transformer import * from .non_parallel_layers import *
from .vanilla_resnet import *
from .vanilla_vision_transformer import *
from .wrapper import * from .wrapper import *
...@@ -2,14 +2,39 @@ ...@@ -2,14 +2,39 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import math import math
import collections.abc
from itertools import repeat
import numpy as np
from colossalai.utils.common import print_rank_0
import torch import torch
from torch import Tensor from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS
from torch import nn
from colossalai.utils import checkpoint from colossalai.utils import checkpoint
from torch import Tensor, nn
from colossalai.constants import IS_TENSOR_PARALLEL class CheckpointModule(nn.Module):
def __init__(self, checkpoint: bool = True):
super().__init__()
self.checkpoint = checkpoint
self._use_checkpoint = checkpoint
def _forward(self, *args, **kwargs):
raise NotImplementedError(
'CheckpointModule should implement _forward method instead of origin forward')
def forward(self, *args, **kwargs):
if self._use_checkpoint:
return checkpoint(self._forward, *args, **kwargs)
else:
return self._forward(*args, **kwargs)
def train(self, mode: bool = True):
self._use_checkpoint = self.checkpoint
return super().train(mode=mode)
def eval(self):
self._use_checkpoint = False
return super().eval()
def divide(numerator, denominator): def divide(numerator, denominator):
""" only allow exact division """ """ only allow exact division """
...@@ -18,46 +43,30 @@ def divide(numerator, denominator): ...@@ -18,46 +43,30 @@ def divide(numerator, denominator):
return numerator // denominator return numerator // denominator
def gelu(x: Tensor) -> Tensor:
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def swish(x: Tensor) -> Tensor: def swish(x: Tensor) -> Tensor:
return x * torch.sigmoid(x) return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
def set_tensor_parallel_attribute(param): def set_tensor_parallel_attribute_by_size(param, size):
if not hasattr(param, IS_TENSOR_PARALLEL): setattr(param, IS_TENSOR_PARALLEL, True)
setattr(param, IS_TENSOR_PARALLEL, True) setattr(param, NUM_PARTITIONS, size // np.prod(param.shape))
class CheckpointModule(nn.Module): def set_tensor_parallel_attribute_by_partition(param, num_partitions):
def __init__(self, checkpoint: bool = True): setattr(param, IS_TENSOR_PARALLEL, True)
super().__init__() setattr(param, NUM_PARTITIONS, num_partitions)
self.checkpoint = checkpoint
self._use_checkpoint = checkpoint
def _forward(self, *args): # From PyTorch internals
raise NotImplementedError( def _ntuple(n):
'CheckpointModule should implement _forward method instead of origin forward') def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
def forward(self, *args): return parse
if self._use_checkpoint:
return checkpoint(self._forward, *args)
else:
return self._forward(*args)
def train(self, mode: bool = True):
self._use_checkpoint = self.checkpoint
return super().train(mode=mode)
def eval(self): to_2tuple = _ntuple(2)
self._use_checkpoint = False
return super().eval()
# adapted from Megatron-LM
# https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/megatron/model/fused_bias_gelu.py
import torch
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_back(g, bias, y):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff*g
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(bias, input)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp
bias_gelu_impl = GeLUFunction.apply
\ No newline at end of file
from ._vit import (ViTBlock, VanillaViTAttention, VanillaViTBlock, VanillaViTDropPath,
VanillaViTHead, VanillaViTMLP, VanillaViTPatchEmbedding)
__all__ = [
'ViTBlock', 'VanillaViTAttention', 'VanillaViTBlock', 'VanillaViTDropPath',
'VanillaViTHead', 'VanillaViTMLP', 'VanillaViTPatchEmbedding'
]
import collections.abc #!/usr/bin/env python
from itertools import repeat # -*- encoding: utf-8 -*-
import torch import torch
from torch import nn as nn from torch import nn as nn
from colossalai.builder import build_layer
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from .._common_utils import to_2tuple
# From PyTorch internals @LAYERS.register_module
def _ntuple(n): class ViTBlock(nn.Module):
def parse(x): """Vision Transformer block
if isinstance(x, collections.abc.Iterable):
return x :param attention_cfg: config of attention layer
return tuple(repeat(x, n)) :type attention_cfg: dict
:param droppath_cfg: config of drop path
return parse :type droppath_cfg: dict
:param mlp_cfg: config of MLP layer
:type mlp_cfg: dict
:param norm_cfg: config of normlization layer
:type norm_cfg: dict
"""
def __init__(self,
attention_cfg: dict,
droppath_cfg: dict,
mlp_cfg: dict,
norm_cfg: dict,
):
super().__init__()
self.norm1 = build_layer(norm_cfg)
self.attn = build_layer(attention_cfg)
self.drop_path = build_layer(
droppath_cfg) if droppath_cfg['drop_path'] > 0. else nn.Identity()
self.norm2 = build_layer(norm_cfg)
self.mlp = build_layer(mlp_cfg)
to_2tuple = _ntuple(2) def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
@LAYERS.register_module @LAYERS.register_module
......
from .layers import Linear1D_Col, Linear1D_Row from .layers import Linear1D_Col, Linear1D_Row
from .layers import MixedFusedLayerNorm1D as LayerNorm1D
from ._transformer import TransformerMLP1D, TransformerSelfAttention1D, TransformerLayer1D
from ._vit import ViTMLP1D, ViTSelfAttention1D, ViTHead1D, ViTPatchEmbedding1D, ViTTokenFuser1D, ViTHead
__all__ = [ __all__ = [
'Linear1D_Col', 'Linear1D_Row', 'Linear1D_Col', 'Linear1D_Row', 'ViTMLP1D', 'ViTSelfAttention1D', 'ViTHead1D', 'ViTPatchEmbedding1D', 'ViTTokenFuser1D',
'TransformerMLP1D', 'TransformerSelfAttention1D', 'TransformerLayer1D', 'LayerNorm1D', 'ViTHead'
] ]
import torch
try:
import fused_mix_prec_layer_norm_cuda
except:
fused_mix_prec_layer_norm_cuda = None
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias \
= fused_mix_prec_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None
\ No newline at end of file
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import math
from torch import Tensor
from torch.nn.parameter import Parameter
from typing import Tuple
from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from .._common_utils import divide, ACT2FN
from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \
split_forward_gather_backward
from ..base_layer import ParallelLayer
from .layers import Linear1D_Col, Linear1D_Row
from .layers import MixedFusedLayerNorm1D as LayerNorm1D
@LAYERS.register_module
class TransformerMLP1D(ParallelLayer):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(self,
in_features: int,
mlp_ratio: int = 4.0,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
skip_bias_add: bool = False
):
super(TransformerMLP1D, self).__init__()
self.in_features = in_features
self.mlp_ratio = mlp_ratio
self.skip_bias_add = skip_bias_add
# Project to h * mlp_ratio.
self.dense_1 = Linear1D_Col(
self.in_features,
int(self.mlp_ratio * self.in_features),
bias=not skip_bias_add,
dtype=dtype,
gather_output = False,
)
assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \
f'activation function can only be {list(ACT2FN.keys())}'
self.activation_func = ACT2FN[act_func]
# Project back to h.
self.dense_2 = Linear1D_Row(
int(self.mlp_ratio * self.in_features),
self.in_features,
bias=not skip_bias_add,
dtype=dtype,
parallel_input = True,
)
self.dropout = nn.Dropout(dropout_prob)
# self.layernorm = LayerNorm1D(in_features, dtype=dtype)
self.layernorm = nn.LayerNorm(in_features, dtype=dtype)
def forward(self, x):
if self.skip_bias_add:
intermediate_output, _ = self.dense_1(x)
else:
intermediate_output = self.dense_1(x)
intermediate_output = self.activation_func(intermediate_output)
if self.skip_bias_add:
output, _ = self.dense_2(intermediate_output)
else:
output = self.dense_2(intermediate_output)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
output = self.layernorm(x + output)
return output
@LAYERS.register_module
class TransformerSelfAttention1D(ParallelLayer):
"""Self attention layer for 1D parallel Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layer
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layer
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
):
super().__init__()
self.hidden_size = hidden_size
self.num_attention_heads = divide(num_attention_heads, gpc.tensor_parallel_size)
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size)
self.query_key_value = Linear1D_Col(
hidden_size,
3 * hidden_size,
dtype=dtype,
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear1D_Row(
hidden_size,
hidden_size,
dtype=dtype,
parallel_input=True,
)
self.dropout = nn.Dropout(hidden_dropout_prob)
# need to re-enable torch grad to enable fused optimization.
# self.layernorm = LayerNorm1D(
# hidden_size,
# dtype=dtype)
self.layernorm = nn.LayerNorm(
hidden_size,
dtype=dtype)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_scores = attention_scores + attention_mask
attention_probs = nn.Softmax(dim=-1)(attention_scores)
with seed(ParallelMode.TENSOR):
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute((0, 2, 1, 3)).contiguous()
new_context_layer_shape = context_layer.size()[
:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.dense(context_layer)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
attention_output = self.layernorm(hidden_states + output)
return attention_output
@LAYERS.register_module
class TransformerLayer1D(ParallelLayer):
"""Transformer layer which contains a self-attention layer and a MLP layer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: float, optional
:param attention_dropout_prob: dropout probability for attention layer, defaults to 0.
:type attention_dropout_prob: float, optional
:param hidden_dropout_prob: dropout probability for attention layer, defaults to 0.
:type hidden_dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
act_func: str = 'gelu',
mlp_ratio: float = 4.0,
attention_dropout_prob: float = 0.,
hidden_dropout_prob: float = 0.,
dtype=None,
):
super().__init__()
self.attention = TransformerSelfAttention1D(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout_prob=attention_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
dtype=dtype,
)
self.mlp = TransformerMLP1D(
in_features=hidden_size,
dropout_prob=hidden_dropout_prob,
act_func=act_func,
mlp_ratio=mlp_ratio,
dtype=dtype,
)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
attention_output = self.attention(hidden_states, attention_mask)
output = self.mlp(attention_output)
return output
...@@ -13,3 +13,6 @@ def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank): ...@@ -13,3 +13,6 @@ def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank):
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
per_partition_vocab_size = divide(global_vocab_size, world_size) per_partition_vocab_size = divide(global_vocab_size, world_size)
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank) return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank)
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
from colossalai import context
import torch
from torch import nn as nn, Tensor, distributed as dist
from torch.nn.init import _calculate_fan_in_and_fan_out
from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer._common_utils import divide, ACT2FN
from colossalai.registry import LAYERS
from colossalai.utils import checkpoint
from colossalai.utils import get_current_device
from .layers import Linear1D_Col, Linear1D_Row
from ..base_layer import ParallelLayer
from .._common_utils import to_2tuple
from ..fused_bias_gelu import bias_gelu_impl
@LAYERS.register_module
class ViTMLP1D(ParallelLayer):
"""MLP layer for 1D parallel Vision Transformer
:param in_features: size of each input sample
:type in_features: int
:param mlp_ratio: hidden size of MLP divided by embedding dim
:type mlp_ratio: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param dropout_prob: dropout probability, defaults to 0.
:type dropout_prob: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: whether to checkpoint the layer, defaults to False
:type checkpoint: bool, optional
"""
def __init__(self,
in_features: int,
mlp_ratio: int,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
checkpoint: bool = False,
skip_bias_add: bool = False,
weight_init='torch'
):
super().__init__()
self.in_features = in_features
self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint
self.skip_bias_add = skip_bias_add
assert weight_init in ('torch', 'jax')
if act_func == 'fused_gelu':
self.act = bias_gelu_impl
skip_dense_1_add_bias = True
else:
self.act = ACT2FN[act_func]
skip_dense_1_add_bias = False
# Project to mlp_ratio * h.
self.dense_1 = Linear1D_Col(
self.in_features,
int(self.mlp_ratio * self.in_features),
dtype=dtype,
gather_output=False,
skip_bias_add=skip_dense_1_add_bias,
init_weight=weight_init,
init_bias=weight_init
)
# Project back to h.
self.dense_2 = Linear1D_Row(
int(self.mlp_ratio * self.in_features),
self.in_features,
dtype=dtype,
parallel_input=True,
init_weight=weight_init, init_bias=weight_init
)
self.dropout = nn.Dropout(dropout_prob)
def _forward(self, hidden_states: Tensor) -> Tensor:
if self.act == bias_gelu_impl:
intermediate_output, bias = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output, bias)
else:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output)
with seed(ParallelMode.TENSOR):
intermediate_output = self.dropout(intermediate_output)
output = self.dense_2(intermediate_output)
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTSelfAttention1D(ParallelLayer):
"""Self-attention layer for 1D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layers
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layers
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: whether to checkpoint the layer, defaults to False
:type checkpoint: bool, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
checkpoint: bool = False,
weight_init='torch'
):
super().__init__()
self.hidden_size = hidden_size
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.num_attention_heads_per_partition = divide(num_attention_heads, gpc.tensor_parallel_size)
self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size)
self.checkpoint = checkpoint
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
init_bias = 'zero'
else:
init_bias = weight_init
self.query_key_value = Linear1D_Col(
hidden_size,
3 * hidden_size,
dtype=dtype,
init_weight=weight_init,
init_bias=init_bias
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear1D_Row(
hidden_size,
hidden_size,
dtype=dtype,
parallel_input=True,
init_weight=weight_init, init_bias=init_bias
)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.softmax = nn.Softmax(dim=-1)
def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads_per_partition, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
with seed(ParallelMode.TENSOR):
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2)
new_context_layer_shape = context_layer.size()[
:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer)
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTHead1D(ParallelLayer):
"""Output layer for 1D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_classes: number of classes
:type num_classes: int
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size,
num_classes,
dtype=None,
weight_init='torch'
):
super().__init__()
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
init_weight = 'zero'
init_bias = 'zero'
else:
init_weight = weight_init
init_bias = weight_init
self.linear = Linear1D_Col(
hidden_size,
num_classes,
dtype=dtype,
gather_output=True,
init_weight=init_weight,
init_bias=init_bias
)
def forward(self, x: Tensor) -> Tensor:
x = x[:, 0]
x = self.linear(x)
return x
@LAYERS.register_module
class ViTHead(ParallelLayer):
"""Output layer for 1D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_classes: number of classes
:type num_classes: int
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size,
num_classes,
dtype=None,
):
super().__init__()
self.linear = nn.Linear(
hidden_size,
num_classes,
dtype=dtype
)
self._broadcast_linear_params()
def _broadcast_linear_params(self) -> None:
self.to(get_current_device())
ranks = gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)
dist.broadcast(self.linear.weight, src=ranks[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
dist.broadcast(self.linear.bias, src=ranks[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
def forward(self, x: Tensor) -> Tensor:
x = x[:, 0]
x = self.linear(x)
return x
@LAYERS.register_module
class ViTPatchEmbedding1D(ParallelLayer):
""" 2D Image to Patch Embedding
:param img_size: iamge size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param in_chans: number of channels of input image, defaults to 3
:type in_chans: int, optional
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
in_chans=3,
flatten=True,
weight_init='torch'):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans,
self.embed_dim,
kernel_size=patch_size,
stride=patch_size
)
if weight_init == 'jax':
fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight)
std = math.sqrt(1.0 / fan_in)
nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
nn.init.zeros_(self.proj.bias)
# sync
self._broadcast_conv_params()
def _broadcast_conv_params(self) -> None:
self.to(get_current_device())
ranks = gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)
dist.broadcast(self.proj.weight, src=ranks[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
dist.broadcast(self.proj.bias, src=ranks[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
def forward(self, x: Tensor) -> Tensor:
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
return x
@LAYERS.register_module
class ViTTokenFuser1D(ParallelLayer):
"""
Fuse cls token and pos embedding to the input
:param img_size: image size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param drop_rate: dropout probability, defaults to 0.
:type drop_rate: float, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
drop_rate=0.
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.embed_dim = embed_dim
self.cls_token = nn.Parameter(torch.zeros(
1, 1, self.embed_dim))
self.pos_embed = nn.Parameter(torch.empty(
1, self.num_patches + 1, self.embed_dim))
nn.init.trunc_normal_(self.pos_embed, std=.02)
# move to cuda before broadcast
self.to(get_current_device())
dist.broadcast(self.pos_embed,
src=gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
group=gpc.get_group(ParallelMode.TENSOR))
self.pos_drop = nn.Dropout(p=drop_rate)
def forward(self, x: Tensor) -> Tensor:
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = self.pos_drop(x + self.pos_embed)
return x.contiguous()
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import math
import numbers
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.nn.init as init import torch.nn.init as init
from torch import Tensor from torch import Tensor
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from typing import Tuple from typing import Tuple
import importlib
from colossalai.context.parallel_mode import ParallelMode from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from .._common_utils import divide from ._operation import FusedLayerNormAffineFunction1D
from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \ from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \
split_forward_gather_backward split_forward_gather_backward
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
@LAYERS.register_module
class Linear1D_Col(ParallelLayer): class Linear1D_Col(ParallelLayer):
"""Linear layer with column parallelism. """Linear layer with column parallelism.
...@@ -44,23 +50,29 @@ class Linear1D_Col(ParallelLayer): ...@@ -44,23 +50,29 @@ class Linear1D_Col(ParallelLayer):
output_size: int, output_size: int,
bias: bool = True, bias: bool = True,
dtype: torch.dtype = None, dtype: torch.dtype = None,
gather_output: bool = False): gather_output: bool = False,
skip_bias_add: bool = False,
init_weight='torch',
init_bias='torch'
):
super().__init__() super().__init__()
# Keep input parameters # Keep input parameters
self.input_size = in_features self.in_features = in_features
self.output_size = output_size self.out_features = output_size
self.gather_output = gather_output self.gather_output = gather_output
self.skip_bias_add = not bias self.skip_bias_add = skip_bias_add
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) if skip_bias_add and not bias:
self.output_size_per_partition = divide(output_size, world_size) raise ValueError('cannot skip bias addition if bias is None')
self.output_size_per_partition = divide(output_size, gpc.tensor_parallel_size)
# Parameters. # Parameters.
# Initialize weight. # Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype} factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty( self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size, self.output_size_per_partition, self.in_features,
**factory_kwargs)) **factory_kwargs))
if bias: if bias:
...@@ -72,6 +84,45 @@ class Linear1D_Col(ParallelLayer): ...@@ -72,6 +84,45 @@ class Linear1D_Col(ParallelLayer):
self.bias.zero_() self.bias.zero_()
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
with seed(ParallelMode.TENSOR):
self.reset_parameters(init_weight, init_bias)
self._set_tensor_parallel_attributes()
def reset_parameters(self, init_weight, init_bias) -> None:
assert init_weight in ('torch', 'jax', 'zero')
assert init_bias in ('torch', 'jax', 'zero')
# setting
fan_in, fan_out = self.in_features, self.out_features
# init weight
if init_weight == 'torch':
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
init.uniform_(self.weight, -bound, bound)
elif init_weight == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(self.weight, -a, a)
elif init_weight == 'zero':
init.zeros_(self.weight)
# init bias
if self.bias is not None:
if init_bias == 'torch':
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
elif init_bias == 'jax':
init.normal_(self.bias, std=1e-6)
elif init_bias == 'zero':
init.zeros_(self.bias)
def _set_tensor_parallel_attributes(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
# Set up backprop all-reduce. # Set up backprop all-reduce.
...@@ -104,7 +155,7 @@ class Linear1D_Row(ParallelLayer): ...@@ -104,7 +155,7 @@ class Linear1D_Row(ParallelLayer):
:type bias: bool, optional :type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None :param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional :type dtype: torch.dtype, optional
:param parallel_input: If set to ``False``, it's assumed that the input is splitted, defaults to False :param parallel_input: If set to ``True``, it's assumed that the input is splitted, defaults to False
:type parallel_input: bool, optional :type parallel_input: bool, optional
""" """
...@@ -113,7 +164,10 @@ class Linear1D_Row(ParallelLayer): ...@@ -113,7 +164,10 @@ class Linear1D_Row(ParallelLayer):
out_features: int, out_features: int,
bias: bool = True, bias: bool = True,
dtype: torch.dtype = None, dtype: torch.dtype = None,
parallel_input: bool = False parallel_input: bool = False,
skip_bias_add: bool = False,
init_weight='torch',
init_bias='torch'
): ):
super().__init__() super().__init__()
...@@ -121,11 +175,13 @@ class Linear1D_Row(ParallelLayer): ...@@ -121,11 +175,13 @@ class Linear1D_Row(ParallelLayer):
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.parallel_input = parallel_input self.parallel_input = parallel_input
self.skip_bias_add = not bias self.skip_bias_add = skip_bias_add
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size)
self.input_size_per_partition = divide(in_features, world_size)
# Parameters. # Parameters.
# Initialize weight. # Initialize weight.
...@@ -146,9 +202,46 @@ class Linear1D_Row(ParallelLayer): ...@@ -146,9 +202,46 @@ class Linear1D_Row(ParallelLayer):
self.bias.zero_() self.bias.zero_()
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
with seed(ParallelMode.TENSOR):
self.reset_parameters(init_weight, init_bias)
self._set_tensor_parallel_attributes()
def reset_parameters(self, init_weight, init_bias) -> None:
assert init_weight in ('torch', 'jax', 'zero')
assert init_bias in ('torch', 'jax', 'zero')
# setting
fan_in, fan_out = self.in_features, self.out_features
# init weight
if init_weight == 'torch':
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
init.uniform_(self.weight, -bound, bound)
elif init_weight == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(self.weight, -a, a)
elif init_weight == 'zero':
init.zeros_(self.weight)
def reset_parameters(self) -> None: # init bias
init.xavier_normal_(self.weight) if self.bias is not None:
if init_bias == 'torch':
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
elif init_bias == 'jax':
init.normal_(self.bias, std=1e-6)
elif init_bias == 'zero':
init.zeros_(self.bias)
dist.broadcast(self.bias,
src=gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
def _set_tensor_parallel_attributes(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce. # Set up backprop all-reduce.
...@@ -163,4 +256,29 @@ class Linear1D_Row(ParallelLayer): ...@@ -163,4 +256,29 @@ class Linear1D_Row(ParallelLayer):
if not self.skip_bias_add: if not self.skip_bias_add:
output = output + self.bias output = output + self.bias
return output return output
else:
return output, self.bias
@LAYERS.register_module
class MixedFusedLayerNorm1D(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5):
super(MixedFusedLayerNorm1D, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters()
def reset_parameters(self):
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
return FusedLayerNormAffineFunction1D.apply(
input, self.weight, self.bias, self.normalized_shape, self.eps)
...@@ -20,7 +20,6 @@ def matmul_2d(a, ...@@ -20,7 +20,6 @@ def matmul_2d(a,
col_parallel_mode=ParallelMode.PARALLEL_2D_COL, col_parallel_mode=ParallelMode.PARALLEL_2D_COL,
): ):
"""Matrix multiplication for 2D parallelism """Matrix multiplication for 2D parallelism
:param a: matrix :math:`A` :param a: matrix :math:`A`
:type a: torch.tensor :type a: torch.tensor
:param b: matrix :math:`B` :param b: matrix :math:`B`
...@@ -86,25 +85,30 @@ class Matmul_AB_2D(torch.autograd.Function): ...@@ -86,25 +85,30 @@ class Matmul_AB_2D(torch.autograd.Function):
ctx.save_for_backward(A, B) ctx.save_for_backward(A, B)
A_shape = A.shape A_shape = A.shape
A = A.reshape((-1, A_shape[-1])) A = A.reshape((-1, A_shape[-1])).contiguous()
B_shape = B.shape B_shape = B.shape
B = B.reshape((-1, B_shape[-1])) B = B.reshape((-1, B_shape[-1])).contiguous()
C_shape = (A.shape[0], B.shape[-1]) C_shape = (A.shape[0], B.shape[-1])
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode)-1)]
B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode)-1)]
A_list.insert(gpc.get_local_rank(row_parallel_mode), A)
B_list.insert(gpc.get_local_rank(col_parallel_mode), B)
op_a = dist.all_gather(A_list, A, group=gpc.get_group(row_parallel_mode), async_op=True)
op_a.wait()
op_b = dist.all_gather(B_list, B, group=gpc.get_group(col_parallel_mode), async_op=True)
for op in [op_a, op_b]:
op.wait()
for i in range(summa_dim): for i in range(summa_dim):
A_temp = A.clone() src_a = i + summa_dim * row_rank
B_temp = B.clone() src_b = i + summa_dim * col_rank
src_a = i + summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ src_a = src_a % summa_dim
pipeline_parallel_rank * tensor_parallel_size src_b = src_b % summa_dim
dist.broadcast(A_temp, src=src_a, A_temp = A_list[src_a]
group=gpc.get_group(row_parallel_mode)) B_temp = B_list[src_b]
src_b = col_rank + summa_dim * i + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(B_temp, src=src_b,
group=gpc.get_group(col_parallel_mode))
torch.addmm(C, A_temp, B_temp, out=C) torch.addmm(C, A_temp, B_temp, out=C)
out = C.reshape(out_shape) out = C.reshape(out_shape)
if ctx: if ctx:
...@@ -499,36 +503,61 @@ class _LayerNorm_2D(torch.autograd.Function): ...@@ -499,36 +503,61 @@ class _LayerNorm_2D(torch.autograd.Function):
# return input_grad, None, None, None, None, None # return input_grad, None, None, None, None, None
class _ViT_Split_Input_2D(torch.autograd.Function): class AllGatherLast(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, def forward(ctx: Any,
inputs: Tensor, inputs: Tensor,
batch_size: int,
summa_dim: int, summa_dim: int,
col_parallel_mode: ParallelMode) -> Tensor: col_parallel_mode: ParallelMode) -> Tensor:
# inputs: [b, s, h/q] ctx.summa_dim = summa_dim
# output: [b/q, s, h/q] ctx.row_rank = gpc.get_local_rank(col_parallel_mode)
last_dim = summa_dim * inputs.size(-1)
outputs_shape = (last_dim,) + inputs.shape[:-1]
outputs = torch.empty(
outputs_shape, dtype=inputs.dtype, device=get_current_device())
dist.all_gather(
list(outputs.chunk(summa_dim, dim=0)),
inputs.permute(2, 0, 1).contiguous(),
group=gpc.get_group(col_parallel_mode)
)
outputs = outputs.permute(1, 2, 0).contiguous()
return outputs
ctx.BATCH_SIZE = batch_size @staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad = output_grad.chunk(ctx.summa_dim, dim=-1)[ctx.row_rank]
return grad.contiguous(), None, None
class SplitFirst(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
inputs: Tensor,
summa_dim: int,
col_parallel_mode: ParallelMode) -> Tensor:
ctx.summa_dim = summa_dim ctx.summa_dim = summa_dim
ctx.col_parallel_mode = col_parallel_mode ctx.batch_size = inputs.size(0)
ctx.para_mode = col_parallel_mode
row_rank = gpc.get_local_rank(col_parallel_mode) row_rank = gpc.get_local_rank(col_parallel_mode)
output = torch.chunk(inputs, summa_dim, dim=0)[row_rank]
output = output.clone() outputs = inputs.chunk(summa_dim, dim=0)[row_rank]
return output return outputs
@staticmethod @staticmethod
@custom_bwd @custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# output_grad: [b/q, s, h/q] grad_shape = (ctx.batch_size,) + output_grad.shape[1:]
# grads: [b, s, h/q] grad = torch.empty(
grads_shape = (ctx.BATCH_SIZE,) + output_grad.shape[1:] grad_shape, dtype=output_grad.dtype, device=get_current_device())
grads = torch.empty(grads_shape, dist.all_gather(
dtype=output_grad.dtype, list(grad.chunk(ctx.summa_dim, dim=0)),
device=get_current_device()) output_grad.contiguous(),
dist.all_gather(list(grads.chunk(ctx.summa_dim, dim=0)), group=gpc.get_group(ctx.para_mode)
output_grad.contiguous(), )
group=gpc.get_group(ctx.col_parallel_mode)) return grad, None, None
return grads, None, None, None
...@@ -5,19 +5,21 @@ import math ...@@ -5,19 +5,21 @@ import math
import torch import torch
from torch import nn as nn, Tensor, distributed as dist from torch import nn as nn, Tensor, distributed as dist
from torch.nn.init import _calculate_fan_in_and_fan_out
from colossalai.context import seed, ParallelMode from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer._common_utils import divide, ACT2FN from colossalai.nn.layer._common_utils import divide, ACT2FN
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env
from colossalai.nn.layer.vanilla_vision_transformer.layers import to_2tuple
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from colossalai.utils import checkpoint from colossalai.utils import checkpoint
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from ._operation import _ViT_Split_Input_2D from colossalai.core import global_context as gpc
from ._operation import AllGatherLast, SplitFirst
from .layers import Linear2D from .layers import Linear2D
from .._common_utils import set_tensor_parallel_attribute from .._common_utils import set_tensor_parallel_attribute_by_partition, to_2tuple
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..fused_bias_gelu import bias_gelu_impl
@LAYERS.register_module @LAYERS.register_module
...@@ -44,8 +46,8 @@ class ViTMLP2D(ParallelLayer): ...@@ -44,8 +46,8 @@ class ViTMLP2D(ParallelLayer):
act_func: str = 'gelu', act_func: str = 'gelu',
dropout_prob: float = 0., dropout_prob: float = 0.,
dtype=None, dtype=None,
checkpoint: bool = False checkpoint: bool = False,
): weight_init='torch'):
super().__init__() super().__init__()
assert_summa_initialization() assert_summa_initialization()
...@@ -53,27 +55,40 @@ class ViTMLP2D(ParallelLayer): ...@@ -53,27 +55,40 @@ class ViTMLP2D(ParallelLayer):
self.in_features = in_features self.in_features = in_features
self.mlp_ratio = mlp_ratio self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint self.checkpoint = checkpoint
assert weight_init in ('torch', 'jax')
if act_func == 'fused_gelu':
self.act = bias_gelu_impl
skip_dense_1_add_bias = True
else:
self.act = ACT2FN[act_func]
skip_dense_1_add_bias = False
# Project to mlp_ratio * h. # Project to mlp_ratio * h.
self.dense_1 = Linear2D( self.dense_1 = Linear2D(
self.in_features, self.in_features,
self.mlp_ratio * self.in_features, self.mlp_ratio * self.in_features,
dtype=dtype, dtype=dtype,
init_weight=weight_init, init_bias=weight_init,
skip_bias_add=skip_dense_1_add_bias
) )
self.act = ACT2FN[act_func]
# Project back to h. # Project back to h.
self.dense_2 = Linear2D( self.dense_2 = Linear2D(
self.mlp_ratio * self.in_features, self.mlp_ratio * self.in_features,
self.in_features, self.in_features,
dtype=dtype, dtype=dtype,
init_weight=weight_init, init_bias=weight_init
) )
self.dropout = nn.Dropout(dropout_prob) self.dropout = nn.Dropout(dropout_prob)
def _forward(self, hidden_states: Tensor) -> Tensor: def _forward(self, hidden_states: Tensor) -> Tensor:
intermediate_output = self.dense_1(hidden_states) if self.act == bias_gelu_impl:
intermediate_output = self.act(intermediate_output) intermediate_output, bias = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output, bias)
else:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output)
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
intermediate_output = self.dropout(intermediate_output) intermediate_output = self.dropout(intermediate_output)
...@@ -117,8 +132,8 @@ class ViTSelfAttention2D(ParallelLayer): ...@@ -117,8 +132,8 @@ class ViTSelfAttention2D(ParallelLayer):
attention_dropout_prob: float, attention_dropout_prob: float,
hidden_dropout_prob: float, hidden_dropout_prob: float,
dtype=None, dtype=None,
checkpoint: bool = False checkpoint: bool = False,
): weight_init='torch'):
super().__init__() super().__init__()
assert_summa_initialization() assert_summa_initialization()
...@@ -128,17 +143,24 @@ class ViTSelfAttention2D(ParallelLayer): ...@@ -128,17 +143,24 @@ class ViTSelfAttention2D(ParallelLayer):
self.attention_head_size = divide(hidden_size, num_attention_heads) self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
self.checkpoint = checkpoint self.checkpoint = checkpoint
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
self.init_bias = 'zero'
else:
self.init_bias = weight_init
self.query_key_value = Linear2D( self.query_key_value = Linear2D(
hidden_size, hidden_size,
3 * hidden_size, 3 * hidden_size,
dtype=dtype, dtype=dtype,
init_weight=weight_init, init_bias=self.init_bias
) )
self.attention_dropout = nn.Dropout(attention_dropout_prob) self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear2D( self.dense = Linear2D(
hidden_size, hidden_size,
hidden_size, hidden_size,
dtype=dtype, dtype=dtype,
init_weight=weight_init, init_bias=self.init_bias
) )
self.dropout = nn.Dropout(hidden_dropout_prob) self.dropout = nn.Dropout(hidden_dropout_prob)
self.softmax = nn.Softmax(dim=-1) self.softmax = nn.Softmax(dim=-1)
...@@ -146,7 +168,7 @@ class ViTSelfAttention2D(ParallelLayer): ...@@ -146,7 +168,7 @@ class ViTSelfAttention2D(ParallelLayer):
def _forward(self, hidden_states: Tensor) -> Tensor: def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states) query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \ new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size) (self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape) query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3)) query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk( query_layer, key_layer, value_layer = torch.chunk(
...@@ -155,7 +177,7 @@ class ViTSelfAttention2D(ParallelLayer): ...@@ -155,7 +177,7 @@ class ViTSelfAttention2D(ParallelLayer):
attention_scores = torch.matmul( attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2)) query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \ attention_scores = attention_scores / \
math.sqrt(self.attention_head_size) math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores) attention_probs = self.softmax(attention_scores)
...@@ -165,7 +187,7 @@ class ViTSelfAttention2D(ParallelLayer): ...@@ -165,7 +187,7 @@ class ViTSelfAttention2D(ParallelLayer):
context_layer = torch.matmul(attention_probs, value_layer) context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2) context_layer = context_layer.transpose(1, 2)
new_context_layer_shape = context_layer.size()[ new_context_layer_shape = context_layer.size()[
:-2] + (self.all_head_size,) :-2] + (self.all_head_size,)
context_layer = context_layer.reshape(new_context_layer_shape) context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer) output = self.dense(context_layer)
...@@ -199,14 +221,22 @@ class ViTHead2D(ParallelLayer): ...@@ -199,14 +221,22 @@ class ViTHead2D(ParallelLayer):
hidden_size, hidden_size,
num_classes, num_classes,
dtype=None, dtype=None,
): weight_init='torch'):
super().__init__() super().__init__()
assert_summa_initialization() assert_summa_initialization()
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
self.init_weight = 'zero'
self.init_bias = 'zero'
else:
self.init_weight = weight_init
self.init_bias = weight_init
self.summa_dim = get_summa_dim_from_env() self.summa_dim = get_summa_dim_from_env()
self.linear = Linear2D( self.linear = Linear2D(
hidden_size, hidden_size,
num_classes, num_classes,
dtype=dtype, dtype=dtype,
init_weight=self.init_weight, init_bias=self.init_bias
) )
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
...@@ -236,7 +266,8 @@ class ViTPatchEmbedding2D(ParallelLayer): ...@@ -236,7 +266,8 @@ class ViTPatchEmbedding2D(ParallelLayer):
patch_size, patch_size,
embed_dim, embed_dim,
in_chans=3, in_chans=3,
flatten=True): flatten=True,
weight_init='torch'):
super().__init__() super().__init__()
img_size = to_2tuple(img_size) img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size) patch_size = to_2tuple(patch_size)
...@@ -249,39 +280,28 @@ class ViTPatchEmbedding2D(ParallelLayer): ...@@ -249,39 +280,28 @@ class ViTPatchEmbedding2D(ParallelLayer):
img_size[1] // patch_size[1]) img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1] self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten self.flatten = flatten
self.embed_dim = embed_dim // self.summa_dim self.embed_dim = embed_dim // (self.summa_dim ** 2)
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
# ensure the partitions are initialized differently
self.proj = nn.Conv2d(in_chans, self.proj = nn.Conv2d(in_chans,
self.embed_dim, self.embed_dim,
kernel_size=patch_size, kernel_size=patch_size,
stride=patch_size stride=patch_size,
device=get_current_device()
) )
self._set_tensor_parallel_attribute()
# sync if weight_init == 'jax':
self._broadcast_conv_params() with seed(ParallelMode.TENSOR):
self.proj.weight.register_hook(self._sync_grad_during_backward) fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight)
self.proj.bias.register_hook(self._sync_grad_during_backward) std = math.sqrt(1.0 / fan_in)
nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
nn.init.zeros_(self.proj.bias)
def _set_tensor_parallel_attribute(self): def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute(self.proj.weight) num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute(self.proj.bias) set_tensor_parallel_attribute_by_partition(self.proj.weight, num_partition)
set_tensor_parallel_attribute_by_partition(self.proj.bias, num_partition)
def _broadcast_conv_params(self) -> None:
self.to(get_current_device())
ranks_in_col = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_COL)
dist.broadcast(self.proj.weight, src=ranks_in_col[0],
group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
dist.broadcast(self.proj.bias, src=ranks_in_col[0],
group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
def _sync_grad_during_backward(self, grad: Tensor) -> None:
dist.all_reduce(grad, group=gpc.get_group(
ParallelMode.PARALLEL_2D_COL))
grad = grad / self.summa_dim
return grad
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
B, C, H, W = x.shape B, C, H, W = x.shape
...@@ -293,6 +313,24 @@ class ViTPatchEmbedding2D(ParallelLayer): ...@@ -293,6 +313,24 @@ class ViTPatchEmbedding2D(ParallelLayer):
return x return x
@LAYERS.register_module
class ViTInputSplitter2D(ParallelLayer):
"""Split the input tensor for 2D parallel Vision Transformer
"""
def __init__(self):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
def forward(self, x: Tensor) -> Tensor:
x = AllGatherLast.apply(
x, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
x = SplitFirst.apply(
x, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
return x
@LAYERS.register_module @LAYERS.register_module
class ViTTokenFuser2D(ParallelLayer): class ViTTokenFuser2D(ParallelLayer):
""" """
...@@ -328,64 +366,32 @@ class ViTTokenFuser2D(ParallelLayer): ...@@ -328,64 +366,32 @@ class ViTTokenFuser2D(ParallelLayer):
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.cls_token = nn.Parameter(torch.zeros( self.cls_token = nn.Parameter(torch.zeros(
1, 1, self.embed_dim // self.summa_dim)) (1, 1, self.embed_dim // (self.summa_dim ** 2)),
self.pos_embed = nn.Parameter(torch.zeros( device=get_current_device()))
1, self.num_patches + 1, self.embed_dim // self.summa_dim)) self.pos_embed = nn.Parameter(torch.empty(
(1, self.num_patches + 1, self.embed_dim // (self.summa_dim ** 2)),
# move to cuda before broadcast device=get_current_device()))
self.to(get_current_device()) with seed(ParallelMode.TENSOR):
nn.init.trunc_normal_(self.pos_embed, std=.02)
# sync param in both forward and backward
_cls_token = self.cls_token.view(-1)
_pos_embed = self.pos_embed.view(-1)
self._param = torch.cat([_cls_token, _pos_embed], dim=0)
self._broadcast_params(self._param)
self._param.register_hook(self._sync_grad_hook)
self.pos_drop = nn.Dropout(p=drop_rate) self.pos_drop = nn.Dropout(p=drop_rate)
self._set_tensor_parallel_attribute() self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self): def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute(self.cls_token) num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute(self.pos_embed) set_tensor_parallel_attribute_by_partition(self.cls_token, num_partition)
set_tensor_parallel_attribute_by_partition(self.pos_embed, num_partition)
def _broadcast_params(self, param) -> None:
" broadcast to all column ranks for data consistency "
ranks_in_col = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_COL)
col_group = gpc.get_group(ParallelMode.PARALLEL_2D_COL)
dist.broadcast(param, src=ranks_in_col[0],
group=col_group)
def _sync_grad_hook(self, grad) -> None:
dist.all_reduce(grad, group=gpc.get_group(
ParallelMode.PARALLEL_2D_COL))
grad = grad / self.summa_dim
return grad
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
# stole cls_tokens impl from Phil Wang, thanks # stole cls_tokens impl from Phil Wang, thanks
cls_token = self.cls_token.expand(x.shape[0], -1, -1) cls_token = AllGatherLast.apply(
self.cls_token, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
cls_token = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1) x = torch.cat((cls_token, x), dim=1)
pos_embed = AllGatherLast.apply(
self.pos_embed, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
x = x + pos_embed
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
x = self.pos_drop(x + self.pos_embed) x = self.pos_drop(x)
return x return x
@LAYERS.register_module
class ViTInputSplitter2D(ParallelLayer):
"""Split the input tensor for 2D parallel Vision Transformer
"""
def __init__(self):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
def forward(self, x: Tensor) -> Tensor:
batch_size = x.size(0)
return _ViT_Split_Input_2D.apply(
x,
batch_size,
self.summa_dim,
ParallelMode.PARALLEL_2D_COL
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment