Unverified Commit da01c234 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Develop/experiments (#59)



* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.

* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>

* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000

* Integrate 1d tensor parallel in Colossal-AI (#39)

* fixed 1D and 2D convergence (#38)

* optimized 2D operations

* fixed 1D ViT convergence problem

* Feature/ddp (#49)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.

* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* support torch ddp

* fix loss accumulation

* add log for ddp

* change seed

* modify timing hook
Co-authored-by: default avatarFrank Lee <somerlee.9@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>

* Feature/pipeline (#40)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.

* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* optimize communication of pipeline parallel

* fix grad clip for pipeline
Co-authored-by: default avatarFrank Lee <somerlee.9@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>

* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)

* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset

* update api for better usability (#58)

update api for better usability
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>
Co-authored-by: default avatarpuck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>
Co-authored-by: default avatarアマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: default avatarBoxiangW <45734921+BoxiangW@users.noreply.github.com>
parent eb2f8b1f
from .caltech101_dataset import Caltech101Dataset
from .cifar10_dataset import CIFAR10Dataset
from .sampler import *
import numpy as np
def pil_img_to_numpy(pil_img):
"""convert a PIL image to numpy nd-array
:param pil_img: a PIL image
:type pil_img: PIL.Image
:return: a nd-array
:rtype: numpy.ndarray
"""
np_img = np.array(pil_img)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return np_img
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC
from torch.utils.data import Dataset
from torchvision.transforms import transforms
from colossalai.builder import build_transform
class BaseDataset(Dataset, ABC):
def __init__(self, transform_pipeline: list):
transform_list = [build_transform(cfg) for cfg in transform_pipeline]
transform = transforms.Compose(transform_list)
self._transform_pipeline = transform
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.distributed as dist
from torchvision.datasets import Caltech101
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import DATASETS
from .base_dataset import BaseDataset
@DATASETS.register_module
class Caltech101Dataset(BaseDataset):
"""`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.
:param transform_pipeline: A list of functions' config, which takes in an PIL image
and returns a transformed version
:type transform_pipeline: list
"""
def __init__(self, transform_pipeline: list, *args, **kwargs):
super().__init__(transform_pipeline)
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() != 0:
dist.barrier()
self._dataset = Caltech101(
transform=self._transform_pipeline, *args, **kwargs)
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() == 0:
dist.barrier()
def __len__(self):
return len(self._dataset)
def __getitem__(self, item):
"""
:param item: Index
:type item: int
:return: ((image,), (target,)) where the type of target specified by target_type.
:rtype: tuple
"""
img, label = self._dataset.__getitem__(item)
return (img,), (label,)
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.distributed as dist
from torchvision.datasets import CIFAR10
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import DATASETS
from .base_dataset import BaseDataset
@DATASETS.register_module
class CIFAR10Dataset(BaseDataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
:param transform_pipeline: A list of functions' config, which takes in an PIL image
and returns a transformed version
:type transform_pipeline: list
"""
def __init__(self, transform_pipeline: list, *args, **kwargs):
super().__init__(transform_pipeline)
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() != 0:
dist.barrier()
self._dataset = CIFAR10(transform=self._transform_pipeline,
*args,
**kwargs)
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() == 0:
dist.barrier()
def __len__(self):
return len(self._dataset)
def __getitem__(self, item):
"""
:param item: Index
:type item: int
:return: ((image,), (target,)) where the type of target specified by target_type.
:rtype: tuple
"""
img, label = self._dataset.__getitem__(item)
return (img,), (label,)
from .base_sampler import BaseSampler
from .data_parallel_sampler import DataParallelSampler
__all__ = ['BaseSampler', 'DataParallelSampler']
import math
from torch import Tensor
from torch.nn import init as init
def init_weight_(tensor: Tensor, fan_in: int, fan_out: int = None, init_method: str = 'torch'):
if init_method == 'torch':
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
init.uniform_(tensor, -bound, bound)
elif init_method == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(tensor, -a, a)
elif init_method == 'jax_embed':
std = math.sqrt(1.0 / fan_in)
init.trunc_normal_(tensor, std=std / .87962566103423978)
elif init_method == 'zero':
init.zeros_(tensor)
def init_bias_(tensor: Tensor, fan_in: int, init_method: str = 'torch'):
if init_method == 'torch':
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(tensor, -bound, bound)
elif init_method == 'jax':
init.normal_(tensor, std=1e-6)
elif init_method == 'jax_embed':
init.trunc_normal_(tensor, std=.02)
elif init_method == 'zero':
init.zeros_(tensor)
from .fused_bias_gelu import bias_gelu_impl
from .parallel_1d import *
from .parallel_2d import *
from .parallel_2p5d import *
from .parallel_3d import *
from .parallel_sequence import *
from .parallel_vision_transformer import *
from .vanilla_resnet import *
from .vanilla_vision_transformer import *
from .non_parallel_layers import *
from .wrapper import *
......@@ -2,14 +2,39 @@
# -*- encoding: utf-8 -*-
import math
import collections.abc
from itertools import repeat
import numpy as np
from colossalai.utils.common import print_rank_0
import torch
from torch import Tensor
from torch import nn
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS
from colossalai.utils import checkpoint
from torch import Tensor, nn
from colossalai.constants import IS_TENSOR_PARALLEL
class CheckpointModule(nn.Module):
def __init__(self, checkpoint: bool = True):
super().__init__()
self.checkpoint = checkpoint
self._use_checkpoint = checkpoint
def _forward(self, *args, **kwargs):
raise NotImplementedError(
'CheckpointModule should implement _forward method instead of origin forward')
def forward(self, *args, **kwargs):
if self._use_checkpoint:
return checkpoint(self._forward, *args, **kwargs)
else:
return self._forward(*args, **kwargs)
def train(self, mode: bool = True):
self._use_checkpoint = self.checkpoint
return super().train(mode=mode)
def eval(self):
self._use_checkpoint = False
return super().eval()
def divide(numerator, denominator):
""" only allow exact division """
......@@ -18,46 +43,30 @@ def divide(numerator, denominator):
return numerator // denominator
def gelu(x: Tensor) -> Tensor:
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def swish(x: Tensor) -> Tensor:
return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
def set_tensor_parallel_attribute(param):
if not hasattr(param, IS_TENSOR_PARALLEL):
setattr(param, IS_TENSOR_PARALLEL, True)
def set_tensor_parallel_attribute_by_size(param, size):
setattr(param, IS_TENSOR_PARALLEL, True)
setattr(param, NUM_PARTITIONS, size // np.prod(param.shape))
class CheckpointModule(nn.Module):
def __init__(self, checkpoint: bool = True):
super().__init__()
self.checkpoint = checkpoint
self._use_checkpoint = checkpoint
def set_tensor_parallel_attribute_by_partition(param, num_partitions):
setattr(param, IS_TENSOR_PARALLEL, True)
setattr(param, NUM_PARTITIONS, num_partitions)
def _forward(self, *args):
raise NotImplementedError(
'CheckpointModule should implement _forward method instead of origin forward')
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
def forward(self, *args):
if self._use_checkpoint:
return checkpoint(self._forward, *args)
else:
return self._forward(*args)
return parse
def train(self, mode: bool = True):
self._use_checkpoint = self.checkpoint
return super().train(mode=mode)
def eval(self):
self._use_checkpoint = False
return super().eval()
to_2tuple = _ntuple(2)
# adapted from Megatron-LM
# https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/megatron/model/fused_bias_gelu.py
import torch
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_back(g, bias, y):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff*g
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(bias, input)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp
bias_gelu_impl = GeLUFunction.apply
\ No newline at end of file
from ._vit import (ViTBlock, VanillaViTAttention, VanillaViTBlock, VanillaViTDropPath,
VanillaViTHead, VanillaViTMLP, VanillaViTPatchEmbedding)
__all__ = [
'ViTBlock', 'VanillaViTAttention', 'VanillaViTBlock', 'VanillaViTDropPath',
'VanillaViTHead', 'VanillaViTMLP', 'VanillaViTPatchEmbedding'
]
import collections.abc
from itertools import repeat
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from torch import nn as nn
from colossalai.builder import build_layer
from colossalai.registry import LAYERS
from .._common_utils import to_2tuple
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
@LAYERS.register_module
class ViTBlock(nn.Module):
"""Vision Transformer block
:param attention_cfg: config of attention layer
:type attention_cfg: dict
:param droppath_cfg: config of drop path
:type droppath_cfg: dict
:param mlp_cfg: config of MLP layer
:type mlp_cfg: dict
:param norm_cfg: config of normlization layer
:type norm_cfg: dict
"""
def __init__(self,
attention_cfg: dict,
droppath_cfg: dict,
mlp_cfg: dict,
norm_cfg: dict,
):
super().__init__()
self.norm1 = build_layer(norm_cfg)
self.attn = build_layer(attention_cfg)
self.drop_path = build_layer(
droppath_cfg) if droppath_cfg['drop_path'] > 0. else nn.Identity()
self.norm2 = build_layer(norm_cfg)
self.mlp = build_layer(mlp_cfg)
to_2tuple = _ntuple(2)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
@LAYERS.register_module
......
from .layers import Linear1D_Col, Linear1D_Row
from .layers import MixedFusedLayerNorm1D as LayerNorm1D
from ._transformer import TransformerMLP1D, TransformerSelfAttention1D, TransformerLayer1D
from ._vit import ViTMLP1D, ViTSelfAttention1D, ViTHead1D, ViTPatchEmbedding1D, ViTTokenFuser1D, ViTHead
__all__ = [
'Linear1D_Col', 'Linear1D_Row',
'Linear1D_Col', 'Linear1D_Row', 'ViTMLP1D', 'ViTSelfAttention1D', 'ViTHead1D', 'ViTPatchEmbedding1D', 'ViTTokenFuser1D',
'TransformerMLP1D', 'TransformerSelfAttention1D', 'TransformerLayer1D', 'LayerNorm1D', 'ViTHead'
]
import torch
try:
import fused_mix_prec_layer_norm_cuda
except:
fused_mix_prec_layer_norm_cuda = None
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias \
= fused_mix_prec_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None
\ No newline at end of file
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import math
from torch import Tensor
from torch.nn.parameter import Parameter
from typing import Tuple
from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from .._common_utils import divide, ACT2FN
from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \
split_forward_gather_backward
from ..base_layer import ParallelLayer
from .layers import Linear1D_Col, Linear1D_Row
from .layers import MixedFusedLayerNorm1D as LayerNorm1D
@LAYERS.register_module
class TransformerMLP1D(ParallelLayer):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(self,
in_features: int,
mlp_ratio: int = 4.0,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
skip_bias_add: bool = False
):
super(TransformerMLP1D, self).__init__()
self.in_features = in_features
self.mlp_ratio = mlp_ratio
self.skip_bias_add = skip_bias_add
# Project to h * mlp_ratio.
self.dense_1 = Linear1D_Col(
self.in_features,
int(self.mlp_ratio * self.in_features),
bias=not skip_bias_add,
dtype=dtype,
gather_output = False,
)
assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \
f'activation function can only be {list(ACT2FN.keys())}'
self.activation_func = ACT2FN[act_func]
# Project back to h.
self.dense_2 = Linear1D_Row(
int(self.mlp_ratio * self.in_features),
self.in_features,
bias=not skip_bias_add,
dtype=dtype,
parallel_input = True,
)
self.dropout = nn.Dropout(dropout_prob)
# self.layernorm = LayerNorm1D(in_features, dtype=dtype)
self.layernorm = nn.LayerNorm(in_features, dtype=dtype)
def forward(self, x):
if self.skip_bias_add:
intermediate_output, _ = self.dense_1(x)
else:
intermediate_output = self.dense_1(x)
intermediate_output = self.activation_func(intermediate_output)
if self.skip_bias_add:
output, _ = self.dense_2(intermediate_output)
else:
output = self.dense_2(intermediate_output)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
output = self.layernorm(x + output)
return output
@LAYERS.register_module
class TransformerSelfAttention1D(ParallelLayer):
"""Self attention layer for 1D parallel Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layer
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layer
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
):
super().__init__()
self.hidden_size = hidden_size
self.num_attention_heads = divide(num_attention_heads, gpc.tensor_parallel_size)
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size)
self.query_key_value = Linear1D_Col(
hidden_size,
3 * hidden_size,
dtype=dtype,
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear1D_Row(
hidden_size,
hidden_size,
dtype=dtype,
parallel_input=True,
)
self.dropout = nn.Dropout(hidden_dropout_prob)
# need to re-enable torch grad to enable fused optimization.
# self.layernorm = LayerNorm1D(
# hidden_size,
# dtype=dtype)
self.layernorm = nn.LayerNorm(
hidden_size,
dtype=dtype)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_scores = attention_scores + attention_mask
attention_probs = nn.Softmax(dim=-1)(attention_scores)
with seed(ParallelMode.TENSOR):
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute((0, 2, 1, 3)).contiguous()
new_context_layer_shape = context_layer.size()[
:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.dense(context_layer)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
attention_output = self.layernorm(hidden_states + output)
return attention_output
@LAYERS.register_module
class TransformerLayer1D(ParallelLayer):
"""Transformer layer which contains a self-attention layer and a MLP layer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: float, optional
:param attention_dropout_prob: dropout probability for attention layer, defaults to 0.
:type attention_dropout_prob: float, optional
:param hidden_dropout_prob: dropout probability for attention layer, defaults to 0.
:type hidden_dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
act_func: str = 'gelu',
mlp_ratio: float = 4.0,
attention_dropout_prob: float = 0.,
hidden_dropout_prob: float = 0.,
dtype=None,
):
super().__init__()
self.attention = TransformerSelfAttention1D(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout_prob=attention_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
dtype=dtype,
)
self.mlp = TransformerMLP1D(
in_features=hidden_size,
dropout_prob=hidden_dropout_prob,
act_func=act_func,
mlp_ratio=mlp_ratio,
dtype=dtype,
)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
attention_output = self.attention(hidden_states, attention_mask)
output = self.mlp(attention_output)
return output
......@@ -13,3 +13,6 @@ def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank):
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
per_partition_vocab_size = divide(global_vocab_size, world_size)
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank)
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
from colossalai import context
import torch
from torch import nn as nn, Tensor, distributed as dist
from torch.nn.init import _calculate_fan_in_and_fan_out
from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer._common_utils import divide, ACT2FN
from colossalai.registry import LAYERS
from colossalai.utils import checkpoint
from colossalai.utils import get_current_device
from .layers import Linear1D_Col, Linear1D_Row
from ..base_layer import ParallelLayer
from .._common_utils import to_2tuple
from ..fused_bias_gelu import bias_gelu_impl
@LAYERS.register_module
class ViTMLP1D(ParallelLayer):
"""MLP layer for 1D parallel Vision Transformer
:param in_features: size of each input sample
:type in_features: int
:param mlp_ratio: hidden size of MLP divided by embedding dim
:type mlp_ratio: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param dropout_prob: dropout probability, defaults to 0.
:type dropout_prob: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: whether to checkpoint the layer, defaults to False
:type checkpoint: bool, optional
"""
def __init__(self,
in_features: int,
mlp_ratio: int,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
checkpoint: bool = False,
skip_bias_add: bool = False,
weight_init='torch'
):
super().__init__()
self.in_features = in_features
self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint
self.skip_bias_add = skip_bias_add
assert weight_init in ('torch', 'jax')
if act_func == 'fused_gelu':
self.act = bias_gelu_impl
skip_dense_1_add_bias = True
else:
self.act = ACT2FN[act_func]
skip_dense_1_add_bias = False
# Project to mlp_ratio * h.
self.dense_1 = Linear1D_Col(
self.in_features,
int(self.mlp_ratio * self.in_features),
dtype=dtype,
gather_output=False,
skip_bias_add=skip_dense_1_add_bias,
init_weight=weight_init,
init_bias=weight_init
)
# Project back to h.
self.dense_2 = Linear1D_Row(
int(self.mlp_ratio * self.in_features),
self.in_features,
dtype=dtype,
parallel_input=True,
init_weight=weight_init, init_bias=weight_init
)
self.dropout = nn.Dropout(dropout_prob)
def _forward(self, hidden_states: Tensor) -> Tensor:
if self.act == bias_gelu_impl:
intermediate_output, bias = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output, bias)
else:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output)
with seed(ParallelMode.TENSOR):
intermediate_output = self.dropout(intermediate_output)
output = self.dense_2(intermediate_output)
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTSelfAttention1D(ParallelLayer):
"""Self-attention layer for 1D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layers
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layers
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: whether to checkpoint the layer, defaults to False
:type checkpoint: bool, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
checkpoint: bool = False,
weight_init='torch'
):
super().__init__()
self.hidden_size = hidden_size
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.num_attention_heads_per_partition = divide(num_attention_heads, gpc.tensor_parallel_size)
self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size)
self.checkpoint = checkpoint
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
init_bias = 'zero'
else:
init_bias = weight_init
self.query_key_value = Linear1D_Col(
hidden_size,
3 * hidden_size,
dtype=dtype,
init_weight=weight_init,
init_bias=init_bias
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear1D_Row(
hidden_size,
hidden_size,
dtype=dtype,
parallel_input=True,
init_weight=weight_init, init_bias=init_bias
)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.softmax = nn.Softmax(dim=-1)
def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads_per_partition, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
with seed(ParallelMode.TENSOR):
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2)
new_context_layer_shape = context_layer.size()[
:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer)
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTHead1D(ParallelLayer):
"""Output layer for 1D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_classes: number of classes
:type num_classes: int
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size,
num_classes,
dtype=None,
weight_init='torch'
):
super().__init__()
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
init_weight = 'zero'
init_bias = 'zero'
else:
init_weight = weight_init
init_bias = weight_init
self.linear = Linear1D_Col(
hidden_size,
num_classes,
dtype=dtype,
gather_output=True,
init_weight=init_weight,
init_bias=init_bias
)
def forward(self, x: Tensor) -> Tensor:
x = x[:, 0]
x = self.linear(x)
return x
@LAYERS.register_module
class ViTHead(ParallelLayer):
"""Output layer for 1D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_classes: number of classes
:type num_classes: int
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size,
num_classes,
dtype=None,
):
super().__init__()
self.linear = nn.Linear(
hidden_size,
num_classes,
dtype=dtype
)
self._broadcast_linear_params()
def _broadcast_linear_params(self) -> None:
self.to(get_current_device())
ranks = gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)
dist.broadcast(self.linear.weight, src=ranks[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
dist.broadcast(self.linear.bias, src=ranks[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
def forward(self, x: Tensor) -> Tensor:
x = x[:, 0]
x = self.linear(x)
return x
@LAYERS.register_module
class ViTPatchEmbedding1D(ParallelLayer):
""" 2D Image to Patch Embedding
:param img_size: iamge size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param in_chans: number of channels of input image, defaults to 3
:type in_chans: int, optional
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
in_chans=3,
flatten=True,
weight_init='torch'):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans,
self.embed_dim,
kernel_size=patch_size,
stride=patch_size
)
if weight_init == 'jax':
fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight)
std = math.sqrt(1.0 / fan_in)
nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
nn.init.zeros_(self.proj.bias)
# sync
self._broadcast_conv_params()
def _broadcast_conv_params(self) -> None:
self.to(get_current_device())
ranks = gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)
dist.broadcast(self.proj.weight, src=ranks[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
dist.broadcast(self.proj.bias, src=ranks[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
def forward(self, x: Tensor) -> Tensor:
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
return x
@LAYERS.register_module
class ViTTokenFuser1D(ParallelLayer):
"""
Fuse cls token and pos embedding to the input
:param img_size: image size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param drop_rate: dropout probability, defaults to 0.
:type drop_rate: float, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
drop_rate=0.
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.embed_dim = embed_dim
self.cls_token = nn.Parameter(torch.zeros(
1, 1, self.embed_dim))
self.pos_embed = nn.Parameter(torch.empty(
1, self.num_patches + 1, self.embed_dim))
nn.init.trunc_normal_(self.pos_embed, std=.02)
# move to cuda before broadcast
self.to(get_current_device())
dist.broadcast(self.pos_embed,
src=gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
group=gpc.get_group(ParallelMode.TENSOR))
self.pos_drop = nn.Dropout(p=drop_rate)
def forward(self, x: Tensor) -> Tensor:
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = self.pos_drop(x + self.pos_embed)
return x.contiguous()
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import numbers
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch import Tensor
from torch.nn.parameter import Parameter
from typing import Tuple
import importlib
from colossalai.context.parallel_mode import ParallelMode
from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from .._common_utils import divide
from ._operation import FusedLayerNormAffineFunction1D
from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \
split_forward_gather_backward
from ..base_layer import ParallelLayer
@LAYERS.register_module
class Linear1D_Col(ParallelLayer):
"""Linear layer with column parallelism.
......@@ -44,23 +50,29 @@ class Linear1D_Col(ParallelLayer):
output_size: int,
bias: bool = True,
dtype: torch.dtype = None,
gather_output: bool = False):
gather_output: bool = False,
skip_bias_add: bool = False,
init_weight='torch',
init_bias='torch'
):
super().__init__()
# Keep input parameters
self.input_size = in_features
self.output_size = output_size
self.in_features = in_features
self.out_features = output_size
self.gather_output = gather_output
self.skip_bias_add = not bias
self.skip_bias_add = skip_bias_add
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
self.output_size_per_partition = divide(output_size, world_size)
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
self.output_size_per_partition = divide(output_size, gpc.tensor_parallel_size)
# Parameters.
# Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size,
self.output_size_per_partition, self.in_features,
**factory_kwargs))
if bias:
......@@ -72,6 +84,45 @@ class Linear1D_Col(ParallelLayer):
self.bias.zero_()
else:
self.register_parameter('bias', None)
with seed(ParallelMode.TENSOR):
self.reset_parameters(init_weight, init_bias)
self._set_tensor_parallel_attributes()
def reset_parameters(self, init_weight, init_bias) -> None:
assert init_weight in ('torch', 'jax', 'zero')
assert init_bias in ('torch', 'jax', 'zero')
# setting
fan_in, fan_out = self.in_features, self.out_features
# init weight
if init_weight == 'torch':
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
init.uniform_(self.weight, -bound, bound)
elif init_weight == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(self.weight, -a, a)
elif init_weight == 'zero':
init.zeros_(self.weight)
# init bias
if self.bias is not None:
if init_bias == 'torch':
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
elif init_bias == 'jax':
init.normal_(self.bias, std=1e-6)
elif init_bias == 'zero':
init.zeros_(self.bias)
def _set_tensor_parallel_attributes(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
# Set up backprop all-reduce.
......@@ -104,7 +155,7 @@ class Linear1D_Row(ParallelLayer):
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param parallel_input: If set to ``False``, it's assumed that the input is splitted, defaults to False
:param parallel_input: If set to ``True``, it's assumed that the input is splitted, defaults to False
:type parallel_input: bool, optional
"""
......@@ -113,7 +164,10 @@ class Linear1D_Row(ParallelLayer):
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
parallel_input: bool = False
parallel_input: bool = False,
skip_bias_add: bool = False,
init_weight='torch',
init_bias='torch'
):
super().__init__()
......@@ -121,11 +175,13 @@ class Linear1D_Row(ParallelLayer):
self.in_features = in_features
self.out_features = out_features
self.parallel_input = parallel_input
self.skip_bias_add = not bias
self.skip_bias_add = skip_bias_add
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# Divide the weight matrix along the last dimension.
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
self.input_size_per_partition = divide(in_features, world_size)
self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size)
# Parameters.
# Initialize weight.
......@@ -146,9 +202,46 @@ class Linear1D_Row(ParallelLayer):
self.bias.zero_()
else:
self.register_parameter('bias', None)
with seed(ParallelMode.TENSOR):
self.reset_parameters(init_weight, init_bias)
self._set_tensor_parallel_attributes()
def reset_parameters(self, init_weight, init_bias) -> None:
assert init_weight in ('torch', 'jax', 'zero')
assert init_bias in ('torch', 'jax', 'zero')
# setting
fan_in, fan_out = self.in_features, self.out_features
# init weight
if init_weight == 'torch':
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
init.uniform_(self.weight, -bound, bound)
elif init_weight == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(self.weight, -a, a)
elif init_weight == 'zero':
init.zeros_(self.weight)
def reset_parameters(self) -> None:
init.xavier_normal_(self.weight)
# init bias
if self.bias is not None:
if init_bias == 'torch':
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
elif init_bias == 'jax':
init.normal_(self.bias, std=1e-6)
elif init_bias == 'zero':
init.zeros_(self.bias)
dist.broadcast(self.bias,
src=gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
def _set_tensor_parallel_attributes(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
......@@ -163,4 +256,29 @@ class Linear1D_Row(ParallelLayer):
if not self.skip_bias_add:
output = output + self.bias
return output
return output
else:
return output, self.bias
@LAYERS.register_module
class MixedFusedLayerNorm1D(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5):
super(MixedFusedLayerNorm1D, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters()
def reset_parameters(self):
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
return FusedLayerNormAffineFunction1D.apply(
input, self.weight, self.bias, self.normalized_shape, self.eps)
......@@ -20,7 +20,6 @@ def matmul_2d(a,
col_parallel_mode=ParallelMode.PARALLEL_2D_COL,
):
"""Matrix multiplication for 2D parallelism
:param a: matrix :math:`A`
:type a: torch.tensor
:param b: matrix :math:`B`
......@@ -86,25 +85,30 @@ class Matmul_AB_2D(torch.autograd.Function):
ctx.save_for_backward(A, B)
A_shape = A.shape
A = A.reshape((-1, A_shape[-1]))
A = A.reshape((-1, A_shape[-1])).contiguous()
B_shape = B.shape
B = B.reshape((-1, B_shape[-1]))
B = B.reshape((-1, B_shape[-1])).contiguous()
C_shape = (A.shape[0], B.shape[-1])
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode)-1)]
B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode)-1)]
A_list.insert(gpc.get_local_rank(row_parallel_mode), A)
B_list.insert(gpc.get_local_rank(col_parallel_mode), B)
op_a = dist.all_gather(A_list, A, group=gpc.get_group(row_parallel_mode), async_op=True)
op_a.wait()
op_b = dist.all_gather(B_list, B, group=gpc.get_group(col_parallel_mode), async_op=True)
for op in [op_a, op_b]:
op.wait()
for i in range(summa_dim):
A_temp = A.clone()
B_temp = B.clone()
src_a = i + summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(A_temp, src=src_a,
group=gpc.get_group(row_parallel_mode))
src_b = col_rank + summa_dim * i + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(B_temp, src=src_b,
group=gpc.get_group(col_parallel_mode))
src_a = i + summa_dim * row_rank
src_b = i + summa_dim * col_rank
src_a = src_a % summa_dim
src_b = src_b % summa_dim
A_temp = A_list[src_a]
B_temp = B_list[src_b]
torch.addmm(C, A_temp, B_temp, out=C)
out = C.reshape(out_shape)
if ctx:
......@@ -499,36 +503,61 @@ class _LayerNorm_2D(torch.autograd.Function):
# return input_grad, None, None, None, None, None
class _ViT_Split_Input_2D(torch.autograd.Function):
class AllGatherLast(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
inputs: Tensor,
batch_size: int,
summa_dim: int,
col_parallel_mode: ParallelMode) -> Tensor:
# inputs: [b, s, h/q]
# output: [b/q, s, h/q]
ctx.summa_dim = summa_dim
ctx.row_rank = gpc.get_local_rank(col_parallel_mode)
last_dim = summa_dim * inputs.size(-1)
outputs_shape = (last_dim,) + inputs.shape[:-1]
outputs = torch.empty(
outputs_shape, dtype=inputs.dtype, device=get_current_device())
dist.all_gather(
list(outputs.chunk(summa_dim, dim=0)),
inputs.permute(2, 0, 1).contiguous(),
group=gpc.get_group(col_parallel_mode)
)
outputs = outputs.permute(1, 2, 0).contiguous()
return outputs
ctx.BATCH_SIZE = batch_size
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad = output_grad.chunk(ctx.summa_dim, dim=-1)[ctx.row_rank]
return grad.contiguous(), None, None
class SplitFirst(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
inputs: Tensor,
summa_dim: int,
col_parallel_mode: ParallelMode) -> Tensor:
ctx.summa_dim = summa_dim
ctx.col_parallel_mode = col_parallel_mode
ctx.batch_size = inputs.size(0)
ctx.para_mode = col_parallel_mode
row_rank = gpc.get_local_rank(col_parallel_mode)
output = torch.chunk(inputs, summa_dim, dim=0)[row_rank]
output = output.clone()
return output
outputs = inputs.chunk(summa_dim, dim=0)[row_rank]
return outputs
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# output_grad: [b/q, s, h/q]
# grads: [b, s, h/q]
grads_shape = (ctx.BATCH_SIZE,) + output_grad.shape[1:]
grads = torch.empty(grads_shape,
dtype=output_grad.dtype,
device=get_current_device())
dist.all_gather(list(grads.chunk(ctx.summa_dim, dim=0)),
output_grad.contiguous(),
group=gpc.get_group(ctx.col_parallel_mode))
return grads, None, None, None
grad_shape = (ctx.batch_size,) + output_grad.shape[1:]
grad = torch.empty(
grad_shape, dtype=output_grad.dtype, device=get_current_device())
dist.all_gather(
list(grad.chunk(ctx.summa_dim, dim=0)),
output_grad.contiguous(),
group=gpc.get_group(ctx.para_mode)
)
return grad, None, None
......@@ -5,19 +5,21 @@ import math
import torch
from torch import nn as nn, Tensor, distributed as dist
from torch.nn.init import _calculate_fan_in_and_fan_out
from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer._common_utils import divide, ACT2FN
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env
from colossalai.nn.layer.vanilla_vision_transformer.layers import to_2tuple
from colossalai.registry import LAYERS
from colossalai.utils import checkpoint
from colossalai.utils import get_current_device
from ._operation import _ViT_Split_Input_2D
from colossalai.core import global_context as gpc
from ._operation import AllGatherLast, SplitFirst
from .layers import Linear2D
from .._common_utils import set_tensor_parallel_attribute
from .._common_utils import set_tensor_parallel_attribute_by_partition, to_2tuple
from ..base_layer import ParallelLayer
from ..fused_bias_gelu import bias_gelu_impl
@LAYERS.register_module
......@@ -44,8 +46,8 @@ class ViTMLP2D(ParallelLayer):
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
checkpoint: bool = False
):
checkpoint: bool = False,
weight_init='torch'):
super().__init__()
assert_summa_initialization()
......@@ -53,27 +55,40 @@ class ViTMLP2D(ParallelLayer):
self.in_features = in_features
self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint
assert weight_init in ('torch', 'jax')
if act_func == 'fused_gelu':
self.act = bias_gelu_impl
skip_dense_1_add_bias = True
else:
self.act = ACT2FN[act_func]
skip_dense_1_add_bias = False
# Project to mlp_ratio * h.
self.dense_1 = Linear2D(
self.in_features,
self.mlp_ratio * self.in_features,
dtype=dtype,
init_weight=weight_init, init_bias=weight_init,
skip_bias_add=skip_dense_1_add_bias
)
self.act = ACT2FN[act_func]
# Project back to h.
self.dense_2 = Linear2D(
self.mlp_ratio * self.in_features,
self.in_features,
dtype=dtype,
init_weight=weight_init, init_bias=weight_init
)
self.dropout = nn.Dropout(dropout_prob)
def _forward(self, hidden_states: Tensor) -> Tensor:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output)
if self.act == bias_gelu_impl:
intermediate_output, bias = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output, bias)
else:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output)
with seed(ParallelMode.TENSOR):
intermediate_output = self.dropout(intermediate_output)
......@@ -117,8 +132,8 @@ class ViTSelfAttention2D(ParallelLayer):
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
checkpoint: bool = False
):
checkpoint: bool = False,
weight_init='torch'):
super().__init__()
assert_summa_initialization()
......@@ -128,17 +143,24 @@ class ViTSelfAttention2D(ParallelLayer):
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.checkpoint = checkpoint
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
self.init_bias = 'zero'
else:
self.init_bias = weight_init
self.query_key_value = Linear2D(
hidden_size,
3 * hidden_size,
dtype=dtype,
init_weight=weight_init, init_bias=self.init_bias
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear2D(
hidden_size,
hidden_size,
dtype=dtype,
init_weight=weight_init, init_bias=self.init_bias
)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.softmax = nn.Softmax(dim=-1)
......@@ -146,7 +168,7 @@ class ViTSelfAttention2D(ParallelLayer):
def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
......@@ -155,7 +177,7 @@ class ViTSelfAttention2D(ParallelLayer):
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
......@@ -165,7 +187,7 @@ class ViTSelfAttention2D(ParallelLayer):
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2)
new_context_layer_shape = context_layer.size()[
:-2] + (self.all_head_size,)
:-2] + (self.all_head_size,)
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer)
......@@ -199,14 +221,22 @@ class ViTHead2D(ParallelLayer):
hidden_size,
num_classes,
dtype=None,
):
weight_init='torch'):
super().__init__()
assert_summa_initialization()
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
self.init_weight = 'zero'
self.init_bias = 'zero'
else:
self.init_weight = weight_init
self.init_bias = weight_init
self.summa_dim = get_summa_dim_from_env()
self.linear = Linear2D(
hidden_size,
num_classes,
dtype=dtype,
init_weight=self.init_weight, init_bias=self.init_bias
)
def forward(self, x: Tensor) -> Tensor:
......@@ -236,7 +266,8 @@ class ViTPatchEmbedding2D(ParallelLayer):
patch_size,
embed_dim,
in_chans=3,
flatten=True):
flatten=True,
weight_init='torch'):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
......@@ -249,39 +280,28 @@ class ViTPatchEmbedding2D(ParallelLayer):
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.embed_dim = embed_dim // self.summa_dim
self.embed_dim = embed_dim // (self.summa_dim ** 2)
with seed(ParallelMode.TENSOR):
# ensure the partitions are initialized differently
self.proj = nn.Conv2d(in_chans,
self.embed_dim,
kernel_size=patch_size,
stride=patch_size
stride=patch_size,
device=get_current_device()
)
self._set_tensor_parallel_attribute()
# sync
self._broadcast_conv_params()
self.proj.weight.register_hook(self._sync_grad_during_backward)
self.proj.bias.register_hook(self._sync_grad_during_backward)
if weight_init == 'jax':
with seed(ParallelMode.TENSOR):
fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight)
std = math.sqrt(1.0 / fan_in)
nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
nn.init.zeros_(self.proj.bias)
def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute(self.proj.weight)
set_tensor_parallel_attribute(self.proj.bias)
def _broadcast_conv_params(self) -> None:
self.to(get_current_device())
ranks_in_col = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_COL)
dist.broadcast(self.proj.weight, src=ranks_in_col[0],
group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
dist.broadcast(self.proj.bias, src=ranks_in_col[0],
group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
def _sync_grad_during_backward(self, grad: Tensor) -> None:
dist.all_reduce(grad, group=gpc.get_group(
ParallelMode.PARALLEL_2D_COL))
grad = grad / self.summa_dim
return grad
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.proj.weight, num_partition)
set_tensor_parallel_attribute_by_partition(self.proj.bias, num_partition)
def forward(self, x: Tensor) -> Tensor:
B, C, H, W = x.shape
......@@ -293,6 +313,24 @@ class ViTPatchEmbedding2D(ParallelLayer):
return x
@LAYERS.register_module
class ViTInputSplitter2D(ParallelLayer):
"""Split the input tensor for 2D parallel Vision Transformer
"""
def __init__(self):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
def forward(self, x: Tensor) -> Tensor:
x = AllGatherLast.apply(
x, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
x = SplitFirst.apply(
x, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
return x
@LAYERS.register_module
class ViTTokenFuser2D(ParallelLayer):
"""
......@@ -328,64 +366,32 @@ class ViTTokenFuser2D(ParallelLayer):
self.embed_dim = embed_dim
self.cls_token = nn.Parameter(torch.zeros(
1, 1, self.embed_dim // self.summa_dim))
self.pos_embed = nn.Parameter(torch.zeros(
1, self.num_patches + 1, self.embed_dim // self.summa_dim))
# move to cuda before broadcast
self.to(get_current_device())
# sync param in both forward and backward
_cls_token = self.cls_token.view(-1)
_pos_embed = self.pos_embed.view(-1)
self._param = torch.cat([_cls_token, _pos_embed], dim=0)
(1, 1, self.embed_dim // (self.summa_dim ** 2)),
device=get_current_device()))
self.pos_embed = nn.Parameter(torch.empty(
(1, self.num_patches + 1, self.embed_dim // (self.summa_dim ** 2)),
device=get_current_device()))
with seed(ParallelMode.TENSOR):
nn.init.trunc_normal_(self.pos_embed, std=.02)
self._broadcast_params(self._param)
self._param.register_hook(self._sync_grad_hook)
self.pos_drop = nn.Dropout(p=drop_rate)
self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute(self.cls_token)
set_tensor_parallel_attribute(self.pos_embed)
def _broadcast_params(self, param) -> None:
" broadcast to all column ranks for data consistency "
ranks_in_col = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_COL)
col_group = gpc.get_group(ParallelMode.PARALLEL_2D_COL)
dist.broadcast(param, src=ranks_in_col[0],
group=col_group)
def _sync_grad_hook(self, grad) -> None:
dist.all_reduce(grad, group=gpc.get_group(
ParallelMode.PARALLEL_2D_COL))
grad = grad / self.summa_dim
return grad
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.cls_token, num_partition)
set_tensor_parallel_attribute_by_partition(self.pos_embed, num_partition)
def forward(self, x: Tensor) -> Tensor:
# stole cls_tokens impl from Phil Wang, thanks
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
cls_token = AllGatherLast.apply(
self.cls_token, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
cls_token = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
pos_embed = AllGatherLast.apply(
self.pos_embed, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
x = x + pos_embed
with seed(ParallelMode.TENSOR):
x = self.pos_drop(x + self.pos_embed)
x = self.pos_drop(x)
return x
@LAYERS.register_module
class ViTInputSplitter2D(ParallelLayer):
"""Split the input tensor for 2D parallel Vision Transformer
"""
def __init__(self):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
def forward(self, x: Tensor) -> Tensor:
batch_size = x.size(0)
return _ViT_Split_Input_2D.apply(
x,
batch_size,
self.summa_dim,
ParallelMode.PARALLEL_2D_COL
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment