"examples/vscode:/vscode.git/clone" did not exist on "04860484531519e02b96995c7811cf4101074232"
Unverified Commit 0fedef4f authored by アマデウス's avatar アマデウス Committed by GitHub
Browse files

Layer integration (#83)



* integrated parallel layers for ease of building models

* integrated 2.5d layers

* cleaned codes and unit tests

* added log metric by step hook; updated imagenet benchmark; fixed some bugs

* reworked initialization; cleaned codes
Co-authored-by: default avatarBoxiangW <45734921+BoxiangW@users.noreply.github.com>
parent 5c3843dc
......@@ -3,6 +3,7 @@
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
from torch import Tensor
from colossalai.context import ParallelMode
......@@ -10,8 +11,7 @@ from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
def all_gather(tensor: Tensor, dim: int,
parallel_mode: ParallelMode, async_op=False) -> Tensor:
def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor:
"""Gathers all tensors from the parallel group and concatenates them in a
specific dimension.
......@@ -25,29 +25,31 @@ def all_gather(tensor: Tensor, dim: int,
:rtype: :class:`torch.Tensor`
"""
depth = gpc.get_world_size(parallel_mode)
temp = tensor.clone()
# shape = list(temp.shape)
# shape[dim] *= depth
# out = torch.zeros(shape, dtype=temp.dtype, device=get_current_device())
# out = list(torch.chunk(out, depth, dim=dim))
# out = [val.contiguous() for val in out]
shape = [1] * len(tensor.shape)
shape[dim] = depth
out = tensor.repeat(shape)
out = list(map(lambda x: x.contiguous(), torch.chunk(out, depth, dim=dim)))
op = dist.all_gather(tensor_list=out,
tensor=temp,
group=gpc.get_group(parallel_mode),
async_op=async_op)
# out = torch.cat(out, dim=dim)
if depth == 1:
out = [tensor]
work = None
else:
shape = list(tensor.shape)
shape[0], shape[dim] = shape[dim], shape[0]
shape[0] *= depth
out = torch.empty(shape, dtype=tensor.dtype, device=get_current_device())
temp = list(torch.chunk(out, depth, dim=0))
work = dist.all_gather(tensor_list=temp,
tensor=tensor.transpose(0, dim).contiguous(),
group=gpc.get_group(parallel_mode),
async_op=async_op)
out = torch.transpose(out, 0, dim)
if async_op:
return out, op
return out, work
else:
return out
def reduce_scatter(tensor: Tensor, dim: int,
parallel_mode: ParallelMode, async_op=False) -> Tensor:
def reduce_scatter(tensor: Tensor,
dim: int,
parallel_mode: ParallelMode,
op: ReduceOp = ReduceOp.SUM,
async_op: bool = False) -> Tensor:
"""Reduces all tensors then scatters it in a specific dimension to all
members in the parallel group.
......@@ -61,52 +63,57 @@ def reduce_scatter(tensor: Tensor, dim: int,
:rtype: :class:`Tensor`
"""
depth = gpc.get_world_size(parallel_mode)
# temp = list(torch.chunk(tensor, depth, dim=dim))
# temp = [val.contiguous() for val in temp]
# out = torch.zeros(temp[0].shape,
# dtype=temp[0].dtype,
# device=get_current_device())
temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim)))
out = temp[0].clone()
op = dist.reduce_scatter(output=out,
input_list=temp,
group=gpc.get_group(parallel_mode),
async_op=async_op)
if depth == 1:
out = tensor
work = None
else:
temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim)))
out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=get_current_device())
work = dist.reduce_scatter(output=out,
input_list=temp,
op=op,
group=gpc.get_group(parallel_mode),
async_op=async_op)
if async_op:
return out, op
return out, work
else:
return out
def all_reduce(tensor: Tensor,
parallel_mode: ParallelMode,
async_op=False) -> Tensor:
op = dist.all_reduce(tensor,
group=gpc.get_group(parallel_mode),
async_op=async_op)
op: ReduceOp = ReduceOp.SUM,
async_op: bool = False) -> Tensor:
depth = gpc.get_world_size(parallel_mode)
if depth == 1:
work = None
else:
work = dist.all_reduce(tensor.contiguous(), op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
if async_op:
return tensor, op
return tensor, work
else:
return tensor
# def scatter(tensor: Tensor, src: int, dim: int,
# parallel_mode: ParallelMode) -> Tensor:
# """Scatters in a specific dimension from source rank to all ranks in
# the parallel group.
# :param tensor: Tensor to be scattered
# :param dim: The dimension scattering in
# :param parallel_mode: Parallel group mode used in this communication
# :type tensor: Tensor
# :type dim: int
# :type parallel_mode: ParallelMode
# :return: The tensor generated by scatter
# :rtype: Tensor
# """
# depth = gpc.get_world_size(parallel_mode)
# temp = tensor.clone()
# dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode))
# rank = gpc.get_local_rank(parallel_mode)
# out = torch.chunk(temp, depth, dim=dim)[rank].contiguous()
# return out
def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: bool = False):
depth = gpc.get_world_size(parallel_mode)
if depth == 1:
work = None
else:
work = dist.broadcast(tensor.contiguous(), src=src, group=gpc.get_group(parallel_mode), async_op=async_op)
if async_op:
return tensor, work
else:
return tensor
def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False):
depth = gpc.get_world_size(parallel_mode)
if depth == 1:
work = None
else:
work = dist.reduce(tensor.contiguous(), dst=dst, op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
if async_op:
return tensor, work
else:
return tensor
......@@ -497,8 +497,7 @@ class ParallelContext:
self._logger.info(
f"initialized seed on rank {global_rank}, "
f"numpy: {seed}, python random: {seed}, {seed_str},"
f"the default parallel seed is {ParallelMode.DATA}.",
ranks=[0])
f"the default parallel seed is {ParallelMode.DATA}.")
else:
if self._verbose:
self._logger.info(
......
......@@ -184,8 +184,6 @@ def launch_from_openmpi(config: Union[str, Path, Config, Dict],
def launch_from_torch(config: Union[str, Path, Config, Dict],
host: str,
port: int,
backend: str = 'nccl',
seed: int = 1024,
verbose: bool = True):
......@@ -206,6 +204,8 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
host = os.environ['MASTER_ADDR']
port = int(os.environ['MASTER_PORT'])
launch(config=config,
local_rank=local_rank,
rank=rank,
......
from .layer import *
from .loss import *
from .lr_scheduler import *
from .metric import *
from .model import *
from .optimizer import *
import math
import warnings
from torch import Tensor
from torch.nn import init as init
def init_weight_(tensor: Tensor, fan_in: int, fan_out: int = None, init_method: str = 'torch'):
if init_method == 'torch':
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
init.uniform_(tensor, -bound, bound)
elif init_method == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(tensor, -a, a)
elif init_method == 'jax_embed':
import torch.nn as nn
def zeros_():
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.zeros_(tensor)
return initializer
def ones_():
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.ones_(tensor)
return initializer
def uniform_(a: float = 0., b: float = 1.):
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.uniform_(tensor, a, b)
return initializer
def normal_(mean: float = 0., std: float = 1.):
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.normal_(tensor, mean, std)
return initializer
def trunc_normal_(mean: float = 0., std: float = 1., a: float = -2., b: float = 2.):
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.trunc_normal_(tensor, mean, std, a, b)
return initializer
def kaiming_uniform_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
# adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
if 0 in tensor.shape:
warnings.warn("Initializing zero-element tensors is a no-op")
return tensor
if mode == 'fan_in':
assert fan_in is not None, 'Fan_in is not provided.'
fan = fan_in
elif mode == 'fan_out':
assert fan_out is not None, 'Fan_out is not provided.'
fan = fan_out
else:
raise ValueError(f'Invalid initialization mode \'{mode}\'')
std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan)
bound = math.sqrt(3.) * std
return nn.init.uniform_(tensor, -bound, bound)
return initializer
def kaiming_normal_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
# adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
if 0 in tensor.shape:
warnings.warn("Initializing zero-element tensors is a no-op")
return tensor
if mode == 'fan_in':
assert fan_in is not None, 'Fan_in is not provided.'
fan = fan_in
elif mode == 'fan_out':
assert fan_out is not None, 'Fan_out is not provided.'
fan = fan_out
else:
raise ValueError(f'Invalid initialization mode \'{mode}\'')
std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan)
return nn.init.normal_(tensor, 0, std)
return initializer
def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1.):
# adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
assert fan_in is not None, 'Fan_in is not provided.'
fan = fan_in
if fan_out is not None:
fan += fan_out
std = gain * math.sqrt(scale / float(fan))
bound = a * std
return nn.init.uniform_(tensor, -bound, bound)
return initializer
def xavier_normal_(scale: float = 2., gain: float = 1.):
# adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
assert fan_in is not None, 'Fan_in is not provided.'
fan = fan_in
if fan_out is not None:
fan += fan_out
std = gain * math.sqrt(scale / float(fan))
return nn.init.normal_(tensor, 0., std)
return initializer
def lecun_uniform_():
# adapted from jax.nn.initializers
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
assert fan_in is not None, 'Fan_in is not provided.'
var = 1.0 / fan_in
bound = math.sqrt(3 * var)
return nn.init.uniform_(tensor, -bound, bound)
return initializer
def lecun_normal_():
# adapted from jax.nn.initializers
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
assert fan_in is not None, 'Fan_in is not provided.'
std = math.sqrt(1.0 / fan_in)
init.trunc_normal_(tensor, std=std / .87962566103423978)
elif init_method == 'zero':
init.zeros_(tensor)
def init_bias_(tensor: Tensor, fan_in: int, init_method: str = 'torch'):
if init_method == 'torch':
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(tensor, -bound, bound)
elif init_method == 'jax':
init.normal_(tensor, std=1e-6)
elif init_method == 'jax_embed':
init.trunc_normal_(tensor, std=.02)
elif init_method == 'zero':
init.zeros_(tensor)
return nn.init.trunc_normal_(tensor, std=std / .87962566103423978)
return initializer
from .colossalai_layer import *
from .fused_bias_gelu import bias_gelu_impl
from .parallel_1d import *
from .parallel_2d import *
from .parallel_2p5d import *
from .parallel_3d import *
from .parallel_sequence import *
from .non_parallel_layers import *
from .wrapper import *
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import collections.abc
from itertools import repeat
import numpy as np
from colossalai.utils.common import print_rank_0
import torch
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS
from colossalai.utils import checkpoint
......@@ -19,8 +18,7 @@ class CheckpointModule(nn.Module):
self._use_checkpoint = checkpoint
def _forward(self, *args, **kwargs):
raise NotImplementedError(
'CheckpointModule should implement _forward method instead of origin forward')
raise NotImplementedError('CheckpointModule should implement _forward method instead of origin forward')
def forward(self, *args, **kwargs):
if self._use_checkpoint:
......@@ -36,6 +34,7 @@ class CheckpointModule(nn.Module):
self._use_checkpoint = False
return super().eval()
def divide(numerator, denominator):
""" only allow exact division """
assert numerator % denominator == 0, \
......@@ -59,7 +58,10 @@ def set_tensor_parallel_attribute_by_partition(param, num_partitions):
setattr(param, IS_TENSOR_PARALLEL, True)
setattr(param, NUM_PARTITIONS, num_partitions)
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
......
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.distributed as dist
from colossalai.core import global_context as gpc
def _reduce(input_, parallel_mode):
# skip if only one rank involved
if gpc.get_world_size(parallel_mode) == 1:
return input_
dist.all_reduce(input_, group=gpc.get_group(parallel_mode))
return input_
def _split(input_, parallel_mode, dim=-1):
# skip if only one rank involved
world_size = gpc.get_world_size(parallel_mode)
if world_size == 1:
return input_
# Split along last dimension.
dim_size = input_.size(dim)
assert dim_size % world_size == 0, \
f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \
f'cannot split tensor evenly'
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
rank = gpc.get_local_rank(parallel_mode)
output = tensor_list[rank].contiguous()
return output
def _gather(input_, parallel_mode, dim=-1):
# skip if only one rank involved
world_size = gpc.get_world_size(parallel_mode)
if world_size == 1:
return input_
# all gather
rank = gpc.get_local_rank(parallel_mode)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=gpc.get_group(parallel_mode))
# concat
output = torch.cat(tensor_list, dim=dim).contiguous()
return output
class _ReduceGrad(torch.autograd.Function):
"""Pass the input to the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
def forward(ctx, input_, parallel_mode):
ctx.mode = parallel_mode
return input_
@staticmethod
def backward(ctx, grad_output):
return _reduce(grad_output, ctx.mode), None
class _ReduceInput(torch.autograd.Function):
"""All-reduce the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce(input_)
@staticmethod
def forward(ctx, input_, parallel_mode):
return _reduce(input_, parallel_mode)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
class _SplitForwardGatherBackward(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split(input_)
@staticmethod
def forward(ctx, input_, parallel_mode, dim):
ctx.mode = parallel_mode
ctx.dim = dim
return _split(input_, parallel_mode, dim)
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output, ctx.mode, ctx.dim), None, None
class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_):
return _gather(input_)
@staticmethod
def forward(ctx, input_, parallel_mode, dim):
ctx.mode = parallel_mode
ctx.dim = dim
return _gather(input_, parallel_mode, dim)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.mode, ctx.dim), None, None
def reduce_grad(input_, parallel_mode):
return _ReduceGrad.apply(input_, parallel_mode)
def reduce_input(input_, parallel_mode):
return _ReduceInput.apply(input_, parallel_mode)
def split_forward_gather_backward(input_, parallel_mode, dim):
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)
def gather_forward_split_backward(input_, parallel_mode, dim):
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
import math
from typing import Callable, Optional
from colossalai.utils import get_current_device
from torch import dtype, nn
from torch.nn.modules.activation import *
from torch.nn.modules.adaptive import *
from torch.nn.modules.batchnorm import *
from torch.nn.modules.channelshuffle import *
from torch.nn.modules.conv import *
from torch.nn.modules.distance import *
from torch.nn.modules.dropout import *
from torch.nn.modules.flatten import *
from torch.nn.modules.fold import *
from torch.nn.modules.instancenorm import *
from torch.nn.modules.linear import *
from torch.nn.modules.normalization import *
from torch.nn.modules.padding import *
from torch.nn.modules.pixelshuffle import *
from torch.nn.modules.pooling import *
from torch.nn.modules.rnn import *
from torch.nn.modules.sparse import *
from torch.nn.modules.transformer import *
from torch.nn.modules.upsampling import *
from .. import init as init
from .vanilla import *
from .parallel_1d import *
from .parallel_2d import *
from .parallel_2p5d import *
from .parallel_3d import *
from .parallel_sequence import *
_parallel_linear = {'1d_col': Linear1D_Col, '1d_row': Linear1D_Row, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
_parallel_classifier = {
None: VanillaClassifier,
'1d': VanillaClassifier,
'2d': Classifier2D,
'2.5d': Classifier2p5D,
'3d': Classifier3D
}
_parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D}
_parallel_embedding = {'3d': Embedding3D}
_parallel_patchembedding = {
None: VanillaPatchEmbedding,
'1d': VanillaPatchEmbedding,
'2d': PatchEmbedding2D,
'2.5d': PatchEmbedding2p5D,
'3d': PatchEmbedding3D
}
class Linear(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
tensor_parallel: Optional[str] = None,
**kwargs) -> None:
super().__init__()
if tensor_parallel is None:
self.layer = nn.Linear(in_features, out_features, bias=bias, device=get_current_device(), dtype=dtype)
weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features)
if bias:
bias_initializer(self.layer.bias, fan_in=in_features)
else:
self.layer = _parallel_linear[tensor_parallel](
in_features,
out_features,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
**kwargs,
)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, *args):
return self.layer(*args)
class LayerNorm(nn.Module):
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None, tensor_parallel: Optional[str] = None) -> None:
super().__init__()
if tensor_parallel in [None, '1d']:
self.norm = nn.LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
else:
self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
@property
def weight(self):
return self.norm.weight
@property
def bias(self):
return self.norm.bias
def forward(self, *args):
return self.norm(*args)
class Embedding(nn.Module):
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: dtype = None,
weight_initializer: Callable = init.normal_(),
tensor_parallel: Optional[str] = None,
*args,
**kwargs) -> None:
super().__init__()
if tensor_parallel in [None, '1d']:
self.embed = nn.Embedding(num_embeddings,
embedding_dim,
padding_idx=padding_idx,
device=get_current_device(),
dtype=dtype,
*args,
**kwargs)
weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
else:
self.embed = _parallel_embedding[tensor_parallel](
num_embeddings,
embedding_dim,
padding_idx=padding_idx,
dtype=dtype,
weight_initializer=weight_initializer,
*args,
**kwargs,
)
@property
def weight(self):
return self.embed.weight
def forward(self, *args):
return self.embed(*args)
class PatchEmbedding(nn.Module):
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
dtype: dtype = None,
flatten: bool = True,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_(),
tensor_parallel: Optional[str] = None) -> None:
super().__init__()
self.embed = _parallel_patchembedding[tensor_parallel](
img_size,
patch_size,
in_chans,
embed_size,
dtype=dtype,
flatten=flatten,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
position_embed_initializer=position_embed_initializer,
)
@property
def weight(self):
return self.embed.weight
@property
def bias(self):
return self.embed.bias
@property
def pos_embed(self):
return self.embed.pos_embed
@property
def cls_token(self):
return self.embed.cls_token
def forward(self, *args):
return self.embed(*args)
class Classifier(nn.Module):
def __init__(self,
in_features: int,
num_classes: int,
weight: nn.Parameter = None,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
tensor_parallel: Optional[str] = None) -> None:
super().__init__()
self.layer = _parallel_classifier[tensor_parallel](
in_features,
num_classes,
weight=weight,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, *args):
return self.layer(*args)
from ._vit import (ViTBlock, VanillaViTAttention, VanillaViTBlock, VanillaViTDropPath,
VanillaViTHead, VanillaViTMLP, VanillaViTPatchEmbedding)
__all__ = [
'ViTBlock', 'VanillaViTAttention', 'VanillaViTBlock', 'VanillaViTDropPath',
'VanillaViTHead', 'VanillaViTMLP', 'VanillaViTPatchEmbedding'
]
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from torch import nn as nn
from colossalai.builder import build_layer
from colossalai.registry import LAYERS
from .._common_utils import to_2tuple
@LAYERS.register_module
class ViTBlock(nn.Module):
"""Vision Transformer block
:param attention_cfg: config of attention layer
:type attention_cfg: dict
:param droppath_cfg: config of drop path
:type droppath_cfg: dict
:param mlp_cfg: config of MLP layer
:type mlp_cfg: dict
:param norm_cfg: config of normlization layer
:type norm_cfg: dict
"""
def __init__(self,
attention_cfg: dict,
droppath_cfg: dict,
mlp_cfg: dict,
norm_cfg: dict,
):
super().__init__()
self.norm1 = build_layer(norm_cfg)
self.attn = build_layer(attention_cfg)
self.drop_path = build_layer(
droppath_cfg) if droppath_cfg['drop_path'] > 0. else nn.Identity()
self.norm2 = build_layer(norm_cfg)
self.mlp = build_layer(mlp_cfg)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
@LAYERS.register_module
class VanillaViTPatchEmbedding(nn.Module):
""" 2D Image to Patch Embedding
:param img_size: image size
:type img_size: int
:param patch_size: size of a patch
:type patch_size: int
:param in_chans: input channels
:type in_chans: int
:param embed_dim: embedding dimension
:type embed_dim: int
:param norm_layer: layer norm class, defaults to None
:type norm_layer: Callable
:param flattern: whether flatten the output
:type flatten: bool
:param drop: dropout rate
:type drop: float
"""
def __init__(self, img_size, patch_size, in_chans, embed_dim, norm_layer=None, flatten=True, drop=0.):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim,
kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop)
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = self.pos_drop(x + self.pos_embed)
return x
@LAYERS.register_module
class VanillaViTMLP(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
:param in_features: input channels
:type in_features: int
:param hidden_features: channels of the output of the first dense layer
:type hidden_features: int
:param hidden_features: channels of the output of the second dense layer
:type hidden_features: int
:param act_layer: activation function
:type act_layer: Callable
:param drop: dropout rate
:type drop: float
"""
def __init__(self, in_features, hidden_features, out_features, act_layer=nn.GELU, drop=0.):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
:param drop_prob: probability for dropout
:type drop_prob: float
:param training: whether it is training mode
:type training: bool
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + \
torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
@LAYERS.register_module
class VanillaViTDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
:param drop_prob: probability for dropout
:type drop_path: float
"""
def __init__(self, drop_prob=0.):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
@LAYERS.register_module
class VanillaViTAttention(nn.Module):
"""Vanilla attention layer of Vision Transformer
:param dim: dimension of input tensor
:type dim: int
:param num_heads: number of attention heads
:type num_heads: int, optional
:param qkv_bias: enable bias for qkv if True, defaults to False
:type qkv_bias: bool, optional
:param attn_drop: dropout probability for attention layer, defaults to 0.
:type attn_drop: float, optional
:param proj_drop: dropout probability for linear layer, defaults to 0.
:type proj_drop: float, optional
"""
def __init__(self, dim, num_heads, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
self.num_heads).permute(2, 0, 3, 1, 4)
# make torchscript happy (cannot use tensor as tuple)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@LAYERS.register_module
class VanillaViTBlock(nn.Module):
"""Vanilla Vision Transformer block
:param dim: dimension of input tensor
:type dim: int
:param num_heads: number of attention heads
:type num_heads: int
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.
:type mlp_ratio: float, optional
:param qkv_bias: enable bias for qkv if True, defaults to False
:type qkv_bias: bool, optional
:param drop: dropout probability, defaults to 0.
:type drop: float, optional
:param attn_drop: dropout probability for attention layer, defaults to 0.
:type attn_drop: float, optional
:param drop_path: drop path probability, defaults to 0.
:type drop_path: float, optional
:param act_layer: activation function, defaults to nn.GELU
:type act_layer: torch.nn.Module, optional
:param norm_layer: normalization layer, defaults to nn.LayerNorm
:type norm_layer: torch.nn.Module, optional
"""
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = LAYERS.get_module('VanillaViTAttention')(dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = LAYERS.get_module('VanillaViTDropPath')(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = LAYERS.get_module('VanillaViTMLP')(in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
@LAYERS.register_module
class VanillaViTHead(nn.Module):
"""Output layer of vanilla Vision Transformer
:param in_features: size of input tensor
:type in_features: int
:param intermediate_features: hidden size
:type intermediate_features: int
:param out_features: size of output tensor
:type out_features: int
:param bias: whether to add bias, defaults to True
:type bias: bool, optional
"""
def __init__(self,
in_features,
intermediate_features,
out_features,
bias=True
):
super().__init__()
self.linear_1 = nn.Linear(
in_features, intermediate_features, bias=bias)
self.act = nn.Tanh()
self.linear_2 = nn.Linear(
intermediate_features, out_features, bias=bias)
def forward(self, x):
x = x[:, 0, :].squeeze(1)
x = self.linear_1(x)
x = self.act(x)
x = self.linear_2(x)
return x
from .layers import Linear1D_Col, Linear1D_Row
from .layers import MixedFusedLayerNorm1D as LayerNorm1D
from ._transformer import TransformerMLP1D, TransformerSelfAttention1D, TransformerLayer1D
from ._vit import ViTMLP1D, ViTSelfAttention1D, ViTHead1D, ViTPatchEmbedding1D, ViTTokenFuser1D, ViTHead
__all__ = [
'Linear1D_Col', 'Linear1D_Row', 'ViTMLP1D', 'ViTSelfAttention1D', 'ViTHead1D', 'ViTPatchEmbedding1D', 'ViTTokenFuser1D',
'TransformerMLP1D', 'TransformerSelfAttention1D', 'TransformerLayer1D', 'LayerNorm1D', 'ViTHead'
]
__all__ = ['Linear1D_Col', 'Linear1D_Row', 'LayerNorm1D']
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import math
from torch import Tensor
from torch.nn.parameter import Parameter
from typing import Tuple
from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from .._common_utils import divide, ACT2FN
from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \
split_forward_gather_backward
from ..base_layer import ParallelLayer
from .layers import Linear1D_Col, Linear1D_Row
from .layers import MixedFusedLayerNorm1D as LayerNorm1D
@LAYERS.register_module
class TransformerMLP1D(ParallelLayer):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(self,
in_features: int,
mlp_ratio: int = 4.0,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
skip_bias_add: bool = False
):
super(TransformerMLP1D, self).__init__()
self.in_features = in_features
self.mlp_ratio = mlp_ratio
self.skip_bias_add = skip_bias_add
# Project to h * mlp_ratio.
self.dense_1 = Linear1D_Col(
self.in_features,
int(self.mlp_ratio * self.in_features),
bias=not skip_bias_add,
dtype=dtype,
gather_output = False,
)
assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \
f'activation function can only be {list(ACT2FN.keys())}'
self.activation_func = ACT2FN[act_func]
# Project back to h.
self.dense_2 = Linear1D_Row(
int(self.mlp_ratio * self.in_features),
self.in_features,
bias=not skip_bias_add,
dtype=dtype,
parallel_input = True,
)
self.dropout = nn.Dropout(dropout_prob)
# self.layernorm = LayerNorm1D(in_features, dtype=dtype)
self.layernorm = nn.LayerNorm(in_features, dtype=dtype)
def forward(self, x):
if self.skip_bias_add:
intermediate_output, _ = self.dense_1(x)
else:
intermediate_output = self.dense_1(x)
intermediate_output = self.activation_func(intermediate_output)
if self.skip_bias_add:
output, _ = self.dense_2(intermediate_output)
else:
output = self.dense_2(intermediate_output)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
output = self.layernorm(x + output)
return output
@LAYERS.register_module
class TransformerSelfAttention1D(ParallelLayer):
"""Self attention layer for 1D parallel Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layer
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layer
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
):
super().__init__()
self.hidden_size = hidden_size
self.num_attention_heads = divide(num_attention_heads, gpc.tensor_parallel_size)
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size)
self.query_key_value = Linear1D_Col(
hidden_size,
3 * hidden_size,
dtype=dtype,
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear1D_Row(
hidden_size,
hidden_size,
dtype=dtype,
parallel_input=True,
)
self.dropout = nn.Dropout(hidden_dropout_prob)
# need to re-enable torch grad to enable fused optimization.
# self.layernorm = LayerNorm1D(
# hidden_size,
# dtype=dtype)
self.layernorm = nn.LayerNorm(
hidden_size,
dtype=dtype)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_scores = attention_scores + attention_mask
attention_probs = nn.Softmax(dim=-1)(attention_scores)
with seed(ParallelMode.TENSOR):
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute((0, 2, 1, 3)).contiguous()
new_context_layer_shape = context_layer.size()[
:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.dense(context_layer)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
attention_output = self.layernorm(hidden_states + output)
return attention_output
@LAYERS.register_module
class TransformerLayer1D(ParallelLayer):
"""Transformer layer which contains a self-attention layer and a MLP layer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: float, optional
:param attention_dropout_prob: dropout probability for attention layer, defaults to 0.
:type attention_dropout_prob: float, optional
:param hidden_dropout_prob: dropout probability for attention layer, defaults to 0.
:type hidden_dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
act_func: str = 'gelu',
mlp_ratio: float = 4.0,
attention_dropout_prob: float = 0.,
hidden_dropout_prob: float = 0.,
dtype=None,
):
super().__init__()
self.attention = TransformerSelfAttention1D(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout_prob=attention_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
dtype=dtype,
)
self.mlp = TransformerMLP1D(
in_features=hidden_size,
dropout_prob=hidden_dropout_prob,
act_func=act_func,
mlp_ratio=mlp_ratio,
dtype=dtype,
)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
attention_output = self.attention(hidden_states, attention_mask)
output = self.mlp(attention_output)
return output
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.distributed as dist
from colossalai.core import global_context as gpc
from .._common_utils import divide
......@@ -15,4 +20,128 @@ def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank)
def _reduce(input_, parallel_mode):
# skip if only one rank involved
if gpc.get_world_size(parallel_mode) == 1:
return input_
dist.all_reduce(input_, group=gpc.get_group(parallel_mode))
return input_
def _split(input_, parallel_mode, dim=-1):
# skip if only one rank involved
world_size = gpc.get_world_size(parallel_mode)
if world_size == 1:
return input_
# Split along last dimension.
dim_size = input_.size(dim)
assert dim_size % world_size == 0, \
f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \
f'cannot split tensor evenly'
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
rank = gpc.get_local_rank(parallel_mode)
output = tensor_list[rank].contiguous()
return output
def _gather(input_, parallel_mode, dim=-1):
# skip if only one rank involved
world_size = gpc.get_world_size(parallel_mode)
if world_size == 1:
return input_
# all gather
rank = gpc.get_local_rank(parallel_mode)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=gpc.get_group(parallel_mode))
# concat
output = torch.cat(tensor_list, dim=dim).contiguous()
return output
class _ReduceGrad(torch.autograd.Function):
"""Pass the input to the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
def forward(ctx, input_, parallel_mode):
ctx.mode = parallel_mode
return input_
@staticmethod
def backward(ctx, grad_output):
return _reduce(grad_output, ctx.mode), None
class _ReduceInput(torch.autograd.Function):
"""All-reduce the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce(input_)
@staticmethod
def forward(ctx, input_, parallel_mode):
return _reduce(input_, parallel_mode)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
class _SplitForwardGatherBackward(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split(input_)
@staticmethod
def forward(ctx, input_, parallel_mode, dim):
ctx.mode = parallel_mode
ctx.dim = dim
return _split(input_, parallel_mode, dim)
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output, ctx.mode, ctx.dim), None, None
class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_):
return _gather(input_)
@staticmethod
def forward(ctx, input_, parallel_mode, dim):
ctx.mode = parallel_mode
ctx.dim = dim
return _gather(input_, parallel_mode, dim)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.mode, ctx.dim), None, None
def reduce_grad(input_, parallel_mode):
return _ReduceGrad.apply(input_, parallel_mode)
def reduce_input(input_, parallel_mode):
return _ReduceInput.apply(input_, parallel_mode)
def split_forward_gather_backward(input_, parallel_mode, dim):
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)
def gather_forward_split_backward(input_, parallel_mode, dim):
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
from colossalai import context
import torch
from torch import nn as nn, Tensor, distributed as dist
from torch.nn.init import _calculate_fan_in_and_fan_out
from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer._common_utils import divide, ACT2FN
from colossalai.registry import LAYERS
from colossalai.utils import checkpoint
from colossalai.utils import get_current_device
from .layers import Linear1D_Col, Linear1D_Row
from ..base_layer import ParallelLayer
from .._common_utils import to_2tuple
from ..fused_bias_gelu import bias_gelu_impl
@LAYERS.register_module
class ViTMLP1D(ParallelLayer):
"""MLP layer for 1D parallel Vision Transformer
:param in_features: size of each input sample
:type in_features: int
:param mlp_ratio: hidden size of MLP divided by embedding dim
:type mlp_ratio: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param dropout_prob: dropout probability, defaults to 0.
:type dropout_prob: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: whether to checkpoint the layer, defaults to False
:type checkpoint: bool, optional
"""
def __init__(self,
in_features: int,
mlp_ratio: int,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
checkpoint: bool = False,
skip_bias_add: bool = False,
weight_init='torch'
):
super().__init__()
self.in_features = in_features
self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint
self.skip_bias_add = skip_bias_add
assert weight_init in ('torch', 'jax')
if act_func == 'fused_gelu':
self.act = bias_gelu_impl
skip_dense_1_add_bias = True
else:
self.act = ACT2FN[act_func]
skip_dense_1_add_bias = False
# Project to mlp_ratio * h.
self.dense_1 = Linear1D_Col(
self.in_features,
int(self.mlp_ratio * self.in_features),
dtype=dtype,
gather_output=False,
skip_bias_add=skip_dense_1_add_bias,
init_weight=weight_init,
init_bias=weight_init
)
# Project back to h.
self.dense_2 = Linear1D_Row(
int(self.mlp_ratio * self.in_features),
self.in_features,
dtype=dtype,
parallel_input=True,
init_weight=weight_init, init_bias=weight_init
)
self.dropout = nn.Dropout(dropout_prob)
def _forward(self, hidden_states: Tensor) -> Tensor:
if self.act == bias_gelu_impl:
intermediate_output, bias = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output, bias)
else:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output)
with seed(ParallelMode.TENSOR):
intermediate_output = self.dropout(intermediate_output)
output = self.dense_2(intermediate_output)
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTSelfAttention1D(ParallelLayer):
"""Self-attention layer for 1D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layers
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layers
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: whether to checkpoint the layer, defaults to False
:type checkpoint: bool, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
checkpoint: bool = False,
weight_init='torch'
):
super().__init__()
self.hidden_size = hidden_size
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.num_attention_heads_per_partition = divide(num_attention_heads, gpc.tensor_parallel_size)
self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size)
self.checkpoint = checkpoint
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
init_bias = 'zero'
else:
init_bias = weight_init
self.query_key_value = Linear1D_Col(
hidden_size,
3 * hidden_size,
dtype=dtype,
init_weight=weight_init,
init_bias=init_bias
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear1D_Row(
hidden_size,
hidden_size,
dtype=dtype,
parallel_input=True,
init_weight=weight_init, init_bias=init_bias
)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.softmax = nn.Softmax(dim=-1)
def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads_per_partition, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
with seed(ParallelMode.TENSOR):
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2)
new_context_layer_shape = context_layer.size()[
:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer)
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTHead1D(ParallelLayer):
"""Output layer for 1D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_classes: number of classes
:type num_classes: int
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size,
num_classes,
dtype=None,
weight_init='torch'
):
super().__init__()
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
init_weight = 'zero'
init_bias = 'zero'
else:
init_weight = weight_init
init_bias = weight_init
self.linear = Linear1D_Col(
hidden_size,
num_classes,
dtype=dtype,
gather_output=True,
init_weight=init_weight,
init_bias=init_bias
)
def forward(self, x: Tensor) -> Tensor:
x = x[:, 0]
x = self.linear(x)
return x
@LAYERS.register_module
class ViTHead(ParallelLayer):
"""Output layer for 1D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_classes: number of classes
:type num_classes: int
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size,
num_classes,
dtype=None,
):
super().__init__()
self.linear = nn.Linear(
hidden_size,
num_classes,
dtype=dtype
)
self._broadcast_linear_params()
def _broadcast_linear_params(self) -> None:
self.to(get_current_device())
ranks = gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)
dist.broadcast(self.linear.weight, src=ranks[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
dist.broadcast(self.linear.bias, src=ranks[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
def forward(self, x: Tensor) -> Tensor:
x = x[:, 0]
x = self.linear(x)
return x
@LAYERS.register_module
class ViTPatchEmbedding1D(ParallelLayer):
""" 2D Image to Patch Embedding
:param img_size: iamge size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param in_chans: number of channels of input image, defaults to 3
:type in_chans: int, optional
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
in_chans=3,
flatten=True,
weight_init='torch'):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans,
self.embed_dim,
kernel_size=patch_size,
stride=patch_size
)
if weight_init == 'jax':
fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight)
std = math.sqrt(1.0 / fan_in)
nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
nn.init.zeros_(self.proj.bias)
# sync
self._broadcast_conv_params()
def _broadcast_conv_params(self) -> None:
self.to(get_current_device())
ranks = gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)
dist.broadcast(self.proj.weight, src=ranks[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
dist.broadcast(self.proj.bias, src=ranks[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
def forward(self, x: Tensor) -> Tensor:
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
return x
@LAYERS.register_module
class ViTTokenFuser1D(ParallelLayer):
"""
Fuse cls token and pos embedding to the input
:param img_size: image size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param drop_rate: dropout probability, defaults to 0.
:type drop_rate: float, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
drop_rate=0.
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.embed_dim = embed_dim
self.cls_token = nn.Parameter(torch.zeros(
1, 1, self.embed_dim))
self.pos_embed = nn.Parameter(torch.empty(
1, self.num_patches + 1, self.embed_dim))
nn.init.trunc_normal_(self.pos_embed, std=.02)
# move to cuda before broadcast
self.to(get_current_device())
dist.broadcast(self.pos_embed,
src=gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
group=gpc.get_group(ParallelMode.TENSOR))
self.pos_drop = nn.Dropout(p=drop_rate)
def forward(self, x: Tensor) -> Tensor:
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = self.pos_drop(x + self.pos_embed)
return x.contiguous()
......@@ -3,25 +3,24 @@
import math
import numbers
from typing import Callable, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch import Tensor
from torch.nn.parameter import Parameter
from typing import Tuple
import importlib
from colossalai.context import seed, ParallelMode
from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.nn import init as init
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from ._operation import FusedLayerNormAffineFunction1D
from torch import Tensor
from torch.nn.parameter import Parameter
from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \
split_forward_gather_backward
from ..base_layer import ParallelLayer
from ._operation import FusedLayerNormAffineFunction1D
from ._utils import (gather_forward_split_backward, reduce_grad, reduce_input, split_forward_gather_backward)
@LAYERS.register_module
......@@ -44,79 +43,46 @@ class Linear1D_Col(ParallelLayer):
which is :math:`Y_i = XA_i`, defaults to False
:type gather_output: bool, optional
"""
def __init__(self,
in_features: int,
output_size: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
gather_output: bool = False,
skip_bias_add: bool = False,
init_weight='torch',
init_bias='torch'
):
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
# Keep input parameters
self.in_features = in_features
self.out_features = output_size
self.out_features = out_features
self.gather_output = gather_output
self.skip_bias_add = skip_bias_add
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
self.output_size_per_partition = divide(output_size, gpc.tensor_parallel_size)
self.out_features_per_partition = divide(out_features, gpc.tensor_parallel_size)
# Parameters.
# Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.in_features,
**factory_kwargs))
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(
self.output_size_per_partition,
**factory_kwargs))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs))
else:
self.register_parameter('bias', None)
self.bias = None
with seed(ParallelMode.TENSOR):
self.reset_parameters(init_weight, init_bias)
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
def reset_parameters(self, init_weight, init_bias) -> None:
assert init_weight in ('torch', 'jax', 'zero')
assert init_bias in ('torch', 'jax', 'zero')
# setting
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
# init weight
if init_weight == 'torch':
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
init.uniform_(self.weight, -bound, bound)
elif init_weight == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(self.weight, -a, a)
elif init_weight == 'zero':
init.zeros_(self.weight)
# init bias
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
if init_bias == 'torch':
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
elif init_bias == 'jax':
init.normal_(self.bias, std=1e-6)
elif init_bias == 'zero':
init.zeros_(self.bias)
bias_initializer(self.bias, fan_in=fan_in)
def _set_tensor_parallel_attributes(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
......@@ -133,8 +99,7 @@ class Linear1D_Col(ParallelLayer):
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(
output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
else:
output = output_parallel
if self.skip_bias_add:
......@@ -158,17 +123,15 @@ class Linear1D_Row(ParallelLayer):
:param parallel_input: If set to ``True``, it's assumed that the input is splitted, defaults to False
:type parallel_input: bool, optional
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
parallel_input: bool = False,
parallel_input: bool = True,
skip_bias_add: bool = False,
init_weight='torch',
init_bias='torch'
):
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
# Keep input parameters
......@@ -186,58 +149,22 @@ class Linear1D_Row(ParallelLayer):
# Parameters.
# Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(
self.out_features,
self.input_size_per_partition,
**factory_kwargs))
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(
self.out_features,
**factory_kwargs
))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
self.register_parameter('bias', None)
self.bias = None
with seed(ParallelMode.TENSOR):
self.reset_parameters(init_weight, init_bias)
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
def reset_parameters(self, init_weight, init_bias) -> None:
assert init_weight in ('torch', 'jax', 'zero')
assert init_bias in ('torch', 'jax', 'zero')
# setting
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
# init weight
if init_weight == 'torch':
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
init.uniform_(self.weight, -bound, bound)
elif init_weight == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(self.weight, -a, a)
elif init_weight == 'zero':
init.zeros_(self.weight)
# init bias
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
if init_bias == 'torch':
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
elif init_bias == 'jax':
init.normal_(self.bias, std=1e-6)
elif init_bias == 'zero':
init.zeros_(self.bias)
dist.broadcast(self.bias,
src=gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
bias_initializer(self.bias, fan_in=fan_in)
broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D)
def _set_tensor_parallel_attributes(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
......@@ -248,8 +175,7 @@ class Linear1D_Row(ParallelLayer):
if self.parallel_input:
input_ = input_
else:
input_ = split_forward_gather_backward(
input_, ParallelMode.PARALLEL_1D, dim=-1)
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
output_parallel = F.linear(input_, self.weight)
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
......@@ -263,12 +189,13 @@ class Linear1D_Row(ParallelLayer):
@LAYERS.register_module
class MixedFusedLayerNorm1D(torch.nn.Module):
""" Experimental
"""
def __init__(self, normalized_shape, eps=1e-5):
super(MixedFusedLayerNorm1D, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = (normalized_shape, )
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.weight = Parameter(torch.Tensor(*normalized_shape))
......@@ -280,5 +207,4 @@ class MixedFusedLayerNorm1D(torch.nn.Module):
init.zeros_(self.bias)
def forward(self, input):
return FusedLayerNormAffineFunction1D.apply(
input, self.weight, self.bias, self.normalized_shape, self.eps)
return FusedLayerNormAffineFunction1D.apply(input, self.weight, self.bias, self.normalized_shape, self.eps)
from ._operation import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D, Add_Bias_2D, matmul_2d
from ._transformer import TransformerMLP2D, TransformerSelfAttention2D, TransformerLayer2D
from ._vit import ViTMLP2D, ViTSelfAttention2D, ViTHead2D, ViTPatchEmbedding2D, ViTTokenFuser2D, ViTInputSplitter2D
from .layers import Linear2D, LayerNorm2D
from ._operation import reduce_by_batch_2d, split_batch_2d
from .layers import Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D
__all__ = [
'Matmul_AB_2D', 'Matmul_ABT_2D', 'Matmul_ATB_2D', 'Add_Bias_2D', 'matmul_2d',
'TransformerMLP2D', 'TransformerSelfAttention2D', 'TransformerLayer2D',
'ViTMLP2D', 'ViTSelfAttention2D', 'ViTHead2D', 'ViTPatchEmbedding2D', 'ViTTokenFuser2D', 'ViTInputSplitter2D',
'Linear2D', 'LayerNorm2D'
'split_batch_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', 'Embedding2D'
]
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import torch
from torch import nn as nn, Tensor
from colossalai.nn.layer._common_utils import divide, ACT2FN
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env
from colossalai.registry import LAYERS
from .layers import Linear2D, LayerNorm2D
from ..base_layer import ParallelLayer
@LAYERS.register_module
class TransformerMLP2D(ParallelLayer):
"""
MLP will take the input with h hidden state, project it to mlp_ratio * h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
:param in_features: the size of input tensor
:type in_features: int
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: int, optional
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param dropout_prob: dropout probability, defaults to 0.
:type dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False
:type skip_bias_add: bool, optional
"""
def __init__(self,
in_features: int,
mlp_ratio: int = 4.0,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
skip_bias_add: bool = False
):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.in_features = in_features
self.skip_bias_add = skip_bias_add
# Project to h * mlp_ratio.
self.dense_1 = Linear2D(
in_features,
int(mlp_ratio * in_features),
dtype=dtype,
skip_bias_add=self.skip_bias_add
)
assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \
f'activation function can only be {list(ACT2FN.keys())}'
self.activation_func = ACT2FN[act_func]
# Project back to h.
self.dense_2 = Linear2D(
int(mlp_ratio * in_features),
in_features,
dtype=dtype,
skip_bias_add=self.skip_bias_add
)
self.dropout = nn.Dropout(dropout_prob)
self.layernorm = LayerNorm2D(in_features, dtype=dtype)
def forward(self, x: Tensor) -> Tensor:
if self.skip_bias_add:
intermediate_output, _ = self.dense_1(x)
else:
intermediate_output = self.dense_1(x)
intermediate_output = self.activation_func(intermediate_output)
if self.skip_bias_add:
output, _ = self.dense_2(intermediate_output)
else:
output = self.dense_2(intermediate_output)
output = self.dropout(output)
output = self.layernorm(x + output)
return output
@LAYERS.register_module
class TransformerSelfAttention2D(ParallelLayer):
"""Self attention layer for 2D parallel Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layer
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layer
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.hidden_size = hidden_size
self.num_attention_heads = divide(num_attention_heads, self.summa_dim)
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query_key_value = Linear2D(
hidden_size,
3 * hidden_size,
dtype=dtype,
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear2D(
hidden_size,
hidden_size,
dtype=dtype,
)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.layernorm = LayerNorm2D(
hidden_size,
dtype=dtype)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_scores = attention_scores + attention_mask
attention_probs = nn.Softmax(dim=-1)(attention_scores)
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute((0, 2, 1, 3)).contiguous()
new_context_layer_shape = context_layer.size()[
:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.dense(context_layer)
output = self.dropout(output)
attention_output = self.layernorm(hidden_states + output)
return attention_output
@LAYERS.register_module
class TransformerLayer2D(ParallelLayer):
"""Transformer layer which contains a self-attention layer and a MLP layer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: float, optional
:param attention_dropout_prob: dropout probability for attention layer, defaults to 0.
:type attention_dropout_prob: float, optional
:param hidden_dropout_prob: dropout probability for attention layer, defaults to 0.
:type hidden_dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
act_func: str = 'gelu',
mlp_ratio: float = 4.0,
attention_dropout_prob: float = 0.,
hidden_dropout_prob: float = 0.,
dtype=None,
):
super().__init__()
self.attention = TransformerSelfAttention2D(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout_prob=attention_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
dtype=dtype,
)
self.mlp = TransformerMLP2D(
in_features=hidden_size,
dropout_prob=hidden_dropout_prob,
act_func=act_func,
mlp_ratio=mlp_ratio,
dtype=dtype,
)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
attention_output = self.attention(hidden_states, attention_mask)
output = self.mlp(attention_output)
return output
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import torch
from torch import nn as nn, Tensor, distributed as dist
from torch.nn.init import _calculate_fan_in_and_fan_out
from colossalai.context import seed, ParallelMode
from colossalai.nn.layer._common_utils import divide, ACT2FN
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env
from colossalai.registry import LAYERS
from colossalai.utils import checkpoint
from colossalai.utils import get_current_device
from colossalai.core import global_context as gpc
from ._operation import AllGatherLast, SplitFirst
from .layers import Linear2D
from .._common_utils import set_tensor_parallel_attribute_by_partition, to_2tuple
from ..base_layer import ParallelLayer
from ..fused_bias_gelu import bias_gelu_impl
@LAYERS.register_module
class ViTMLP2D(ParallelLayer):
"""MLP layer for 2D parallel Vision Transformer
:param in_features: size of each input sample
:type in_features: int
:param mlp_ratio: hidden size of MLP divided by embedding dim
:type mlp_ratio: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param dropout_prob: dropout probability, defaults to 0.
:type dropout_prob: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: whether to checkpoint the layer, defaults to False
:type checkpoint: bool, optional
"""
def __init__(self,
in_features: int,
mlp_ratio: int,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
checkpoint: bool = False,
weight_init='torch'):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.in_features = in_features
self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint
assert weight_init in ('torch', 'jax')
if act_func == 'fused_gelu':
self.act = bias_gelu_impl
skip_dense_1_add_bias = True
else:
self.act = ACT2FN[act_func]
skip_dense_1_add_bias = False
# Project to mlp_ratio * h.
self.dense_1 = Linear2D(
self.in_features,
self.mlp_ratio * self.in_features,
dtype=dtype,
init_weight=weight_init, init_bias=weight_init,
skip_bias_add=skip_dense_1_add_bias
)
# Project back to h.
self.dense_2 = Linear2D(
self.mlp_ratio * self.in_features,
self.in_features,
dtype=dtype,
init_weight=weight_init, init_bias=weight_init
)
self.dropout = nn.Dropout(dropout_prob)
def _forward(self, hidden_states: Tensor) -> Tensor:
if self.act == bias_gelu_impl:
intermediate_output, bias = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output, bias)
else:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output)
with seed(ParallelMode.TENSOR):
intermediate_output = self.dropout(intermediate_output)
output = self.dense_2(intermediate_output)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTSelfAttention2D(ParallelLayer):
"""Self-attention layer for 2D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layers
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layers
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: whether to checkpoint the layer, defaults to False
:type checkpoint: bool, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
checkpoint: bool = False,
weight_init='torch'):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.hidden_size = hidden_size
self.num_attention_heads = divide(num_attention_heads, self.summa_dim)
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.checkpoint = checkpoint
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
self.init_bias = 'zero'
else:
self.init_bias = weight_init
self.query_key_value = Linear2D(
hidden_size,
3 * hidden_size,
dtype=dtype,
init_weight=weight_init, init_bias=self.init_bias
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear2D(
hidden_size,
hidden_size,
dtype=dtype,
init_weight=weight_init, init_bias=self.init_bias
)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.softmax = nn.Softmax(dim=-1)
def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
with seed(ParallelMode.TENSOR):
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2)
new_context_layer_shape = context_layer.size()[
:-2] + (self.all_head_size,)
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTHead2D(ParallelLayer):
"""Output layer for 2D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_classes: number of classes
:type num_classes: int
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size,
num_classes,
dtype=None,
weight_init='torch'):
super().__init__()
assert_summa_initialization()
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
self.init_weight = 'zero'
self.init_bias = 'zero'
else:
self.init_weight = weight_init
self.init_bias = weight_init
self.summa_dim = get_summa_dim_from_env()
self.linear = Linear2D(
hidden_size,
num_classes,
dtype=dtype,
init_weight=self.init_weight, init_bias=self.init_bias
)
def forward(self, x: Tensor) -> Tensor:
x = x[:, 0]
x = self.linear(x)
return x
@LAYERS.register_module
class ViTPatchEmbedding2D(ParallelLayer):
""" 2D Image to Patch Embedding
:param img_size: iamge size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param in_chans: number of channels of input image, defaults to 3
:type in_chans: int, optional
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
in_chans=3,
flatten=True,
weight_init='torch'):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.embed_dim = embed_dim // (self.summa_dim ** 2)
with seed(ParallelMode.TENSOR):
self.proj = nn.Conv2d(in_chans,
self.embed_dim,
kernel_size=patch_size,
stride=patch_size,
device=get_current_device()
)
self._set_tensor_parallel_attribute()
if weight_init == 'jax':
with seed(ParallelMode.TENSOR):
fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight)
std = math.sqrt(1.0 / fan_in)
nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
nn.init.zeros_(self.proj.bias)
def _set_tensor_parallel_attribute(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.proj.weight, num_partition)
set_tensor_parallel_attribute_by_partition(self.proj.bias, num_partition)
def forward(self, x: Tensor) -> Tensor:
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
return x
@LAYERS.register_module
class ViTInputSplitter2D(ParallelLayer):
"""Split the input tensor for 2D parallel Vision Transformer
"""
def __init__(self):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
def forward(self, x: Tensor) -> Tensor:
x = AllGatherLast.apply(
x, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
x = SplitFirst.apply(
x, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
return x
@LAYERS.register_module
class ViTTokenFuser2D(ParallelLayer):
"""
Fuse cls token and pos embedding to the input
:param img_size: image size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param drop_rate: dropout probability, defaults to 0.
:type drop_rate: float, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
drop_rate=0.
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.embed_dim = embed_dim
self.cls_token = nn.Parameter(torch.zeros(
(1, 1, self.embed_dim // (self.summa_dim ** 2)),
device=get_current_device()))
self.pos_embed = nn.Parameter(torch.empty(
(1, self.num_patches + 1, self.embed_dim // (self.summa_dim ** 2)),
device=get_current_device()))
with seed(ParallelMode.TENSOR):
nn.init.trunc_normal_(self.pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.cls_token, num_partition)
set_tensor_parallel_attribute_by_partition(self.pos_embed, num_partition)
def forward(self, x: Tensor) -> Tensor:
# stole cls_tokens impl from Phil Wang, thanks
cls_token = AllGatherLast.apply(
self.cls_token, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
cls_token = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
pos_embed = AllGatherLast.apply(
self.pos_embed, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
x = x + pos_embed
with seed(ParallelMode.TENSOR):
x = self.pos_drop(x)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment